def __init__(self,
                 template_store=None,
                 cc_merge=None,
                 cc_mixture=None,
                 optimize=True,
                 overlap_path=None):

        self.template_store = template_store
        self.nb_channels = self.template_store.nb_channels
        self.cc_merge = cc_merge
        self.cc_mixture = cc_mixture
        self._duplicates = None
        self._mixtures = None
        self.optimize = optimize
        self._indices = []
        self.overlaps_store = OverlapsStore(self.template_store,
                                            optimize=self.optimize,
                                            path=overlap_path)
    def _initialize_templates(self):

        self._template_store = TemplateStore(self.templates_init_path,
                                             mode='r')

        # Log info message.
        string = "{} is initialized with {} templates from {}"
        message = string.format(self.name, self._template_store.nb_templates,
                                self.templates_init_path)
        self.log.info(message)

        self._overlaps_store = OverlapsStore(
            template_store=self._template_store,
            path=self.overlaps_init_path,
            fitting_mode=True)

        # Log info message.
        string = "{} is initialized with precomputed overlaps from {}"
        message = string.format(self.name, self.overlaps_init_path)
        self.log.info(message)

        return
    def _initialize(self):
        """Initialize template updater."""

        # Initialize path to save the templates.
        if self.templates_path is None:
            self.templates_path = self._get_tmp_path()
        else:
            self.templates_path = os.path.expanduser(self.templates_path)
            self.templates_path = os.path.abspath(self.templates_path)

        # Create the corresponding directory if it does not exist.
        data_directory, _ = os.path.split(self.templates_path)
        if not os.path.exists(data_directory):
            os.makedirs(data_directory)

        if self.skip_overlaps:
            self.overlaps_path = None

        # Create object to handle templates.
        self._template_store = TemplateStore(self.templates_path, probe_file=self.probe_path, mode='w')
        self._template_dictionary = TemplateDictionary(self._template_store, cc_merge=self.cc_merge,
                                                       cc_mixture=self.cc_mixture)
        # Create object to handle overlaps.
        self._overlap_store = OverlapsStore(template_store=self._template_store, path=self.overlaps_path)

        # Log info message.
        string = "{} records templates into {}"
        message = string.format(self.name, self.templates_path)
        self.log.info(message)

        # Define precomputed templates (if necessary).
        if self.precomputed_template_paths is not None:

            # Load precomputed templates.
            precomputed_templates = [
                load_template(path)
                for path in self.precomputed_template_paths
            ]

            # Add precomputed templates to the dictionary.
            accepted = self._template_dictionary.initialize(precomputed_templates)

            # Log some information.
            if len(accepted) > 0:
                string = "{} added {} precomputed templates"
                message = string.format(self.name, len(accepted))
                self.log.debug(message)

            # Update precomputed overlaps.
            if not self.skip_overlaps:
                self._overlap_store.update(accepted)
                self._overlap_store.compute_overlaps()
                # Save precomputed overlaps to disk.
                self._overlap_store.save_overlaps()

            # Log some information.
            if len(accepted) > 0:
                string = "{} precomputed overlaps"
                message = string.format(self.name)
                self.log.debug(message)

            # Send output data.
            self._precomputed_output = {
                'indices': accepted,
                'template_store': self._template_store.file_name,
                'overlaps': self._overlap_store.to_json,
            }

        else:

            self._precomputed_output = None

        return
