コード例 #1
0
def spike_trains(mon):
    indices = mon.i[:]
    t = mon.t[:]
    sort_indices = numpy.argsort(indices, kind='mergesort')
    used_indices, first_pos = numpy.unique(mon.i[:][sort_indices],
                                           return_index=True)
    sorted_values = t[sort_indices]
    dim = t.dim
    event_values = {}
    current_pos = 0  # position in the all_indices array
    for idx in range(mon.i[:].max()):
        if current_pos < len(
                used_indices) and used_indices[current_pos] == idx:
            if current_pos < len(used_indices) - 1:
                event_values[idx] = Quantity(
                    sorted_values[first_pos[current_pos]:first_pos[current_pos
                                                                   + 1]],
                    dim=dim,
                    copy=False)
            else:
                event_values[idx] = Quantity(
                    sorted_values[first_pos[current_pos]:],
                    dim=dim,
                    copy=False)
            current_pos += 1
        else:
            event_values[idx] = Quantity([], dim=dim)
    return event_values
コード例 #2
0
def test_store_restore_to_file_new_objects():
    # A more realistic test where the objects are completely re-created
    filename = tempfile.mktemp(suffix='state', prefix='brian_test')

    def create_net():
        # Use a bit of a complicated spike and connection pattern with
        # heterogeneous delays

        # Note: it is important that all objects have the same name, this would
        # be the case if we were running this in a new process but to not rely
        # on garbage collection we will assign explicit names here
        source = SpikeGeneratorGroup(
            5,
            np.arange(5).repeat(3),
            [3, 4, 1, 2, 3, 7, 5, 4, 1, 0, 5, 9, 7, 8, 9] * ms,
            name='source')
        target = NeuronGroup(10, 'v:1', name='target')
        synapses = Synapses(source,
                            target,
                            model='w:1',
                            on_pre='v+=w',
                            name='synapses')
        synapses.connect('j>=i')
        synapses.w = 'i*1.0 + j*2.0'
        synapses.delay = '(5-i)*ms'
        state_mon = StateMonitor(target, 'v', record=True, name='statemonitor')
        input_spikes = SpikeMonitor(source, name='input_spikes')
        net = Network(source, target, synapses, state_mon, input_spikes)
        return net

    net = create_net()
    net.store(filename=filename)  # default time slot
    net.run(5 * ms)
    net.store('second', filename=filename)
    net.run(5 * ms)
    input_spike_indices = np.array(net['input_spikes'].i)
    input_spike_times = Quantity(net['input_spikes'].t, copy=True)
    v_values_full_sim = Quantity(net['statemonitor'].v[:, :], copy=True)

    net = create_net()
    net.restore(filename=filename)  # Go back to beginning
    net.run(10 * ms)
    assert_equal(input_spike_indices, net['input_spikes'].i)
    assert_equal(input_spike_times, net['input_spikes'].t)
    assert_equal(v_values_full_sim, net['statemonitor'].v[:, :])

    net = create_net()
    net.restore('second', filename=filename)  # Go back to middle
    net.run(5 * ms)
    assert_equal(input_spike_indices, net['input_spikes'].i)
    assert_equal(input_spike_times, net['input_spikes'].t)
    assert_equal(v_values_full_sim, net['statemonitor'].v[:, :])

    try:
        os.remove(filename)
    except OSError:
        pass
