def __init__(self, template_store=None, cc_merge=None, cc_mixture=None, optimize=True, overlap_path=None): self.template_store = template_store self.nb_channels = self.template_store.nb_channels self.cc_merge = cc_merge self.cc_mixture = cc_mixture self._duplicates = None self._mixtures = None self.optimize = optimize self._indices = [] self.overlaps_store = OverlapsStore(self.template_store, optimize=self.optimize, path=overlap_path)
def _initialize_templates(self): self._template_store = TemplateStore(self.templates_init_path, mode='r') # Log info message. string = "{} is initialized with {} templates from {}" message = string.format(self.name, self._template_store.nb_templates, self.templates_init_path) self.log.info(message) self._overlaps_store = OverlapsStore( template_store=self._template_store, path=self.overlaps_init_path, fitting_mode=True) # Log info message. string = "{} is initialized with precomputed overlaps from {}" message = string.format(self.name, self.overlaps_init_path) self.log.info(message) return
def _initialize(self): """Initialize template updater.""" # Initialize path to save the templates. if self.templates_path is None: self.templates_path = self._get_tmp_path() else: self.templates_path = os.path.expanduser(self.templates_path) self.templates_path = os.path.abspath(self.templates_path) # Create the corresponding directory if it does not exist. data_directory, _ = os.path.split(self.templates_path) if not os.path.exists(data_directory): os.makedirs(data_directory) if self.skip_overlaps: self.overlaps_path = None # Create object to handle templates. self._template_store = TemplateStore(self.templates_path, probe_file=self.probe_path, mode='w') self._template_dictionary = TemplateDictionary(self._template_store, cc_merge=self.cc_merge, cc_mixture=self.cc_mixture) # Create object to handle overlaps. self._overlap_store = OverlapsStore(template_store=self._template_store, path=self.overlaps_path) # Log info message. string = "{} records templates into {}" message = string.format(self.name, self.templates_path) self.log.info(message) # Define precomputed templates (if necessary). if self.precomputed_template_paths is not None: # Load precomputed templates. precomputed_templates = [ load_template(path) for path in self.precomputed_template_paths ] # Add precomputed templates to the dictionary. accepted = self._template_dictionary.initialize(precomputed_templates) # Log some information. if len(accepted) > 0: string = "{} added {} precomputed templates" message = string.format(self.name, len(accepted)) self.log.debug(message) # Update precomputed overlaps. if not self.skip_overlaps: self._overlap_store.update(accepted) self._overlap_store.compute_overlaps() # Save precomputed overlaps to disk. self._overlap_store.save_overlaps() # Log some information. if len(accepted) > 0: string = "{} precomputed overlaps" message = string.format(self.name) self.log.debug(message) # Send output data. self._precomputed_output = { 'indices': accepted, 'template_store': self._template_store.file_name, 'overlaps': self._overlap_store.to_json, } else: self._precomputed_output = None return
class TemplateUpdater(Block): """Template updater. Attributes: probe_path: string radius: float cc_merge: float cc_mixture: float templates_path: string overlaps_path: string precomputed_template_paths: none | list sampling_rate: float nb_samples: integer """ name = "Template updater" params = { 'probe_path': None, 'radius': None, 'cc_merge': 0.95, 'cc_mixture': None, 'templates_path': None, 'overlaps_path': None, 'precomputed_template_paths': None, 'sampling_rate': 20e+3, 'nb_samples': 1024, 'skip_overlaps': False } def __init__(self, **kwargs): """Initialize template updater. Arguments: probe_path: string radius: none | float (optional) cc_merge: float (optional) cc_mixture: none | float (optional) templates_path: none | string (optional) overlaps_path: none | string (optional) precomputed_template_paths: none | list (optional) sampling_rate: float (optional) nb_samples: integer (optional) """ Block.__init__(self, **kwargs) # The following lines are useful to avoid some PyCharm's warnings. self.probe_path = self.probe_path self.radius = self.radius self.cc_merge = self.cc_merge self.cc_mixture = self.cc_mixture self.templates_path = self.templates_path self.overlaps_path = self.overlaps_path self.precomputed_template_paths = self.precomputed_template_paths self.sampling_rate = self.sampling_rate self.nb_samples = self.nb_samples self.skip_overlaps = self.skip_overlaps # Initialize private attributes. if self.probe_path is None: self.probe = None # Log error message. string = "{}: the probe file must be specified!" message = string.format(self.name) self.log.error(message) else: self.probe = load_probe(self.probe_path, radius=self.radius, logger=self.log) # Log info message. string = "{} reads the probe layout" message = string.format(self.name) self.log.info(message) self._template_store = None self._template_dictionary = None self._overlap_store = None self._two_components = None self.add_input('templates', structure='dict') self.add_output('updater', structure='dict') def _initialize(self): """Initialize template updater.""" # Initialize path to save the templates. if self.templates_path is None: self.templates_path = self._get_tmp_path() else: self.templates_path = os.path.expanduser(self.templates_path) self.templates_path = os.path.abspath(self.templates_path) # Create the corresponding directory if it does not exist. data_directory, _ = os.path.split(self.templates_path) if not os.path.exists(data_directory): os.makedirs(data_directory) if self.skip_overlaps: self.overlaps_path = None # Create object to handle templates. self._template_store = TemplateStore(self.templates_path, probe_file=self.probe_path, mode='w') self._template_dictionary = TemplateDictionary(self._template_store, cc_merge=self.cc_merge, cc_mixture=self.cc_mixture) # Create object to handle overlaps. self._overlap_store = OverlapsStore(template_store=self._template_store, path=self.overlaps_path) # Log info message. string = "{} records templates into {}" message = string.format(self.name, self.templates_path) self.log.info(message) # Define precomputed templates (if necessary). if self.precomputed_template_paths is not None: # Load precomputed templates. precomputed_templates = [ load_template(path) for path in self.precomputed_template_paths ] # Add precomputed templates to the dictionary. accepted = self._template_dictionary.initialize(precomputed_templates) # Log some information. if len(accepted) > 0: string = "{} added {} precomputed templates" message = string.format(self.name, len(accepted)) self.log.debug(message) # Update precomputed overlaps. if not self.skip_overlaps: self._overlap_store.update(accepted) self._overlap_store.compute_overlaps() # Save precomputed overlaps to disk. self._overlap_store.save_overlaps() # Log some information. if len(accepted) > 0: string = "{} precomputed overlaps" message = string.format(self.name) self.log.debug(message) # Send output data. self._precomputed_output = { 'indices': accepted, 'template_store': self._template_store.file_name, 'overlaps': self._overlap_store.to_json, } else: self._precomputed_output = None return @staticmethod def _get_tmp_path(): tmp_directory = tempfile.gettempdir() tmp_basename = "templates.h5" tmp_path = os.path.join(tmp_directory, tmp_basename) return tmp_path def _data_to_templates(self, data): all_templates = [] keys = [key for key in data.keys() if key not in ['offset']] for key in keys: for channel in data[key].keys(): templates = [] for template in data[key][channel].values(): templates += [load_template_from_dict(template, self.probe)] if len(templates) > 0: # Log debug message. string = "{} received {} {} templates from electrode {}" message = string.format(self.name, len(templates), key, channel) self.log.debug(message) all_templates += templates return all_templates def _process(self): # Send precomputed templates. if self.counter == 0 and self._precomputed_output is not None: # Prepare output packet. packet = { 'number': -1, 'payload': self._precomputed_output, } # Send templates. self.get_output('updater').send(packet) # Receive input data. templates_packet = self.get_input('templates').receive(blocking=False) data = templates_packet['payload'] if templates_packet is not None else None if data is not None: self._measure_time('start', period=1) # Set mode as active (if necessary). if not self.is_active: self._set_active_mode() # Add received templates to the dictionary. templates = self._data_to_templates(data) self._measure_time('add_template_start', period=1) accepted, nb_duplicates, nb_mixtures = self._template_dictionary.add(templates) self._measure_time('add_template_end', period=1) # Log debug messages (if necessary). if nb_duplicates > 0: # Log debug message. string = "{} rejected {} duplicated templates" message = string.format(self.name, nb_duplicates) self.log.debug(message) if nb_mixtures > 0: # Log debug message. string = "{} rejected {} composite templates" message = string.format(self.name, nb_mixtures) self.log.debug(message) if len(accepted) > 0: # Log debug message. string = "{} accepted {} templates" message = string.format(self.name, len(accepted)) self.log.debug(message) # Update and pre-compute the overlaps. self._overlap_store.update(accepted) if not self.skip_overlaps: self._measure_time('compute_overlap_start', period=1) self._overlap_store.compute_overlaps() self._measure_time('compute_overlap_end', period=1) # Log debug message. string = "{} updates and pre-computes the overlaps" message = string.format(self.name_and_counter) self.log.debug(message) # Save precomputed overlaps to disk. self._measure_time('save_overlap_start', period=1) self._overlap_store.save_overlaps() self._measure_time('save_overlap_end', period=1) # Log debug message. string = "{} saves precomputed overlaps" message = string.format(self.name_and_counter) self.log.debug(message) # Prepare output data. output_data = { 'indices': accepted, 'template_store': self._template_store.file_name, 'overlaps': self._overlap_store.to_json, } # Prepare output packet. output_packet = { 'number': templates_packet['number'], 'payload': output_data, } # Send output packet. self.get_output('updater').send(output_packet) # Log debug message. string = "{} sends output packet" message = string.format(self.name_and_counter) self.log.debug(message) self._measure_time('end', period=1) return def _introspect(self): nb_buffers = self.counter - self.start_step start_times = np.array(self._measured_times.get('start', [])) end_times = np.array(self._measured_times.get('end', [])) durations = end_times - start_times data_duration = float(self.nb_samples) / self.sampling_rate ratios = data_duration / durations min_ratio = np.min(ratios) if ratios.size > 0 else np.nan mean_ratio = np.mean(ratios) if ratios.size > 0 else np.nan max_ratio = np.max(ratios) if ratios.size > 0 else np.nan # Log info message. string = "{} processed {} buffers [speed:x{:.2f} (min:x{:.2f}, max:x{:.2f})]" message = string.format(self.name, nb_buffers, mean_ratio, min_ratio, max_ratio) self.log.info(message) return
class TemplateDictionary(object): def __init__(self, template_store=None, cc_merge=None, cc_mixture=None, optimize=True, overlap_path=None): self.template_store = template_store self.nb_channels = self.template_store.nb_channels self.cc_merge = cc_merge self.cc_mixture = cc_mixture self._duplicates = None self._mixtures = None self.optimize = optimize self._indices = [] self.overlaps_store = OverlapsStore(self.template_store, optimize=self.optimize, path=overlap_path) @property def is_empty(self): return len(self.template_store) == 0 @property def first_component(self): return self.overlaps_store.first_component def _init_from_template(self, template): """Initialize template dictionary based on a sampled template. Argument: template: circusort.obj.Template The sampled template used to initialize the dictionary. This template won't be added to the dictionary. """ self.nb_elements = self.nb_channels * template.temporal_width return def __str__(self): string = """ Template dictionary with {} templates rejected as duplicates: {} rejected as mixtures : {} """.format(self.nb_templates, self.nb_duplicates, self.nb_mixtures) return string def __iter__(self, index): for i in self.first_component: yield self[i] return def __getitem__(self, index): return self.first_component[index] def __len__(self): return self.nb_templates @property def to_json(self): result = { 'template_store': self.template_store.file_name, 'overlaps': self.overlaps_store.to_json } return result @property def nb_templates(self): return self.overlaps_store.nb_templates @property def nb_mixtures(self): if self._mixtures is None: nb_mixtures = 0 else: nb_mixtures = np.sum( [len(value) for value in self._mixtures.items()]) return nb_mixtures @property def nb_duplicates(self): if self._duplicates is None: nb_duplicates = 0 else: nb_duplicates = np.sum( [len(value) for value in self._duplicates.items()]) return nb_duplicates def _add_duplicates(self, template): if self._duplicates is None: self._duplicates = {} if template.channel in self._duplicates: self._duplicates[template.channel] += [template.creation_time] else: self._duplicates[template.channel] = [template.creation_time] return def _add_mixtures(self, template): if self._mixtures is None: self._mixtures = {} if template.channel in self._mixtures: self._mixtures[template.channel] += [template.creation_time] else: self._mixtures[template.channel] = [template.creation_time] return def _add_template(self, template): """Add a template to the template dictionary. Arguments: template: circusort.obj.Template Return: indices: list A list which contains the indices of templates successfully added to the underlying template store. """ indices = self.template_store.add(template) self.overlaps_store.add_template(template) return indices def compute_overlaps(self): self.overlaps_store.compute_overlaps() def save_overlaps(self): self.overlaps_store.save_overlaps() def non_zeros(self, channel_indices): """Get indices of templates whose spatial supports include at least one of the given channel indices. Argument: channel_indices: iterable The channel indices. Return: template_indices: numpy.ndarray The template indices. """ template_indices = np.array([ k for k, channel_indices_bis in enumerate(self._indices) if np.any(np.in1d(channel_indices_bis, channel_indices)) ], dtype=np.int32) return template_indices def initialize(self, templates): """Initialize the template dictionary with templates. Argument: templates Return: accepted: list A list which contains the indices of templates successfully added to the underlying template store. """ accepted, _, _ = self.add(templates, force=True) return accepted def add(self, templates, force=False): """Add templates to the template dictionary. Arguments: templates force: boolean (optional) The default value is False. Returns: accepted: list A list which contains the indices of templates successfully added to the underlying template store. nb_duplicates: integer The number of duplicates. nb_mixtures: integer The number of mixtures. """ nb_duplicates = 0 nb_mixtures = 0 accepted = [] # Initialize the dictionary (if necessary). if self.is_empty: template = next(iter(templates)) self._init_from_template(template) if force: # Add all the given templates without checking duplicates and mixtures. for t in templates: accepted += self._add_template(t) if self.optimize: self._indices += [t.indices] else: for t in templates: csr_template = t.first_component.to_sparse('csr', flatten=True) norm = t.first_component.norm csr_template /= norm if self.optimize: non_zeros = self.non_zeros(t.indices) else: non_zeros = None # Check if the template is already in this dictionary. is_present = self._is_present(csr_template, non_zeros) if is_present: nb_duplicates += 1 self._add_duplicates(t) # TODO Removing mixture online is difficult. # TODO Some mixture might have been added before the templates which compose this mixture. # TODO For now, let's say that there is no mixture in the signal. # # Check if the template is a mixture of templates already present in the dictionary. # is_mixture = self._is_mixture(csr_template, non_zeros) # if is_mixture: # nb_mixtures += 1 # self._add_mixtures(t) is_mixture = False if not is_present and not is_mixture: accepted += self._add_template(t) if self.optimize: self._indices += [t.indices] return accepted, nb_duplicates, nb_mixtures def _is_present(self, csr_template, non_zeros=None): if (self.cc_merge is None) or (self.nb_templates == 0): return False return self.overlaps_store.is_present(csr_template, self.cc_merge, non_zeros) def _is_mixture(self, csr_template, non_zeros=None): if (self.cc_mixture is None) or (self.nb_templates == 0): return False # TODO complete/clean. _ = csr_template # discard this argument _ = non_zeros # discard this argument # return self.overlaps_store.is_mixture(csr_template, self.cc_mixture, non_zeros) return False
def _process(self, verbose=False, timing=False): if timing: self._measure_time('preamble_start', period=10) # First, collect all the buffers we need. # # Prepare everything to collect buffers. if self.counter == 0: # Initialize 'self.x'. shape = (2 * self._nb_samples, self._nb_channels) self.x = np.zeros(shape, dtype=np.float32) elif self._nb_fitters == 1: # Copy the end of 'self.x' at its beginning. self.x[0 * self._nb_samples:1 * self._nb_samples, :] = \ self.x[1 * self._nb_samples:2 * self._nb_samples, :] else: pass # # Collect precedent data and peaks buffers. if self._nb_fitters > 1 and not (self.counter == 0 and self._fitter_id == 0): self._collect_data(shift=-1) self._collect_peaks(verbose=verbose) # # Collect current data and peaks buffers. self._collect_data(shift=0) self._collect_peaks(verbose=verbose) # # Collect current updater buffer. updater_packet = self.get_input('updater').receive( blocking=False, discarding_eoc=self.discarding_eoc_from_updater) updater = updater_packet[ 'payload'] if updater_packet is not None else None if timing: self._measure_time('preamble_end', period=10) if updater is not None: self._measure_time('update_start', period=1) while updater is not None: # Log debug message. string = "{} modifies template and overlap stores" message = string.format(self.name_and_counter) self.log.debug(message) # Modify template and overlap stores. indices = updater.get('indices', None) _ = indices # Discard unused variable. if self._template_store is None: # Initialize template and overlap stores. self._template_store = TemplateStore( updater['template_store'], mode='r') self._overlaps_store = OverlapsStore( template_store=self._template_store, path=updater['overlaps']['path'], fitting_mode=True) self._init_temp_window() # Log debug message. string = "{} initializes template and overlap stores ({}, {})" message = string.format(self.name_and_counter, updater['template_store'], updater['overlaps']['path']) self.log.debug(message) else: # TODO avoid duplicates in template store and uncomment the 3 following lines. # Update template and overlap stores. laziness = updater['overlaps']['path'] is None self._overlaps_store.update(indices, laziness=laziness) # Log debug message. string = "{} updates template and overlap stores" message = string.format(self.name_and_counter) self.log.debug(message) # Log debug message. string = "{} modified template and overlap stores" message = string.format(self.name_and_counter) self.log.debug(message) updater_packet = self.get_input('updater').receive( blocking=False, discarding_eoc=self.discarding_eoc_from_updater) updater = updater_packet[ 'payload'] if updater_packet is not None else None self._measure_time('update_end', period=1) if self.p is not None: if self.nb_templates > 0: self._measure_time('start') if timing: self._measure_time('fit_start', period=10) self._fit_chunk(verbose=verbose, timing=timing) if timing: self._measure_time('fit_end', period=10) if timing: self._measure_time('output_start', period=10) packet = { 'number': self._number, 'payload': self.r, } self.get_output('spikes').send(packet) if timing: self._measure_time('output_end', period=10) self._measure_time('end') elif self._nb_fitters > 1: packet = { 'number': self._number, 'payload': self._empty_result, } self.get_output('spikes').send(packet) elif self._nb_fitters > 1: packet = { 'number': self._number, 'payload': self._empty_result, } self.get_output('spikes').send(packet) return
class Fitter(Block): """Fitter Attributes: templates_init_path: none | string (optional) Path to the location used to load templates to initialize the dictionary of templates. If equal to None, this dictionary will start empty. The default value is None. overlaps_init_path: none | string (optional) Path to the location used to load the overlaps to initialize the overlap store. The default value is None. with_rejected_times: boolean (optional) The default value is False. sampling_rate: float (optional) The default value is 20e+3. discarding_eoc_from_updater: boolean (optional) The default value is False. """ name = "Fitter" params = { 'templates_init_path': None, 'overlaps_init_path': None, 'with_rejected_times': False, 'sampling_rate': 20e+3, 'discarding_eoc_from_updater': False, '_nb_fitters': 1, '_fitter_id': 0, } def __init__(self, **kwargs): """Initialize fitter Arguments: templates_init_path: string (optional) overlaps_init_path: string (optional) with_rejected_times: boolean (optional) sampling_rate: float (optional) discarding_eoc_from_updater: boolean (optional) _nb_fitters: integer (optional) _fitter_id: integer (optional) """ Block.__init__(self, **kwargs) # The following lines are useful to avoid some PyCharm's warnings. self.templates_init_path = self.templates_init_path self.overlaps_init_path = self.overlaps_init_path self.with_rejected_times = self.with_rejected_times self.sampling_rate = self.sampling_rate self.discarding_eoc_from_updater = self.discarding_eoc_from_updater self._nb_fitters = self._nb_fitters self._fitter_id = self._fitter_id # Initialize private attributes. self._template_store = None self._overlaps_store = None self.add_input('updater', structure='dict') self.add_input('data', structure='dict') self.add_input('peaks', structure='dict') self.add_output('spikes', structure='dict') self._nb_channels = None self._nb_samples = None self._number = None def _initialize(self): self.space_explo = 0.5 self.nb_chances = 3 if self.templates_init_path is not None: self.templates_init_path = os.path.expanduser( self.templates_init_path) self.templates_init_path = os.path.abspath( self.templates_init_path) self._initialize_templates() # Variables used to handle buffer edges. self.x = None # voltage signal self.p = None # peak time steps self.r = { # temporary result 'spike_times': np.zeros(0, dtype=np.int32), 'amplitudes': np.zeros(0, dtype=np.float32), 'templates': np.zeros(0, dtype=np.int32), } if self.with_rejected_times: self.r.update({ 'rejected_times': np.zeros(0, dtype=np.int32), 'rejected_amplitudes': np.zeros(0, dtype=np.float32), }) return def _initialize_templates(self): self._template_store = TemplateStore(self.templates_init_path, mode='r') # Log info message. string = "{} is initialized with {} templates from {}" message = string.format(self.name, self._template_store.nb_templates, self.templates_init_path) self.log.info(message) self._overlaps_store = OverlapsStore( template_store=self._template_store, path=self.overlaps_init_path, fitting_mode=True) # Log info message. string = "{} is initialized with precomputed overlaps from {}" message = string.format(self.name, self.overlaps_init_path) self.log.info(message) return @property def nb_templates(self): if self._overlaps_store is not None: nb_templates = self._overlaps_store.nb_templates else: nb_templates = 0 return nb_templates @property def min_scalar_product(self): return np.min(self._overlaps_store.amplitudes[:, 0] * self._overlaps_store.norms['1']) def _configure_input_parameters(self, nb_channels=None, nb_samples=None, **kwargs): if nb_channels is not None: self._nb_channels = nb_channels if nb_samples is not None: self._nb_samples = nb_samples return def _update_initialization(self): if self.templates_init_path is not None: self._init_temp_window() return def _init_temp_window(self): self._width = (self._overlaps_store.temporal_width - 1) // 2 self._2_width = 2 * self._width self.temp_window = np.arange(-self._width, self._width + 1) buffer_size = 2 * self._nb_samples return def _is_valid(self, peak_step): i_min = self._width i_max = self._nb_samples - self._width is_valid = (i_min <= peak_step) & (peak_step < i_max) return is_valid def _get_all_valid_peaks(self, peak_steps): all_peak_steps = set([]) for key in peak_steps.keys(): for channel in peak_steps[key].keys(): all_peak_steps = all_peak_steps.union(peak_steps[key][channel]) all_peak_steps = np.array(list(all_peak_steps), dtype=np.int32) mask = self._is_valid(all_peak_steps) all_valid_peak_steps = all_peak_steps[mask] return all_valid_peak_steps @property def _empty_result(self): r = { 'offset': self.offset, } return r def _reset_result(self): self.r = { 'spike_times': np.zeros(0, dtype=np.int32), 'amplitudes': np.zeros(0, dtype=np.float32), 'templates': np.zeros(0, dtype=np.int32), 'offset': self.offset, } if self.with_rejected_times: self.r.update({ 'rejected_times': np.zeros(0, dtype=np.int32), 'rejected_amplitudes': np.zeros(0, dtype=np.float32) }) return def _extract_waveforms(self, peak_time_steps): """Extract waveforms from buffer Argument: peak_time_steps: np.array Peak time steps. Array of shape (number of peaks,). Return: waveforms: np.array Extracted waveforms. Array of shape (waveform size, number of peaks), where waveform size is equal to number of channels x number of samples. """ waveforms = self.x[peak_time_steps[:, None] + self.temp_window] waveforms = waveforms.transpose(2, 1, 0).reshape( self._overlaps_store.nb_elements, len(peak_time_steps)) return waveforms def _fit_chunk(self, verbose=False, timing=False): if verbose: # Log debug message. string = "{} fits spikes... ({} templates)" message = string.format(self.name_and_counter, self.nb_templates) self.log.debug(message) # Reset result. self._reset_result() # Compute the number of peaks in the current chunk. is_in_work_area = np.logical_and(self.work_area_start <= self.p, self.p < self.work_area_end) nb_peaks = np.count_nonzero(is_in_work_area) peaks = self.p[is_in_work_area] if verbose: # Log debug message. string = "{} has {} peaks in the work area among {} peaks" message = string.format(self.name, nb_peaks, len(self.p)) self.log.debug(message) if verbose: # Log debug message. string = "{} peaks: {}" message = string.format(self.name, peaks) self.log.debug(message) # If there is at least one peak in the work area... if 0 < nb_peaks: # Extract waveforms from buffer. waveforms = self._extract_waveforms(peaks) if timing: self._measure_time('scalar_products_start', period=10) # Compute the scalar products between waveforms and templates. scalar_products = self._overlaps_store.dot(waveforms) if timing: self._measure_time('scalar_products_end', period=10) # Initialize the failure counter of each peak. nb_failures = np.zeros(nb_peaks, dtype=np.int32) if verbose: # Log debug message. string = "{} buffer offset: {}" message = string.format(self.name, self.offset) self.log.debug(message) if timing: self._measure_time('while_loop_start', period=10) # TODO rewrite condition according to the 3 last lines of the nested while loop. # while not np.all(nb_failures == self.max_nb_trials): numerous_argmax = False nb_argmax = self.nb_templates best_indices = np.zeros(0, dtype=np.int32) # Set scalar products of tested matches to zero. data = scalar_products[:self._overlaps_store.nb_templates, :] data_flatten = data.ravel() min_scalar_product = self.min_scalar_product while np.mean(nb_failures) < self.nb_chances: # Find the best template. if numerous_argmax: if len(best_indices) == 0: best_indices = largest_indices(data_flatten, nb_argmax) best_template_index, peak_index = np.unravel_index( best_indices[0], data.shape) else: best_template_index, peak_index = np.unravel_index( data.argmax(), data.shape) # TODO remove peaks with scalar products equal to zero? # TODO consider the absolute values of the scalar products? # Compute the best amplitude. best_amplitude = scalar_products[best_template_index, peak_index] if best_amplitude < min_scalar_product: nb_failures[:] = self.nb_chances break if self._overlaps_store.two_components: best_scalar_product = scalar_products[best_template_index + self.nb_templates, peak_index] best_amplitude_2 = best_scalar_product # Compute the best normalized amplitude. best_amplitude_ = best_amplitude / self._overlaps_store.norms[ '1'][best_template_index] if self._overlaps_store.two_components: best_amplitude_2_ = best_amplitude_2 / self._overlaps_store.norms[ '2'][best_template_index] # Verify amplitude constraint. a_min = self._overlaps_store.amplitudes[best_template_index, 0] a_max = self._overlaps_store.amplitudes[best_template_index, 1] if (a_min <= best_amplitude_) & (best_amplitude_ <= a_max): if verbose: # Log debug message. string = "{} processes (p {}, t {}) -> (a {}, keep)" message = string.format(self.name, peak_index, best_template_index, best_amplitude) self.log.debug(message) if timing: self._measure_time('for_loop_accept_start', period=10) # Keep the matching. peak_time_step = peaks[peak_index] # # Compute the neighboring peaks. # # TODO use this definition of `is_neighbor` instead of the other. # is_neighbor = np.abs(peaks - peak_index) <= 2 * self._width if timing: self._measure_time('for_loop_update_start', period=10) if timing: self._measure_time('for_loop_update_1_start', period=10) # Update scalar products. # TODO simplify the following 11 lines. tmp = np.dot(np.ones((1, 1), dtype=np.int32), np.reshape(peaks, (1, nb_peaks))) tmp -= np.array([[peak_time_step]]) is_neighbor = np.abs(tmp) <= self._2_width ytmp = tmp[0, is_neighbor[0, :]] + self._2_width indices = np.zeros((self._overlaps_store.size, len(ytmp)), dtype=np.int32) indices[ytmp, np.arange(len(ytmp))] = 1 if timing: self._measure_time('for_loop_update_1_end', period=10) if timing: self._measure_time('for_loop_update_2_start', period=10) if timing: self._measure_time('for_loop_overlaps_start', period=10) tmp1_ = self._overlaps_store.get_overlaps( best_template_index, '1') if timing: self._measure_time('for_loop_overlaps_end', period=10) tmp1 = tmp1_.multiply(-best_amplitude) to_add = tmp1.toarray()[:, ytmp] scalar_products[:, is_neighbor[0, :]] += to_add if timing: self._measure_time('for_loop_update_2_end', period=10) if self._overlaps_store.two_components: tmp2_ = self._overlaps_store.get_overlaps( best_template_index, '2') tmp2 = tmp2_.multiply(-best_amplitude_2) to_add = tmp2.toarray()[:, ytmp] scalar_products[:, is_neighbor[0, :]] += to_add if timing: self._measure_time('for_loop_update_end', period=10) if timing: self._measure_time('for_loop_concatenate_start', period=10) # Add matching to the result. self.r['spike_times'] = np.concatenate( (self.r['spike_times'], [peak_time_step])) self.r['amplitudes'] = np.concatenate( (self.r['amplitudes'], [best_amplitude_])) self.r['templates'] = np.concatenate( (self.r['templates'], [best_template_index])) if timing: self._measure_time('for_loop_concatenate_end', period=10) # Mark current matching as tried. scalar_products[best_template_index, peak_index] = -np.inf best_indices = np.zeros(0, dtype=np.int32) if timing: self._measure_time('for_loop_accept_end', period=10) numerous_argmax = False else: numerous_argmax = True if verbose: # Log debug message. string = "{} processes (p {}, t {}) -> (a {}, reject)" message = string.format(self.name, peak_index, best_template_index, best_amplitude) self.log.debug(message) if timing: self._measure_time('for_loop_reject_start', period=10) # Reject the matching. # Update failure counter of the peak. nb_failures[peak_index] += 1 # If the maximal number of failures is reached then mark peak as solved (i.e. not fitted). if nb_failures[peak_index] == self.nb_chances: scalar_products[:, peak_index] = -np.inf else: scalar_products[best_template_index, peak_index] = -np.inf best_indices = best_indices[ data_flatten[best_indices] > -np.inf] # Add reject to the result if necessary. if self.with_rejected_times: self.r['rejected_times'] = np.concatenate( (self.r['rejected_times'], [peaks[peak_index]])) self.r['rejected_amplitudes'] = np.concatenate( (self.r['rejected_amplitudes'], [best_amplitude_])) if timing: self._measure_time('for_loop_reject_end', period=10) if timing: self._measure_time('while_loop_end', period=10) # Handle result. keys = ['spike_times', 'amplitudes', 'templates'] if self.with_rejected_times: keys += ['rejected_times', 'rejected_amplitudes'] # # Keep only spikes in the result area. is_in_result = np.logical_and( self.result_area_start <= self.r['spike_times'], self.r['spike_times'] < self.result_area_end) for key in keys: self.r[key] = self.r[key][is_in_result] # # Sort spike. indices = np.argsort(self.r['spike_times']) for key in keys: self.r[key] = self.r[key][indices] # # Modify spike time reference. self.r['spike_times'] = self.r['spike_times'] - self._nb_samples if verbose: # Log debug message. nb_spike_times = len(self.r['spike_times']) if nb_spike_times > 0: string = "{} fitted {} spikes ({} templates)" message = string.format(self.name_and_counter, nb_spike_times, self.nb_templates) else: string = "{} fitted no spikes ({} templates)" message = string.format(self.name_and_counter, self.nb_templates) self.log.debug(message) else: # i.e. nb_peaks == 0 if verbose: # Log debug message. string = "{} can't fit spikes ({} templates)" message = string.format(self.name_and_counter, self.nb_templates) self.log.debug(message) return @staticmethod def _merge_peaks(peaks): """Merge positive and negative peaks from all the channels.""" time_steps = set([]) keys = [key for key in peaks.keys() if key not in ['offset']] for key in keys: for channel in peaks[key].keys(): time_steps = time_steps.union(peaks[key][channel]) time_steps = np.array(list(time_steps), dtype=np.int32) time_steps = np.sort(time_steps) return time_steps @property def nb_buffers(self): return self.x.shape[0] // self._nb_samples @property def result_area_start(self): return (self.nb_buffers - 1) * self._nb_samples - self._nb_samples // 2 @property def result_area_end(self): return (self.nb_buffers - 1) * self._nb_samples + self._nb_samples // 2 @property def work_area_start(self): return self.result_area_start - self._width @property def work_area_end(self): return self.result_area_end + self._width @property def buffer_id(self): return self.counter * self._nb_fitters + self._fitter_id @property def first_buffer_id(self): # TODO check if the comment fix the "offset bug". return self.buffer_id # - 1 @property def offset(self): # TODO check if the comment fix the "offset bug". return self.first_buffer_id * self._nb_samples # + self.result_area_start def _collect_data(self, shift=0): k = (self.nb_buffers - 1) + shift data_packet = self.get_input('data').receive(blocking=True) self._number = data_packet['number'] self.x[k * self._nb_samples:(k + 1) * self._nb_samples, :] = data_packet['payload'] return def _handle_peaks(self, peaks): p = self._nb_samples + self._merge_peaks(peaks) self.p = self.p - self._nb_samples self.p = self.p[0 <= self.p] self.p = np.concatenate((self.p, p)) return def _collect_peaks(self, verbose=False): if self.is_active: peaks_packet = self.get_input('peaks').receive(blocking=True, number=self._number) if peaks_packet is None: # This is the last packet (last data packet don't have a corresponding peak packet since the peak # detector needs two consecutive data packets to produce one peak packet). peaks = {} else: peaks = peaks_packet['payload']['peaks'] self._handle_peaks(peaks) if verbose: # Log debug message. string = "{} collects peaks {} (reg)" message = string.format(self.name, peaks_packet['payload']['offset']) self.log.debug(message) else: if self.get_input('peaks').has_received(): peaks_packet = self.get_input('peaks').receive( blocking=True, number=self._number) if peaks_packet is not None: peaks = peaks_packet['payload']['peaks'] p = self._nb_samples + self._merge_peaks(peaks) self.p = p if verbose: # Log debug message. string = "{} collects peaks {} (init)" message = string.format( self.name, peaks_packet['payload']['offset']) self.log.debug(message) # Set active mode. self._set_active_mode() else: self.p = None else: self.p = None return def _process(self, verbose=False, timing=False): if timing: self._measure_time('preamble_start', period=10) # First, collect all the buffers we need. # # Prepare everything to collect buffers. if self.counter == 0: # Initialize 'self.x'. shape = (2 * self._nb_samples, self._nb_channels) self.x = np.zeros(shape, dtype=np.float32) elif self._nb_fitters == 1: # Copy the end of 'self.x' at its beginning. self.x[0 * self._nb_samples:1 * self._nb_samples, :] = \ self.x[1 * self._nb_samples:2 * self._nb_samples, :] else: pass # # Collect precedent data and peaks buffers. if self._nb_fitters > 1 and not (self.counter == 0 and self._fitter_id == 0): self._collect_data(shift=-1) self._collect_peaks(verbose=verbose) # # Collect current data and peaks buffers. self._collect_data(shift=0) self._collect_peaks(verbose=verbose) # # Collect current updater buffer. updater_packet = self.get_input('updater').receive( blocking=False, discarding_eoc=self.discarding_eoc_from_updater) updater = updater_packet[ 'payload'] if updater_packet is not None else None if timing: self._measure_time('preamble_end', period=10) if updater is not None: self._measure_time('update_start', period=1) while updater is not None: # Log debug message. string = "{} modifies template and overlap stores" message = string.format(self.name_and_counter) self.log.debug(message) # Modify template and overlap stores. indices = updater.get('indices', None) _ = indices # Discard unused variable. if self._template_store is None: # Initialize template and overlap stores. self._template_store = TemplateStore( updater['template_store'], mode='r') self._overlaps_store = OverlapsStore( template_store=self._template_store, path=updater['overlaps']['path'], fitting_mode=True) self._init_temp_window() # Log debug message. string = "{} initializes template and overlap stores ({}, {})" message = string.format(self.name_and_counter, updater['template_store'], updater['overlaps']['path']) self.log.debug(message) else: # TODO avoid duplicates in template store and uncomment the 3 following lines. # Update template and overlap stores. laziness = updater['overlaps']['path'] is None self._overlaps_store.update(indices, laziness=laziness) # Log debug message. string = "{} updates template and overlap stores" message = string.format(self.name_and_counter) self.log.debug(message) # Log debug message. string = "{} modified template and overlap stores" message = string.format(self.name_and_counter) self.log.debug(message) updater_packet = self.get_input('updater').receive( blocking=False, discarding_eoc=self.discarding_eoc_from_updater) updater = updater_packet[ 'payload'] if updater_packet is not None else None self._measure_time('update_end', period=1) if self.p is not None: if self.nb_templates > 0: self._measure_time('start') if timing: self._measure_time('fit_start', period=10) self._fit_chunk(verbose=verbose, timing=timing) if timing: self._measure_time('fit_end', period=10) if timing: self._measure_time('output_start', period=10) packet = { 'number': self._number, 'payload': self.r, } self.get_output('spikes').send(packet) if timing: self._measure_time('output_end', period=10) self._measure_time('end') elif self._nb_fitters > 1: packet = { 'number': self._number, 'payload': self._empty_result, } self.get_output('spikes').send(packet) elif self._nb_fitters > 1: packet = { 'number': self._number, 'payload': self._empty_result, } self.get_output('spikes').send(packet) return def _introspect(self): """Introspection.""" nb_buffers = self.counter - self.start_step start_times = np.array(self._measured_times.get('start', [])) end_times = np.array(self._measured_times.get('end', [])) durations = end_times - start_times data_duration = float( self._nb_fitters * self._nb_samples) / self.sampling_rate ratios = data_duration / durations min_ratio = np.min(ratios) if ratios.size > 0 else np.nan mean_ratio = np.mean(ratios) if ratios.size > 0 else np.nan max_ratio = np.max(ratios) if ratios.size > 0 else np.nan # Log info message. string = "{} processed {} buffers [speed:x{:.2f} (min:x{:.2f}, max:x{:.2f})]" message = string.format(self.name, nb_buffers, mean_ratio, min_ratio, max_ratio) self.log.info(message) return