class TemplateUpdater(Block):
    """Template updater.

    Attributes:
        probe_path: string
        radius: float
        cc_merge: float
        cc_mixture: float
        templates_path: string
        overlaps_path: string
        precomputed_template_paths: none | list
        sampling_rate: float
        nb_samples: integer
    """

    name = "Template updater"

    params = {
        'probe_path': None,
        'radius': None,
        'cc_merge': 0.95,
        'cc_mixture': None,
        'templates_path': None,
        'overlaps_path': None,
        'precomputed_template_paths': None,
        'sampling_rate': 20e+3,
        'nb_samples': 1024,
        'skip_overlaps': False
    }

    def __init__(self, **kwargs):
        """Initialize template updater.

        Arguments:
            probe_path: string
            radius: none | float (optional)
            cc_merge: float (optional)
            cc_mixture: none | float (optional)
            templates_path: none | string (optional)
            overlaps_path: none | string (optional)
            precomputed_template_paths: none | list (optional)
            sampling_rate: float (optional)
            nb_samples: integer (optional)
        """

        Block.__init__(self, **kwargs)

        # The following lines are useful to avoid some PyCharm's warnings.
        self.probe_path = self.probe_path
        self.radius = self.radius
        self.cc_merge = self.cc_merge
        self.cc_mixture = self.cc_mixture
        self.templates_path = self.templates_path
        self.overlaps_path = self.overlaps_path
        self.precomputed_template_paths = self.precomputed_template_paths
        self.sampling_rate = self.sampling_rate
        self.nb_samples = self.nb_samples
        self.skip_overlaps = self.skip_overlaps

        # Initialize private attributes.
        if self.probe_path is None:
            self.probe = None
            # Log error message.
            string = "{}: the probe file must be specified!"
            message = string.format(self.name)
            self.log.error(message)
        else:
            self.probe = load_probe(self.probe_path, radius=self.radius, logger=self.log)
            # Log info message.
            string = "{} reads the probe layout"
            message = string.format(self.name)
            self.log.info(message)
        self._template_store = None
        self._template_dictionary = None
        self._overlap_store = None
        self._two_components = None

        self.add_input('templates', structure='dict')
        self.add_output('updater', structure='dict')

    def _initialize(self):
        """Initialize template updater."""

        # Initialize path to save the templates.
        if self.templates_path is None:
            self.templates_path = self._get_tmp_path()
        else:
            self.templates_path = os.path.expanduser(self.templates_path)
            self.templates_path = os.path.abspath(self.templates_path)

        # Create the corresponding directory if it does not exist.
        data_directory, _ = os.path.split(self.templates_path)
        if not os.path.exists(data_directory):
            os.makedirs(data_directory)

        if self.skip_overlaps:
            self.overlaps_path = None

        # Create object to handle templates.
        self._template_store = TemplateStore(self.templates_path, probe_file=self.probe_path, mode='w')
        self._template_dictionary = TemplateDictionary(self._template_store, cc_merge=self.cc_merge,
                                                       cc_mixture=self.cc_mixture)
        # Create object to handle overlaps.
        self._overlap_store = OverlapsStore(template_store=self._template_store, path=self.overlaps_path)

        # Log info message.
        string = "{} records templates into {}"
        message = string.format(self.name, self.templates_path)
        self.log.info(message)

        # Define precomputed templates (if necessary).
        if self.precomputed_template_paths is not None:

            # Load precomputed templates.
            precomputed_templates = [
                load_template(path)
                for path in self.precomputed_template_paths
            ]

            # Add precomputed templates to the dictionary.
            accepted = self._template_dictionary.initialize(precomputed_templates)

            # Log some information.
            if len(accepted) > 0:
                string = "{} added {} precomputed templates"
                message = string.format(self.name, len(accepted))
                self.log.debug(message)

            # Update precomputed overlaps.
            if not self.skip_overlaps:
                self._overlap_store.update(accepted)
                self._overlap_store.compute_overlaps()
                # Save precomputed overlaps to disk.
                self._overlap_store.save_overlaps()

            # Log some information.
            if len(accepted) > 0:
                string = "{} precomputed overlaps"
                message = string.format(self.name)
                self.log.debug(message)

            # Send output data.
            self._precomputed_output = {
                'indices': accepted,
                'template_store': self._template_store.file_name,
                'overlaps': self._overlap_store.to_json,
            }

        else:

            self._precomputed_output = None

        return

    @staticmethod
    def _get_tmp_path():

        tmp_directory = tempfile.gettempdir()
        tmp_basename = "templates.h5"
        tmp_path = os.path.join(tmp_directory, tmp_basename)

        return tmp_path

    def _data_to_templates(self, data):

        all_templates = []
        keys = [key for key in data.keys() if key not in ['offset']]
        for key in keys:
            for channel in data[key].keys():
                templates = []
                for template in data[key][channel].values():

                    templates += [load_template_from_dict(template, self.probe)]

                if len(templates) > 0:
                    # Log debug message.
                    string = "{} received {} {} templates from electrode {}"
                    message = string.format(self.name, len(templates), key, channel)
                    self.log.debug(message)

                all_templates += templates

        return all_templates

    def _process(self):

        # Send precomputed templates.
        if self.counter == 0 and self._precomputed_output is not None:
            # Prepare output packet.
            packet = {
                'number': -1,
                'payload': self._precomputed_output,
            }
            # Send templates.
            self.get_output('updater').send(packet)

        # Receive input data.
        templates_packet = self.get_input('templates').receive(blocking=False)
        data = templates_packet['payload'] if templates_packet is not None else None

        if data is not None:

            self._measure_time('start', period=1)

            # Set mode as active (if necessary).
            if not self.is_active:
                self._set_active_mode()

            # Add received templates to the dictionary.
            templates = self._data_to_templates(data)
            self._measure_time('add_template_start', period=1)
            accepted, nb_duplicates, nb_mixtures = self._template_dictionary.add(templates)
            self._measure_time('add_template_end', period=1)

            # Log debug messages (if necessary).
            if nb_duplicates > 0:
                # Log debug message.
                string = "{} rejected {} duplicated templates"
                message = string.format(self.name, nb_duplicates)
                self.log.debug(message)
            if nb_mixtures > 0:
                # Log debug message.
                string = "{} rejected {} composite templates"
                message = string.format(self.name, nb_mixtures)
                self.log.debug(message)
            if len(accepted) > 0:
                # Log debug message.
                string = "{} accepted {} templates"
                message = string.format(self.name, len(accepted))
                self.log.debug(message)

            # Update and pre-compute the overlaps.
            self._overlap_store.update(accepted)
    
            if not self.skip_overlaps:
                self._measure_time('compute_overlap_start', period=1)
                self._overlap_store.compute_overlaps()
                self._measure_time('compute_overlap_end', period=1)
                # Log debug message.
                string = "{} updates and pre-computes the overlaps"
                message = string.format(self.name_and_counter)
                self.log.debug(message)

                # Save precomputed overlaps to disk.
                self._measure_time('save_overlap_start', period=1)
                self._overlap_store.save_overlaps()
                self._measure_time('save_overlap_end', period=1)
                # Log debug message.
                string = "{} saves precomputed overlaps"
                message = string.format(self.name_and_counter)
                self.log.debug(message)

            # Prepare output data.
            output_data = {
                'indices': accepted,
                'template_store': self._template_store.file_name,
                'overlaps': self._overlap_store.to_json,
            }
            # Prepare output packet.
            output_packet = {
                'number': templates_packet['number'],
                'payload': output_data,
            }
            # Send output packet.
            self.get_output('updater').send(output_packet)
            # Log debug message.
            string = "{} sends output packet"
            message = string.format(self.name_and_counter)
            self.log.debug(message)

            self._measure_time('end', period=1)

        return

    def _introspect(self):

        nb_buffers = self.counter - self.start_step
        start_times = np.array(self._measured_times.get('start', []))
        end_times = np.array(self._measured_times.get('end', []))
        durations = end_times - start_times
        data_duration = float(self.nb_samples) / self.sampling_rate
        ratios = data_duration / durations

        min_ratio = np.min(ratios) if ratios.size > 0 else np.nan
        mean_ratio = np.mean(ratios) if ratios.size > 0 else np.nan
        max_ratio = np.max(ratios) if ratios.size > 0 else np.nan

        # Log info message.
        string = "{} processed {} buffers [speed:x{:.2f} (min:x{:.2f}, max:x{:.2f})]"
        message = string.format(self.name, nb_buffers, mean_ratio, min_ratio, max_ratio)
        self.log.info(message)

        return