コード例 #3
0
ファイル: helper.py プロジェクト: nikhil-garg/brian2tools
def _prepare_identifiers(identifiers):
    """
    Helper function to filter out required identifiers and
    prepare them to use in standard dictionary format

    Parameters
    ----------
    identifiers : dict
        Dictionary of identifiers resolved by parent Group

    Returns
    -------
    clean_identifiers : dict
        Filtered identifiers to use with standard format
    """

    clean_identifiers = {}

    for (key, value) in identifiers.items():

        if isinstance(value, Constant):
            if key not in DEFAULT_CONSTANTS and key not in DEFAULT_UNITS:
                quant_identity = {key: Quantity(value.value, dim=value.dim)}
                clean_identifiers.update(quant_identity)
        # check if Function type
        elif isinstance(value, Function):
            # if TimedArray express it
            if isinstance(value, TimedArray):
                timed_arr = {
                    'name': value.name,
                    'values': Quantity(value.values, dim=value.dim),
                    'dt': Quantity(value.dt, dim=second),
                    'ndim': value.values.ndim,
                    'type': 'timedarray'
                }
                clean_identifiers.update({key: timed_arr})
            # else if custom function type
            elif key not in DEFAULT_FUNCTIONS:
                clean_identifiers.update({
                    key: {
                        'type': 'custom_func',
                        'arg_units': value._arg_units,
                        'arg_types': value._arg_types,
                        'return_type': value._return_type,
                        'return_unit': value._return_unit,
                    }
                })
        elif isinstance(value, Quantity):
            if key not in DEFAULT_UNITS:
                clean_identifiers.update({key: value})

    return clean_identifiers
コード例 #4
0
 def state(self, value):
     assert len(value) == len(self.constraints)
     for i, c in enumerate(self.constraints):
         dim = c.free.dimensions
         c.free = Quantity(value[i], dim=dim)
     for f in self.dependent_functions:
         f()
コード例 #5
0
 def best_error(self):
     if self._best_error is None:
         return None
     if self.use_units:
         error_dim = self.metric.get_dimensions(self.output_dim)
         return Quantity(self._best_error, dim=error_dim)
     else:
         return self._best_error
コード例 #6
0
 def best_params(self):
     if self._best_params is None:
         return None
     if self.use_units:
         params_with_units = {
             p: Quantity(v, dim=self.model[p].dim)
             for p, v in self._best_params.items()
         }
         return params_with_units
     else:
         return self._best_params
コード例 #7
0
        def _callback_wrapper(params, iter, resid, *args, **kwds):
            error = mean(resid**2)
            errors.append(error)
            if self.use_units:
                error_dim = self.output_dim**2 * get_dimensions(
                    normalization)**2
                all_errors = Quantity(errors, dim=error_dim)
                params = {
                    p: Quantity(val, dim=self.model[p].dim)
                    for p, val in params.items()
                }
            else:
                all_errors = array(errors)
                params = {p: float(val) for p, val in params.items()}
            tested_parameters.append(params)

            best_idx = argmin(errors)
            best_error = all_errors[best_idx]
            best_params = tested_parameters[best_idx]

            return callback_func(params, all_errors, best_params, best_error,
                                 iter)
コード例 #8
0
    def __init__(self,
                 model,
                 input_var,
                 input,
                 output_var,
                 output,
                 dt,
                 n_samples=30,
                 method=None,
                 reset=None,
                 refractory=False,
                 threshold=None,
                 level=0,
                 param_init=None,
                 t_start=0 * second):
        """Initialize the fitter."""
        super().__init__(dt, model, input, output, input_var, output_var,
                         n_samples, threshold, reset, refractory, method,
                         param_init)

        self.output = Quantity(output)
        self.output_ = array(output)

        if output_var not in self.model.names:
            raise NameError("%s is not a model variable" % output_var)
        if output.shape != input.shape:
            raise ValueError("Input and output must have the same size")

        # Replace input variable by TimedArray
        output_traces = TimedArray(output.transpose(), dt=dt)
        output_dim = get_dimensions(output)
        squared_output_dim = ('1' if output_dim is DIMENSIONLESS else repr(
            output_dim**2))
        error_eqs = Equations('total_error : {}'.format(squared_output_dim))
        self.model = self.model + error_eqs

        self.t_start = t_start

        if param_init:
            for param, val in param_init.items():
                if not (param in self.model.identifiers
                        or param in self.model.names):
                    raise ValueError("%s is not a model variable or an "
                                     "identifier in the model" % param)
            self.param_init = param_init

        self.simulator = None