class TemplateDictionary(object):
    def __init__(self,
                 template_store=None,
                 cc_merge=None,
                 cc_mixture=None,
                 optimize=True,
                 overlap_path=None):

        self.template_store = template_store
        self.nb_channels = self.template_store.nb_channels
        self.cc_merge = cc_merge
        self.cc_mixture = cc_mixture
        self._duplicates = None
        self._mixtures = None
        self.optimize = optimize
        self._indices = []
        self.overlaps_store = OverlapsStore(self.template_store,
                                            optimize=self.optimize,
                                            path=overlap_path)

    @property
    def is_empty(self):

        return len(self.template_store) == 0

    @property
    def first_component(self):

        return self.overlaps_store.first_component

    def _init_from_template(self, template):
        """Initialize template dictionary based on a sampled template.

        Argument:
            template: circusort.obj.Template
                The sampled template used to initialize the dictionary. This template won't be added to the dictionary.
        """

        self.nb_elements = self.nb_channels * template.temporal_width

        return

    def __str__(self):

        string = """
        Template dictionary with {} templates
        rejected as duplicates: {}
        rejected as mixtures  : {}
        """.format(self.nb_templates, self.nb_duplicates, self.nb_mixtures)

        return string

    def __iter__(self, index):

        for i in self.first_component:
            yield self[i]

        return

    def __getitem__(self, index):

        return self.first_component[index]

    def __len__(self):

        return self.nb_templates

    @property
    def to_json(self):

        result = {
            'template_store': self.template_store.file_name,
            'overlaps': self.overlaps_store.to_json
        }

        return result

    @property
    def nb_templates(self):

        return self.overlaps_store.nb_templates

    @property
    def nb_mixtures(self):

        if self._mixtures is None:
            nb_mixtures = 0
        else:
            nb_mixtures = np.sum(
                [len(value) for value in self._mixtures.items()])

        return nb_mixtures

    @property
    def nb_duplicates(self):

        if self._duplicates is None:
            nb_duplicates = 0
        else:
            nb_duplicates = np.sum(
                [len(value) for value in self._duplicates.items()])

        return nb_duplicates

    def _add_duplicates(self, template):

        if self._duplicates is None:
            self._duplicates = {}

        if template.channel in self._duplicates:
            self._duplicates[template.channel] += [template.creation_time]
        else:
            self._duplicates[template.channel] = [template.creation_time]

        return

    def _add_mixtures(self, template):

        if self._mixtures is None:
            self._mixtures = {}

        if template.channel in self._mixtures:
            self._mixtures[template.channel] += [template.creation_time]
        else:
            self._mixtures[template.channel] = [template.creation_time]

        return

    def _add_template(self, template):
        """Add a template to the template dictionary.

        Arguments:
            template: circusort.obj.Template
        Return:
            indices: list
                A list which contains the indices of templates successfully added to the underlying template store.
        """

        indices = self.template_store.add(template)
        self.overlaps_store.add_template(template)

        return indices

    def compute_overlaps(self):

        self.overlaps_store.compute_overlaps()

    def save_overlaps(self):

        self.overlaps_store.save_overlaps()

    def non_zeros(self, channel_indices):
        """Get indices of templates whose spatial supports include at least one of the given channel indices.

        Argument:
            channel_indices: iterable
                The channel indices.
        Return:
            template_indices: numpy.ndarray
                The template indices.
        """

        template_indices = np.array([
            k for k, channel_indices_bis in enumerate(self._indices)
            if np.any(np.in1d(channel_indices_bis, channel_indices))
        ],
                                    dtype=np.int32)

        return template_indices

    def initialize(self, templates):
        """Initialize the template dictionary with templates.

        Argument:
            templates
        Return:
            accepted: list
                A list which contains the indices of templates successfully added to the underlying template store.
        """

        accepted, _, _ = self.add(templates, force=True)

        return accepted

    def add(self, templates, force=False):
        """Add templates to the template dictionary.

        Arguments:
            templates
            force: boolean (optional)
                The default value is False.
        Returns:
            accepted: list
                A list which contains the indices of templates successfully added to the underlying template store.
            nb_duplicates: integer
                The number of duplicates.
            nb_mixtures: integer
                The number of mixtures.
        """

        nb_duplicates = 0
        nb_mixtures = 0
        accepted = []

        # Initialize the dictionary (if necessary).
        if self.is_empty:
            template = next(iter(templates))
            self._init_from_template(template)

        if force:
            # Add all the given templates without checking duplicates and mixtures.

            for t in templates:

                accepted += self._add_template(t)

                if self.optimize:
                    self._indices += [t.indices]

        else:

            for t in templates:

                csr_template = t.first_component.to_sparse('csr', flatten=True)
                norm = t.first_component.norm
                csr_template /= norm

                if self.optimize:
                    non_zeros = self.non_zeros(t.indices)
                else:
                    non_zeros = None

                # Check if the template is already in this dictionary.
                is_present = self._is_present(csr_template, non_zeros)
                if is_present:
                    nb_duplicates += 1
                    self._add_duplicates(t)

                # TODO Removing mixture online is difficult.
                # TODO Some mixture might have been added before the templates which compose this mixture.
                # TODO For now, let's say that there is no mixture in the signal.

                # # Check if the template is a mixture of templates already present in the dictionary.
                # is_mixture = self._is_mixture(csr_template, non_zeros)
                # if is_mixture:
                #     nb_mixtures += 1
                #     self._add_mixtures(t)

                is_mixture = False

                if not is_present and not is_mixture:
                    accepted += self._add_template(t)

                    if self.optimize:
                        self._indices += [t.indices]

        return accepted, nb_duplicates, nb_mixtures

    def _is_present(self, csr_template, non_zeros=None):

        if (self.cc_merge is None) or (self.nb_templates == 0):
            return False

        return self.overlaps_store.is_present(csr_template, self.cc_merge,
                                              non_zeros)

    def _is_mixture(self, csr_template, non_zeros=None):

        if (self.cc_mixture is None) or (self.nb_templates == 0):
            return False

        # TODO complete/clean.
        _ = csr_template  # discard this argument
        _ = non_zeros  # discard this argument
        # return self.overlaps_store.is_mixture(csr_template, self.cc_mixture, non_zeros)
        return False
    def _process(self, verbose=False, timing=False):

        if timing:
            self._measure_time('preamble_start', period=10)

        # First, collect all the buffers we need.
        # # Prepare everything to collect buffers.
        if self.counter == 0:
            # Initialize 'self.x'.
            shape = (2 * self._nb_samples, self._nb_channels)
            self.x = np.zeros(shape, dtype=np.float32)
        elif self._nb_fitters == 1:
            # Copy the end of 'self.x' at its beginning.
            self.x[0 * self._nb_samples:1 * self._nb_samples, :] = \
                self.x[1 * self._nb_samples:2 * self._nb_samples, :]
        else:
            pass
        # # Collect precedent data and peaks buffers.
        if self._nb_fitters > 1 and not (self.counter == 0
                                         and self._fitter_id == 0):
            self._collect_data(shift=-1)
            self._collect_peaks(verbose=verbose)
        # # Collect current data and peaks buffers.
        self._collect_data(shift=0)
        self._collect_peaks(verbose=verbose)
        # # Collect current updater buffer.
        updater_packet = self.get_input('updater').receive(
            blocking=False, discarding_eoc=self.discarding_eoc_from_updater)
        updater = updater_packet[
            'payload'] if updater_packet is not None else None

        if timing:
            self._measure_time('preamble_end', period=10)

        if updater is not None:

            self._measure_time('update_start', period=1)

            while updater is not None:

                # Log debug message.
                string = "{} modifies template and overlap stores"
                message = string.format(self.name_and_counter)
                self.log.debug(message)

                # Modify template and overlap stores.
                indices = updater.get('indices', None)
                _ = indices  # Discard unused variable.
                if self._template_store is None:
                    # Initialize template and overlap stores.
                    self._template_store = TemplateStore(
                        updater['template_store'], mode='r')
                    self._overlaps_store = OverlapsStore(
                        template_store=self._template_store,
                        path=updater['overlaps']['path'],
                        fitting_mode=True)
                    self._init_temp_window()
                    # Log debug message.
                    string = "{} initializes template and overlap stores ({}, {})"
                    message = string.format(self.name_and_counter,
                                            updater['template_store'],
                                            updater['overlaps']['path'])
                    self.log.debug(message)

                else:

                    # TODO avoid duplicates in template store and uncomment the 3 following lines.
                    # Update template and overlap stores.
                    laziness = updater['overlaps']['path'] is None
                    self._overlaps_store.update(indices, laziness=laziness)
                    # Log debug message.
                    string = "{} updates template and overlap stores"
                    message = string.format(self.name_and_counter)
                    self.log.debug(message)

                # Log debug message.
                string = "{} modified template and overlap stores"
                message = string.format(self.name_and_counter)
                self.log.debug(message)

                updater_packet = self.get_input('updater').receive(
                    blocking=False,
                    discarding_eoc=self.discarding_eoc_from_updater)
                updater = updater_packet[
                    'payload'] if updater_packet is not None else None

            self._measure_time('update_end', period=1)

        if self.p is not None:

            if self.nb_templates > 0:

                self._measure_time('start')

                if timing:
                    self._measure_time('fit_start', period=10)
                self._fit_chunk(verbose=verbose, timing=timing)
                if timing:
                    self._measure_time('fit_end', period=10)
                if timing:
                    self._measure_time('output_start', period=10)
                packet = {
                    'number': self._number,
                    'payload': self.r,
                }
                self.get_output('spikes').send(packet)
                if timing:
                    self._measure_time('output_end', period=10)

                self._measure_time('end')

            elif self._nb_fitters > 1:

                packet = {
                    'number': self._number,
                    'payload': self._empty_result,
                }
                self.get_output('spikes').send(packet)

        elif self._nb_fitters > 1:

            packet = {
                'number': self._number,
                'payload': self._empty_result,
            }
            self.get_output('spikes').send(packet)

        return