コード例 #9
0
    def __init__(self,
                 model,
                 input,
                 output,
                 dt,
                 reset,
                 threshold,
                 input_var='I',
                 refractory=False,
                 n_samples=30,
                 method=None,
                 param_init=None,
                 use_units=True):
        """Initialize the fitter."""
        if method is None:
            method = 'exponential_euler'
        super().__init__(dt,
                         model,
                         input,
                         output,
                         input_var,
                         'v',
                         n_samples,
                         threshold,
                         reset,
                         refractory,
                         method,
                         param_init,
                         use_units=use_units)
        self.output = [Quantity(o) for o in output]
        self.output_ = [array(o) for o in output]
        self.output_var = 'spikes'

        if param_init:
            for param, val in param_init.items():
                if not (param in self.model.identifiers
                        or param in self.model.names):
                    raise ValueError("%s is not a model variable or an "
                                     "identifier in the model" % param)
            self.param_init = param_init

        self.simulator = None
コード例 #10
0
 def variableview_set_with_index_array(self,
                                       variableview,
                                       item,
                                       value,
                                       check_units=True):
     """
     Capture setters with particular,
     for eg. obj.var[0:2] = -78 * mV
     """
     # happens when dimensionless is passed like int/float
     if not isinstance(value, Quantity):
         value = Quantity(value)
     init_dict = {
         'source': variableview.group.name,
         'variable': variableview.name,
         'value': value,
         'type': 'initializer'
     }
     # check type is slice and True
     if type(item) == slice and item.start is None and item.stop is None:
         init_dict['index'] = True
     elif ((isinstance(item, int) or
            (isinstance(item, np.ndarray) and item.shape == ()))
           and value.size == 1):
         if self.array_cache.get(variableview.variable, None) is not None:
             self.array_cache[variableview.variable][item] = value
         init_dict['index'] = item
     else:
         # We have to calculate indices. This will not work for synaptic
         # variables
         try:
             init_dict['index'] = np.asarray(
                 variableview.indexing(item,
                                       index_var=variableview.index_var))
         except NotImplementedError:
             raise NotImplementedError(('Cannot set variable "%s" this way'
                                        'in device mode, '
                                        'try using string '
                                        'expressions') % variableview.name)
     self.initializers_connectors.append(init_dict)
コード例 #11
0
    def __init__(self,
                 model,
                 input_var,
                 input,
                 output_var,
                 output,
                 dt,
                 n_samples=30,
                 method=None,
                 reset=None,
                 refractory=False,
                 threshold=None,
                 param_init=None,
                 use_units=True):
        super().__init__(dt,
                         model,
                         input,
                         output,
                         input_var,
                         output_var,
                         n_samples,
                         threshold,
                         reset,
                         refractory,
                         method,
                         param_init,
                         use_units=use_units)
        self.output = Quantity(output)
        self.output_ = array(output)
        # We store the bounds set in TraceFitter.fit, so that Tracefitter.refine
        # can reuse them
        self.bounds = None

        if output_var not in self.model.names:
            raise NameError("%s is not a model variable" % output_var)
        if output.shape != input.shape:
            raise ValueError("Input and output must have the same size")
コード例 #12
0
def test_spike_neurongroup():
    """
    Test dictionary representation of spiking neuron
    """
    eqn = ''' dv/dt = (v_th - v) / tau : volt
              v_th = 900 * mV :volt
              v_rest = -70 * mV :volt
              tau :second (constant)'''

    tau = 10 * ms
    size = 10

    grp = NeuronGroup(size,
                      eqn,
                      threshold='v > v_th',
                      reset='v = v_rest',
                      refractory=2 * ms)

    neuron_dict = collect_NeuronGroup(grp, get_local_namespace(0))

    assert neuron_dict['N'] == size
    assert neuron_dict['user_method'] is None

    eqns = Equations(eqn)
    assert neuron_dict['equations']['v']['type'] == DIFFERENTIAL_EQUATION
    assert neuron_dict['equations']['v']['unit'] == volt
    assert neuron_dict['equations']['v']['var_type'] == FLOAT
    assert neuron_dict['equations']['v']['expr'] == eqns['v'].expr.code

    assert neuron_dict['equations']['v_th']['type'] == SUBEXPRESSION
    assert neuron_dict['equations']['v_th']['unit'] == volt
    assert neuron_dict['equations']['v_th']['var_type'] == FLOAT
    assert neuron_dict['equations']['v_th']['expr'] == eqns['v_th'].expr.code

    assert neuron_dict['equations']['v_rest']['type'] == SUBEXPRESSION
    assert neuron_dict['equations']['v_rest']['unit'] == volt
    assert neuron_dict['equations']['v_rest']['var_type'] == FLOAT

    assert neuron_dict['equations']['tau']['type'] == PARAMETER
    assert neuron_dict['equations']['tau']['unit'] == second
    assert neuron_dict['equations']['tau']['var_type'] == FLOAT
    assert neuron_dict['equations']['tau']['flags'][0] == 'constant'

    thresholder = grp.thresholder['spike']
    neuron_events = neuron_dict['events']['spike']
    assert neuron_events['threshold']['code'] == 'v > v_th'
    assert neuron_events['threshold']['when'] == thresholder.when
    assert neuron_events['threshold']['order'] == thresholder.order
    assert neuron_events['threshold']['dt'] == grp.clock.dt

    resetter = grp.resetter['spike']
    assert neuron_events['reset']['code'] == 'v = v_rest'
    assert neuron_events['reset']['when'] == resetter.when
    assert neuron_events['reset']['order'] == resetter.order
    assert neuron_events['reset']['dt'] == resetter.clock.dt

    assert neuron_dict['events']['spike']['refractory'] == Quantity(2 * ms)

    # example 2 with threshold but no reset

    start_scope()
    grp2 = NeuronGroup(size,
                       '''dv/dt = (100 * mV - v) / tau_n : volt''',
                       threshold='v > 800 * mV',
                       method='euler')
    tau_n = 10 * ms

    neuron_dict2 = collect_NeuronGroup(grp2, get_local_namespace(0))
    thresholder = grp2.thresholder['spike']
    neuron_events = neuron_dict2['events']['spike']
    assert neuron_events['threshold']['code'] == 'v > 800 * mV'
    assert neuron_events['threshold']['when'] == thresholder.when
    assert neuron_events['threshold']['order'] == thresholder.order
    assert neuron_events['threshold']['dt'] == grp2.clock.dt

    with pytest.raises(KeyError):
        neuron_dict2['events']['spike']['reset']
        neuron_dict2['events']['spike']['refractory']