class Fitter(Block):
    """Fitter

    Attributes:
        templates_init_path: none | string (optional)
            Path to the location used to load templates to initialize the
            dictionary of templates. If equal to None, this dictionary will
            start empty. The default value is None.
        overlaps_init_path: none | string (optional)
            Path to the location used to load the overlaps to initialize the
            overlap store.
            The default value is None.
        with_rejected_times: boolean (optional)
            The default value is False.
        sampling_rate: float (optional)
            The default value is 20e+3.
        discarding_eoc_from_updater: boolean (optional)
            The default value is False.
    """

    name = "Fitter"

    params = {
        'templates_init_path': None,
        'overlaps_init_path': None,
        'with_rejected_times': False,
        'sampling_rate': 20e+3,
        'discarding_eoc_from_updater': False,
        '_nb_fitters': 1,
        '_fitter_id': 0,
    }

    def __init__(self, **kwargs):
        """Initialize fitter

        Arguments:
            templates_init_path: string (optional)
            overlaps_init_path: string (optional)
            with_rejected_times: boolean (optional)
            sampling_rate: float (optional)
            discarding_eoc_from_updater: boolean (optional)
            _nb_fitters: integer (optional)
            _fitter_id: integer (optional)
        """

        Block.__init__(self, **kwargs)

        # The following lines are useful to avoid some PyCharm's warnings.
        self.templates_init_path = self.templates_init_path
        self.overlaps_init_path = self.overlaps_init_path
        self.with_rejected_times = self.with_rejected_times
        self.sampling_rate = self.sampling_rate
        self.discarding_eoc_from_updater = self.discarding_eoc_from_updater
        self._nb_fitters = self._nb_fitters
        self._fitter_id = self._fitter_id

        # Initialize private attributes.
        self._template_store = None
        self._overlaps_store = None

        self.add_input('updater', structure='dict')
        self.add_input('data', structure='dict')
        self.add_input('peaks', structure='dict')
        self.add_output('spikes', structure='dict')

        self._nb_channels = None
        self._nb_samples = None
        self._number = None

    def _initialize(self):

        self.space_explo = 0.5
        self.nb_chances = 3

        if self.templates_init_path is not None:
            self.templates_init_path = os.path.expanduser(
                self.templates_init_path)
            self.templates_init_path = os.path.abspath(
                self.templates_init_path)
            self._initialize_templates()

        # Variables used to handle buffer edges.
        self.x = None  # voltage signal
        self.p = None  # peak time steps
        self.r = {  # temporary result
            'spike_times': np.zeros(0, dtype=np.int32),
            'amplitudes': np.zeros(0, dtype=np.float32),
            'templates': np.zeros(0, dtype=np.int32),
        }
        if self.with_rejected_times:
            self.r.update({
                'rejected_times': np.zeros(0, dtype=np.int32),
                'rejected_amplitudes': np.zeros(0, dtype=np.float32),
            })

        return

    def _initialize_templates(self):

        self._template_store = TemplateStore(self.templates_init_path,
                                             mode='r')

        # Log info message.
        string = "{} is initialized with {} templates from {}"
        message = string.format(self.name, self._template_store.nb_templates,
                                self.templates_init_path)
        self.log.info(message)

        self._overlaps_store = OverlapsStore(
            template_store=self._template_store,
            path=self.overlaps_init_path,
            fitting_mode=True)

        # Log info message.
        string = "{} is initialized with precomputed overlaps from {}"
        message = string.format(self.name, self.overlaps_init_path)
        self.log.info(message)

        return

    @property
    def nb_templates(self):

        if self._overlaps_store is not None:
            nb_templates = self._overlaps_store.nb_templates
        else:
            nb_templates = 0

        return nb_templates

    @property
    def min_scalar_product(self):
        return np.min(self._overlaps_store.amplitudes[:, 0] *
                      self._overlaps_store.norms['1'])

    def _configure_input_parameters(self,
                                    nb_channels=None,
                                    nb_samples=None,
                                    **kwargs):

        if nb_channels is not None:
            self._nb_channels = nb_channels
        if nb_samples is not None:
            self._nb_samples = nb_samples

        return

    def _update_initialization(self):

        if self.templates_init_path is not None:
            self._init_temp_window()

        return

    def _init_temp_window(self):

        self._width = (self._overlaps_store.temporal_width - 1) // 2
        self._2_width = 2 * self._width
        self.temp_window = np.arange(-self._width, self._width + 1)
        buffer_size = 2 * self._nb_samples
        return

    def _is_valid(self, peak_step):

        i_min = self._width
        i_max = self._nb_samples - self._width
        is_valid = (i_min <= peak_step) & (peak_step < i_max)

        return is_valid

    def _get_all_valid_peaks(self, peak_steps):

        all_peak_steps = set([])
        for key in peak_steps.keys():
            for channel in peak_steps[key].keys():
                all_peak_steps = all_peak_steps.union(peak_steps[key][channel])
        all_peak_steps = np.array(list(all_peak_steps), dtype=np.int32)

        mask = self._is_valid(all_peak_steps)
        all_valid_peak_steps = all_peak_steps[mask]

        return all_valid_peak_steps

    @property
    def _empty_result(self):

        r = {
            'offset': self.offset,
        }

        return r

    def _reset_result(self):

        self.r = {
            'spike_times': np.zeros(0, dtype=np.int32),
            'amplitudes': np.zeros(0, dtype=np.float32),
            'templates': np.zeros(0, dtype=np.int32),
            'offset': self.offset,
        }
        if self.with_rejected_times:
            self.r.update({
                'rejected_times': np.zeros(0, dtype=np.int32),
                'rejected_amplitudes': np.zeros(0, dtype=np.float32)
            })

        return

    def _extract_waveforms(self, peak_time_steps):
        """Extract waveforms from buffer

        Argument:
            peak_time_steps: np.array
                Peak time steps. Array of shape (number of peaks,).

        Return:
            waveforms: np.array
                Extracted waveforms. Array of shape (waveform size, number of peaks), where waveform size is equal to
                number of channels x number of samples.
        """

        waveforms = self.x[peak_time_steps[:, None] + self.temp_window]
        waveforms = waveforms.transpose(2, 1, 0).reshape(
            self._overlaps_store.nb_elements, len(peak_time_steps))
        return waveforms

    def _fit_chunk(self, verbose=False, timing=False):

        if verbose:
            # Log debug message.
            string = "{} fits spikes... ({} templates)"
            message = string.format(self.name_and_counter, self.nb_templates)
            self.log.debug(message)

        # Reset result.
        self._reset_result()

        # Compute the number of peaks in the current chunk.
        is_in_work_area = np.logical_and(self.work_area_start <= self.p,
                                         self.p < self.work_area_end)
        nb_peaks = np.count_nonzero(is_in_work_area)
        peaks = self.p[is_in_work_area]

        if verbose:
            # Log debug message.
            string = "{} has {} peaks in the work area among {} peaks"
            message = string.format(self.name, nb_peaks, len(self.p))
            self.log.debug(message)

        if verbose:
            # Log debug message.
            string = "{} peaks: {}"
            message = string.format(self.name, peaks)
            self.log.debug(message)

        # If there is at least one peak in the work area...
        if 0 < nb_peaks:

            # Extract waveforms from buffer.
            waveforms = self._extract_waveforms(peaks)

            if timing:
                self._measure_time('scalar_products_start', period=10)
            # Compute the scalar products between waveforms and templates.
            scalar_products = self._overlaps_store.dot(waveforms)
            if timing:
                self._measure_time('scalar_products_end', period=10)

            # Initialize the failure counter of each peak.
            nb_failures = np.zeros(nb_peaks, dtype=np.int32)

            if verbose:
                # Log debug message.
                string = "{} buffer offset: {}"
                message = string.format(self.name, self.offset)
                self.log.debug(message)

            if timing:
                self._measure_time('while_loop_start', period=10)
            # TODO rewrite condition according to the 3 last lines of the nested while loop.
            # while not np.all(nb_failures == self.max_nb_trials):

            numerous_argmax = False
            nb_argmax = self.nb_templates
            best_indices = np.zeros(0, dtype=np.int32)

            # Set scalar products of tested matches to zero.
            data = scalar_products[:self._overlaps_store.nb_templates, :]
            data_flatten = data.ravel()
            min_scalar_product = self.min_scalar_product

            while np.mean(nb_failures) < self.nb_chances:

                # Find the best template.

                if numerous_argmax:
                    if len(best_indices) == 0:
                        best_indices = largest_indices(data_flatten, nb_argmax)
                    best_template_index, peak_index = np.unravel_index(
                        best_indices[0], data.shape)
                else:
                    best_template_index, peak_index = np.unravel_index(
                        data.argmax(), data.shape)

                # TODO remove peaks with scalar products equal to zero?
                # TODO consider the absolute values of the scalar products?

                # Compute the best amplitude.
                best_amplitude = scalar_products[best_template_index,
                                                 peak_index]

                if best_amplitude < min_scalar_product:
                    nb_failures[:] = self.nb_chances
                    break

                if self._overlaps_store.two_components:
                    best_scalar_product = scalar_products[best_template_index +
                                                          self.nb_templates,
                                                          peak_index]
                    best_amplitude_2 = best_scalar_product

                # Compute the best normalized amplitude.
                best_amplitude_ = best_amplitude / self._overlaps_store.norms[
                    '1'][best_template_index]
                if self._overlaps_store.two_components:
                    best_amplitude_2_ = best_amplitude_2 / self._overlaps_store.norms[
                        '2'][best_template_index]

                # Verify amplitude constraint.
                a_min = self._overlaps_store.amplitudes[best_template_index, 0]
                a_max = self._overlaps_store.amplitudes[best_template_index, 1]

                if (a_min <= best_amplitude_) & (best_amplitude_ <= a_max):
                    if verbose:
                        # Log debug message.
                        string = "{} processes (p {}, t {}) -> (a {}, keep)"
                        message = string.format(self.name, peak_index,
                                                best_template_index,
                                                best_amplitude)
                        self.log.debug(message)
                    if timing:
                        self._measure_time('for_loop_accept_start', period=10)
                    # Keep the matching.
                    peak_time_step = peaks[peak_index]
                    # # Compute the neighboring peaks.
                    # # TODO use this definition of `is_neighbor` instead of the other.
                    # is_neighbor = np.abs(peaks - peak_index) <= 2 * self._width

                    if timing:
                        self._measure_time('for_loop_update_start', period=10)
                    if timing:
                        self._measure_time('for_loop_update_1_start',
                                           period=10)
                    # Update scalar products.
                    # TODO simplify the following 11 lines.
                    tmp = np.dot(np.ones((1, 1), dtype=np.int32),
                                 np.reshape(peaks, (1, nb_peaks)))
                    tmp -= np.array([[peak_time_step]])
                    is_neighbor = np.abs(tmp) <= self._2_width
                    ytmp = tmp[0, is_neighbor[0, :]] + self._2_width
                    indices = np.zeros((self._overlaps_store.size, len(ytmp)),
                                       dtype=np.int32)
                    indices[ytmp, np.arange(len(ytmp))] = 1
                    if timing:
                        self._measure_time('for_loop_update_1_end', period=10)

                    if timing:
                        self._measure_time('for_loop_update_2_start',
                                           period=10)
                    if timing:
                        self._measure_time('for_loop_overlaps_start',
                                           period=10)
                    tmp1_ = self._overlaps_store.get_overlaps(
                        best_template_index, '1')
                    if timing:
                        self._measure_time('for_loop_overlaps_end', period=10)
                    tmp1 = tmp1_.multiply(-best_amplitude)
                    to_add = tmp1.toarray()[:, ytmp]

                    scalar_products[:, is_neighbor[0, :]] += to_add
                    if timing:
                        self._measure_time('for_loop_update_2_end', period=10)

                    if self._overlaps_store.two_components:
                        tmp2_ = self._overlaps_store.get_overlaps(
                            best_template_index, '2')
                        tmp2 = tmp2_.multiply(-best_amplitude_2)
                        to_add = tmp2.toarray()[:, ytmp]
                        scalar_products[:, is_neighbor[0, :]] += to_add

                    if timing:
                        self._measure_time('for_loop_update_end', period=10)

                    if timing:
                        self._measure_time('for_loop_concatenate_start',
                                           period=10)
                    # Add matching to the result.
                    self.r['spike_times'] = np.concatenate(
                        (self.r['spike_times'], [peak_time_step]))
                    self.r['amplitudes'] = np.concatenate(
                        (self.r['amplitudes'], [best_amplitude_]))
                    self.r['templates'] = np.concatenate(
                        (self.r['templates'], [best_template_index]))
                    if timing:
                        self._measure_time('for_loop_concatenate_end',
                                           period=10)
                    # Mark current matching as tried.
                    scalar_products[best_template_index, peak_index] = -np.inf
                    best_indices = np.zeros(0, dtype=np.int32)
                    if timing:
                        self._measure_time('for_loop_accept_end', period=10)

                    numerous_argmax = False
                else:

                    numerous_argmax = True

                    if verbose:
                        # Log debug message.
                        string = "{} processes (p {}, t {}) -> (a {}, reject)"
                        message = string.format(self.name, peak_index,
                                                best_template_index,
                                                best_amplitude)
                        self.log.debug(message)
                    if timing:
                        self._measure_time('for_loop_reject_start', period=10)
                    # Reject the matching.
                    # Update failure counter of the peak.
                    nb_failures[peak_index] += 1
                    # If the maximal number of failures is reached then mark peak as solved (i.e. not fitted).
                    if nb_failures[peak_index] == self.nb_chances:
                        scalar_products[:, peak_index] = -np.inf
                    else:
                        scalar_products[best_template_index,
                                        peak_index] = -np.inf

                    best_indices = best_indices[
                        data_flatten[best_indices] > -np.inf]

                    # Add reject to the result if necessary.
                    if self.with_rejected_times:
                        self.r['rejected_times'] = np.concatenate(
                            (self.r['rejected_times'], [peaks[peak_index]]))
                        self.r['rejected_amplitudes'] = np.concatenate(
                            (self.r['rejected_amplitudes'], [best_amplitude_]))
                    if timing:
                        self._measure_time('for_loop_reject_end', period=10)

            if timing:
                self._measure_time('while_loop_end', period=10)

            # Handle result.
            keys = ['spike_times', 'amplitudes', 'templates']
            if self.with_rejected_times:
                keys += ['rejected_times', 'rejected_amplitudes']
            # # Keep only spikes in the result area.
            is_in_result = np.logical_and(
                self.result_area_start <= self.r['spike_times'],
                self.r['spike_times'] < self.result_area_end)
            for key in keys:
                self.r[key] = self.r[key][is_in_result]
            # # Sort spike.
            indices = np.argsort(self.r['spike_times'])
            for key in keys:
                self.r[key] = self.r[key][indices]
            # # Modify spike time reference.
            self.r['spike_times'] = self.r['spike_times'] - self._nb_samples

            if verbose:
                # Log debug message.
                nb_spike_times = len(self.r['spike_times'])
                if nb_spike_times > 0:
                    string = "{} fitted {} spikes ({} templates)"
                    message = string.format(self.name_and_counter,
                                            nb_spike_times, self.nb_templates)
                else:
                    string = "{} fitted no spikes ({} templates)"
                    message = string.format(self.name_and_counter,
                                            self.nb_templates)
                self.log.debug(message)

        else:  # i.e. nb_peaks == 0

            if verbose:
                # Log debug message.
                string = "{} can't fit spikes ({} templates)"
                message = string.format(self.name_and_counter,
                                        self.nb_templates)
                self.log.debug(message)

        return

    @staticmethod
    def _merge_peaks(peaks):
        """Merge positive and negative peaks from all the channels."""

        time_steps = set([])
        keys = [key for key in peaks.keys() if key not in ['offset']]
        for key in keys:
            for channel in peaks[key].keys():
                time_steps = time_steps.union(peaks[key][channel])
        time_steps = np.array(list(time_steps), dtype=np.int32)
        time_steps = np.sort(time_steps)

        return time_steps

    @property
    def nb_buffers(self):

        return self.x.shape[0] // self._nb_samples

    @property
    def result_area_start(self):

        return (self.nb_buffers - 1) * self._nb_samples - self._nb_samples // 2

    @property
    def result_area_end(self):

        return (self.nb_buffers - 1) * self._nb_samples + self._nb_samples // 2

    @property
    def work_area_start(self):

        return self.result_area_start - self._width

    @property
    def work_area_end(self):

        return self.result_area_end + self._width

    @property
    def buffer_id(self):

        return self.counter * self._nb_fitters + self._fitter_id

    @property
    def first_buffer_id(self):

        # TODO check if the comment fix the "offset bug".
        return self.buffer_id  # - 1

    @property
    def offset(self):

        # TODO check if the comment fix the "offset bug".
        return self.first_buffer_id * self._nb_samples  # + self.result_area_start

    def _collect_data(self, shift=0):

        k = (self.nb_buffers - 1) + shift

        data_packet = self.get_input('data').receive(blocking=True)
        self._number = data_packet['number']
        self.x[k * self._nb_samples:(k + 1) *
               self._nb_samples, :] = data_packet['payload']

        return

    def _handle_peaks(self, peaks):

        p = self._nb_samples + self._merge_peaks(peaks)
        self.p = self.p - self._nb_samples
        self.p = self.p[0 <= self.p]
        self.p = np.concatenate((self.p, p))

        return

    def _collect_peaks(self, verbose=False):

        if self.is_active:
            peaks_packet = self.get_input('peaks').receive(blocking=True,
                                                           number=self._number)
            if peaks_packet is None:
                # This is the last packet (last data packet don't have a corresponding peak packet since the peak
                # detector needs two consecutive data packets to produce one peak packet).
                peaks = {}
            else:
                peaks = peaks_packet['payload']['peaks']
            self._handle_peaks(peaks)
            if verbose:
                # Log debug message.
                string = "{} collects peaks {} (reg)"
                message = string.format(self.name,
                                        peaks_packet['payload']['offset'])
                self.log.debug(message)
        else:
            if self.get_input('peaks').has_received():
                peaks_packet = self.get_input('peaks').receive(
                    blocking=True, number=self._number)
                if peaks_packet is not None:
                    peaks = peaks_packet['payload']['peaks']
                    p = self._nb_samples + self._merge_peaks(peaks)
                    self.p = p
                    if verbose:
                        # Log debug message.
                        string = "{} collects peaks {} (init)"
                        message = string.format(
                            self.name, peaks_packet['payload']['offset'])
                        self.log.debug(message)
                    # Set active mode.
                    self._set_active_mode()
                else:
                    self.p = None
            else:
                self.p = None

        return

    def _process(self, verbose=False, timing=False):

        if timing:
            self._measure_time('preamble_start', period=10)

        # First, collect all the buffers we need.
        # # Prepare everything to collect buffers.
        if self.counter == 0:
            # Initialize 'self.x'.
            shape = (2 * self._nb_samples, self._nb_channels)
            self.x = np.zeros(shape, dtype=np.float32)
        elif self._nb_fitters == 1:
            # Copy the end of 'self.x' at its beginning.
            self.x[0 * self._nb_samples:1 * self._nb_samples, :] = \
                self.x[1 * self._nb_samples:2 * self._nb_samples, :]
        else:
            pass
        # # Collect precedent data and peaks buffers.
        if self._nb_fitters > 1 and not (self.counter == 0
                                         and self._fitter_id == 0):
            self._collect_data(shift=-1)
            self._collect_peaks(verbose=verbose)
        # # Collect current data and peaks buffers.
        self._collect_data(shift=0)
        self._collect_peaks(verbose=verbose)
        # # Collect current updater buffer.
        updater_packet = self.get_input('updater').receive(
            blocking=False, discarding_eoc=self.discarding_eoc_from_updater)
        updater = updater_packet[
            'payload'] if updater_packet is not None else None

        if timing:
            self._measure_time('preamble_end', period=10)

        if updater is not None:

            self._measure_time('update_start', period=1)

            while updater is not None:

                # Log debug message.
                string = "{} modifies template and overlap stores"
                message = string.format(self.name_and_counter)
                self.log.debug(message)

                # Modify template and overlap stores.
                indices = updater.get('indices', None)
                _ = indices  # Discard unused variable.
                if self._template_store is None:
                    # Initialize template and overlap stores.
                    self._template_store = TemplateStore(
                        updater['template_store'], mode='r')
                    self._overlaps_store = OverlapsStore(
                        template_store=self._template_store,
                        path=updater['overlaps']['path'],
                        fitting_mode=True)
                    self._init_temp_window()
                    # Log debug message.
                    string = "{} initializes template and overlap stores ({}, {})"
                    message = string.format(self.name_and_counter,
                                            updater['template_store'],
                                            updater['overlaps']['path'])
                    self.log.debug(message)

                else:

                    # TODO avoid duplicates in template store and uncomment the 3 following lines.
                    # Update template and overlap stores.
                    laziness = updater['overlaps']['path'] is None
                    self._overlaps_store.update(indices, laziness=laziness)
                    # Log debug message.
                    string = "{} updates template and overlap stores"
                    message = string.format(self.name_and_counter)
                    self.log.debug(message)

                # Log debug message.
                string = "{} modified template and overlap stores"
                message = string.format(self.name_and_counter)
                self.log.debug(message)

                updater_packet = self.get_input('updater').receive(
                    blocking=False,
                    discarding_eoc=self.discarding_eoc_from_updater)
                updater = updater_packet[
                    'payload'] if updater_packet is not None else None

            self._measure_time('update_end', period=1)

        if self.p is not None:

            if self.nb_templates > 0:

                self._measure_time('start')

                if timing:
                    self._measure_time('fit_start', period=10)
                self._fit_chunk(verbose=verbose, timing=timing)
                if timing:
                    self._measure_time('fit_end', period=10)
                if timing:
                    self._measure_time('output_start', period=10)
                packet = {
                    'number': self._number,
                    'payload': self.r,
                }
                self.get_output('spikes').send(packet)
                if timing:
                    self._measure_time('output_end', period=10)

                self._measure_time('end')

            elif self._nb_fitters > 1:

                packet = {
                    'number': self._number,
                    'payload': self._empty_result,
                }
                self.get_output('spikes').send(packet)

        elif self._nb_fitters > 1:

            packet = {
                'number': self._number,
                'payload': self._empty_result,
            }
            self.get_output('spikes').send(packet)

        return

    def _introspect(self):
        """Introspection."""

        nb_buffers = self.counter - self.start_step
        start_times = np.array(self._measured_times.get('start', []))
        end_times = np.array(self._measured_times.get('end', []))
        durations = end_times - start_times
        data_duration = float(
            self._nb_fitters * self._nb_samples) / self.sampling_rate
        ratios = data_duration / durations

        min_ratio = np.min(ratios) if ratios.size > 0 else np.nan
        mean_ratio = np.mean(ratios) if ratios.size > 0 else np.nan
        max_ratio = np.max(ratios) if ratios.size > 0 else np.nan

        # Log info message.
        string = "{} processed {} buffers [speed:x{:.2f} (min:x{:.2f}, max:x{:.2f})]"
        message = string.format(self.name, nb_buffers, mean_ratio, min_ratio,
                                max_ratio)
        self.log.info(message)

        return