コード例 #13
0
    def refine(self,
               params=None,
               t_start=None,
               normalization=None,
               callback='text',
               calc_gradient=False,
               optimize=True,
               level=0,
               **kwds):
        """
        Refine the fitting results with a sequentially operating minimization
        algorithm. Uses the `lmfit <https://lmfit.github.io/lmfit-py/>`_
        package which itself makes use of
        `scipy.optimize <https://docs.scipy.org/doc/scipy/reference/optimize.html>`_.
        Has to be called after `~.TraceFitter.fit`, but a call with
        ``n_rounds=0`` is enough.

        Parameters
        ----------
        params : dict, optional
            A dictionary with the parameters to use as a starting point for the
            refinement. If not given, the best parameters found so far by
            `~.TraceFitter.fit` will be used.
        t_start : `~brian2.units.fundamentalunits.Quantity`, optional
            Initial simulation/model time that should be ignored for the error
            calculation. If not set, will reuse the `t_start` value from the
            previously used metric.
        normalization : float, optional
            A normalization term that will be used rescale results before
            handing them to the optimization algorithm. Can be useful if the
            algorithm makes assumptions about the scale of errors, e.g. if the
            size of steps in the parameter space depends on the absolute value
            of the error. The difference between simulated and target traces
            will be divided by this value. If not set, will reuse the
            `normalization` value from the previously used metric.
        callback: `str` or `~typing.Callable`
            Either the name of a provided callback function (``text`` or
            ``progressbar``), or a custom feedback function
            ``func(parameters, errors, best_parameters, best_error, index)``.
            If this function returns ``True`` the fitting execution is
            interrupted.
        calc_gradient: bool, optional
            Whether to add "sensitivity variables" to the equation that track
            the sensitivity of the equation variables to the parameters. This
            information will be used to pass the local gradient of the error
            with respect to the parameters to the optimization function. This
            can lead to much faster convergence than with an estimated gradient
            but comes at the expense of additional computation. Defaults to
            ``False``.
        optimize : bool, optional
            Whether to remove sensitivity variables from the equations that do
            not evolve if initialized to zero (e.g. ``dS_x_y/dt = -S_x_y/tau``
            would be removed). This avoids unnecessary computation but will fail
            in the rare case that such a sensitivity variable needs to be
            initialized to a non-zero value. Only taken into account if
            ``calc_gradient`` is ``True``. Defaults to ``True``.
        level : int, optional
            How much farther to go down in the stack to find the namespace.
        kwds
            Additional arguments can overwrite the bounds for individual
            parameters (if not given, the bounds previously specified in the
            call to `~.TraceFitter.fit` will be used). All other arguments will
            be passed on to `.lmfit.minimize` and can be used to e.g. change the
            method, or to specify method-specific arguments.

        Returns
        -------
        parameters : dict
            The parameters at the end of the optimization process as a
            dictionary.
        result : `.lmfit.MinimizerResult`
            The result of the optimization process.

        Notes
        -----
        The default method used by `lmfit` is least-squares minimization using
        a Levenberg-Marquardt method. Note that there is no support for
        specifying a `Metric`, the given output trace(s) will be subtracted
        from the simulated trace(s) and passed on to the minimization algorithm.

        This method always uses the runtime mode, independent of the selection
        of the current device.
        """
        try:
            import lmfit
        except ImportError:
            raise ImportError('Refinement needs the "lmfit" package.')
        if params is None:
            if self.best_params is None:
                raise TypeError('You need to either specify parameters or run '
                                'the fit function first.')
            params = self.best_params

        if t_start is None:
            t_start = getattr(self.metric, 't_start', 0 * second)
        if normalization is None:
            normalization = getattr(self.metric, 'normalization', 1.)
        else:
            normalization = 1 / normalization

        callback_func = callback_setup(callback, None)

        # Set up Parameter objects
        parameters = lmfit.Parameters()
        for param_name in self.parameter_names:
            if param_name not in kwds:
                if self.bounds is None:
                    raise TypeError(
                        'You need to either specify bounds for all '
                        'parameters or run the fit function first.')
                min_bound, max_bound = self.bounds[param_name]
            else:
                min_bound, max_bound = kwds.pop(param_name)
            parameters.add(param_name,
                           value=array(params[param_name]),
                           min=array(min_bound),
                           max=array(max_bound))

        self.simulator = self.setup_simulator('refine',
                                              self.n_traces,
                                              output_var=self.output_var,
                                              param_init=self.param_init,
                                              calc_gradient=calc_gradient,
                                              optimize=optimize,
                                              level=level + 1)

        t_start_steps = int(round(t_start / self.dt))

        def _calc_error(params):
            param_dic = get_param_dic(
                [params[p] for p in self.parameter_names],
                self.parameter_names, self.n_traces, 1)
            self.simulator.run(self.duration,
                               param_dic,
                               self.parameter_names,
                               name='refine')
            trace = getattr(self.simulator.monitor, self.output_var + '_')
            residual = trace[:, t_start_steps:] - self.output_[:,
                                                               t_start_steps:]
            return residual.flatten() * normalization

        def _calc_gradient(params):
            residuals = []
            for name in self.parameter_names:
                trace = getattr(self.simulator.monitor,
                                f'S_{self.output_var}_{name}_')
                residual = trace[:, t_start_steps:]
                residuals.append(residual.flatten() * normalization)
            gradient = array(residuals)
            return gradient.T

        tested_parameters = []
        errors = []

        def _callback_wrapper(params, iter, resid, *args, **kwds):
            error = mean(resid**2)
            errors.append(error)
            if self.use_units:
                error_dim = self.output_dim**2 * get_dimensions(
                    normalization)**2
                all_errors = Quantity(errors, dim=error_dim)
                params = {
                    p: Quantity(val, dim=self.model[p].dim)
                    for p, val in params.items()
                }
            else:
                all_errors = array(errors)
                params = {p: float(val) for p, val in params.items()}
            tested_parameters.append(params)

            best_idx = argmin(errors)
            best_error = all_errors[best_idx]
            best_params = tested_parameters[best_idx]

            return callback_func(params, all_errors, best_params, best_error,
                                 iter)

        assert 'Dfun' not in kwds
        if calc_gradient:
            kwds.update({'Dfun': _calc_gradient})
        if 'iter_cb' in kwds:
            # Use the given callback but raise a warning if callback is not
            # set to None
            if callback is not None:
                logger.warn(
                    'The iter_cb keyword has been specified together '
                    f'with callback={callback!r}. Only the iter_cb '
                    'callback will be used. Use the standard '
                    'callback mechanism or set callback=None to '
                    'remove this warning.',
                    name_suffix='iter_cb_callback')
            iter_cb = kwds.pop('iter_cb')
        else:
            iter_cb = _callback_wrapper
        result = lmfit.minimize(_calc_error,
                                parameters,
                                iter_cb=iter_cb,
                                **kwds)

        if self.use_units:
            param_dict = {
                p: Quantity(float(val), dim=self.model[p].dim)
                for p, val in result.params.items()
            }
        else:
            param_dict = {p: float(val) for p, val in result.params.items()}

        return param_dict, result
コード例 #14
0
    def results(self, format='list', use_units=None):
        """
        Returns all of the gathered results (parameters and errors).
        In one of the 3 formats: 'dataframe', 'list', 'dict'.

        Parameters
        ----------
        format: str
            The desired output format. Currently supported: ``dataframe``,
            ``list``, or ``dict``.
        use_units: bool, optional
            Whether to use units in the results. If not specified, defaults to
            `.Tracefitter.use_units`, i.e. the value that was specified when
            the `.Tracefitter` object was created (``True`` by default).

        Returns
        -------
        object
            'dataframe': returns pandas `~pandas.DataFrame` without units
            'list': list of dictionaries
            'dict': dictionary of lists
        """
        if use_units is None:
            use_units = self.use_units
        names = list(self.parameter_names)

        params = array(self.optimizer.tested_parameters)
        params = params.reshape(-1, params.shape[-1])

        if use_units:
            error_dim = self.metric.get_dimensions(self.output_dim)
            errors = Quantity(array(self.optimizer.errors).flatten(),
                              dim=error_dim)
        else:
            errors = array(array(self.optimizer.errors).flatten())

        dim = self.model.dimensions

        if format == 'list':
            res_list = []
            for j in arange(0, len(params)):
                temp_data = params[j]
                res_dict = dict()

                for i, n in enumerate(names):
                    if use_units:
                        res_dict[n] = Quantity(temp_data[i], dim=dim[n])
                    else:
                        res_dict[n] = float(temp_data[i])
                res_dict['error'] = errors[j]
                res_list.append(res_dict)

            return res_list

        elif format == 'dict':
            res_dict = dict()
            for i, n in enumerate(names):
                if use_units:
                    res_dict[n] = Quantity(params[:, i], dim=dim[n])
                else:
                    res_dict[n] = array(params[:, i])

            res_dict['error'] = errors
            return res_dict

        elif format == 'dataframe':
            from pandas import DataFrame
            if use_units:
                logger.warn(
                    'Results in dataframes do not support units. '
                    'Specify "use_units=False" to avoid this warning.',
                    name_suffix='dataframe_units')
            data = concatenate((params, array(errors)[None, :].transpose()),
                               axis=1)
            return DataFrame(data=data, columns=names + ['error'])
コード例 #15
0
    def fit(self,
            optimizer,
            metric=None,
            n_rounds=1,
            callback='text',
            restart=False,
            online_error=False,
            level=0,
            **params):
        """
        Run the optimization algorithm for given amount of rounds with given
        number of samples drawn. Return best set of parameters and
        corresponding error.

        Parameters
        ----------
        optimizer: `~.Optimizer` children
            Child of Optimizer class, specific for each library.
        metric: `~.Metric` children
            Child of Metric class, specifies optimization metric
        n_rounds: int
            Number of rounds to optimize over (feedback provided over each
            round).
        callback: `str` or `~typing.Callable`
            Either the name of a provided callback function (``text`` or
            ``progressbar``), or a custom feedback function
            ``func(parameters, errors, best_parameters, best_error, index)``.
            If this function returns ``True`` the fitting execution is
            interrupted.
        restart: bool
            Flag that reinitializes the Fitter to reset the optimization.
            With restart True user is allowed to change optimizer/metric.
        online_error: bool, optional
            Whether to calculate the squared error between target trace and
            simulated trace online. Defaults to ``False``.
         level : `int`, optional
            How much farther to go down in the stack to find the namespace.
        **params
            bounds for each parameter
        Returns
        -------
        best_results : dict
            dictionary with best parameter set
        error: float
            error value for best parameter set
        """
        if not (isinstance(metric, Metric) or metric is None):
            raise TypeError("metric has to be a child of class Metric or None "
                            "for OnlineTraceFitter")

        if not (isinstance(optimizer, Optimizer)) or optimizer is None:
            raise TypeError("metric has to be a child of class Optimizer")

        if self.metric is not None and restart is False:
            if metric is not self.metric:
                raise Exception("You can not change the metric between fits")

        if self.optimizer is not None and restart is False:
            if optimizer is not self.optimizer:
                raise Exception(
                    "You can not change the optimizer between fits")

        if self.optimizer is None or restart is True:
            optimizer.initialize(self.parameter_names,
                                 popsize=self.n_samples,
                                 **params)

        self.optimizer = optimizer
        self.metric = metric

        callback = callback_setup(callback, n_rounds)

        # Check whether we can reuse the current simulator or whether we have
        # to create a new one (only relevant for standalone, but does not hurt
        # for runtime)
        if self.simulator is None or self.simulator.current_net != 'fit':
            self.simulator = self.setup_simulator('fit',
                                                  self.n_neurons,
                                                  output_var=self.output_var,
                                                  online_error=online_error,
                                                  param_init=self.param_init,
                                                  level=level + 1)

        # Run Optimization Loop
        for index in range(n_rounds):
            best_params, parameters, errors = self.optimization_iter(
                optimizer, metric)
            self._best_error = nanmin(self.optimizer.errors)
            # create output variables
            self._best_params = make_dic(self.parameter_names, best_params)
            if self.use_units:
                if self.output_var == 'spikes':
                    output_dim = DIMENSIONLESS
                else:
                    output_dim = self.output_dim
                # Correct the units for the normalization factor
                error_dim = self.metric.get_normalized_dimensions(output_dim)
                best_error = Quantity(float(self.best_error), dim=error_dim)
                errors = Quantity(errors, dim=error_dim)
                param_dicts = [{
                    p: Quantity(v, dim=self.model[p].dim)
                    for p, v in zip(self.parameter_names, one_param_set)
                } for one_param_set in parameters]
            else:
                param_dicts = [{
                    p: v
                    for p, v in zip(self.parameter_names, one_param_set)
                } for one_param_set in parameters]
                best_error = self.best_error

            if callback(param_dicts, errors, self.best_params, best_error,
                        index) is True:
                break

        return self.best_params, self.best_error