Ejemplo n.º 1
0
class NeuronGroup(Group, SpikeSource):
    '''
    A group of neurons.

    
    Parameters
    ----------
    N : int
        Number of neurons in the group.
    model : (str, `Equations`)
        The differential equations defining the group
    method : (str, function), optional
        The numerical integration method. Either a string with the name of a
        registered method (e.g. "euler") or a function that receives an
        `Equations` object and returns the corresponding abstract code. If no
        method is specified, a suitable method will be chosen automatically.
    threshold : str, optional
        The condition which produces spikes. Should be a single line boolean
        expression.
    reset : str, optional
        The (possibly multi-line) string with the code to execute on reset.
    refractory : {str, `Quantity`}, optional
        Either the length of the refractory period (e.g. ``2*ms``), a string
        expression that evaluates to the length of the refractory period
        after each spike (e.g. ``'(1 + rand())*ms'``), or a string expression
        evaluating to a boolean value, given the condition under which the
        neuron stays refractory after a spike (e.g. ``'v > -20*mV'``)
    namespace: dict, optional
        A dictionary mapping variable/function names to the respective objects.
        If no `namespace` is given, the "implicit" namespace, consisting of
        the local and global namespace surrounding the creation of the class,
        is used.
    dtype : (`dtype`, `dict`), optional
        The `numpy.dtype` that will be used to store the values, or a
        dictionary specifying the type for variable names. If a value is not
        provided for a variable (or no value is provided at all), the preference
        setting `core.default_scalar_dtype` is used.
    codeobj_class : class, optional
        The `CodeObject` class to run code with.
    clock : Clock, optional
        The update clock to be used, or defaultclock if not specified.
    name : str, optional
        A unique name for the group, otherwise use ``neurongroup_0``, etc.
        
    Notes
    -----
    `NeuronGroup` contains a `StateUpdater`, `Thresholder` and `Resetter`, and
    these are run at the 'groups', 'thresholds' and 'resets' slots (i.e. the
    values of `Scheduler.when` take these values). The `Scheduler.order`
    attribute is set to 0 initially, but this can be modified using the
    attributes `state_updater`, `thresholder` and `resetter`.    
    '''
    def __init__(self, N, model, method=None,
                 threshold=None,
                 reset=None,
                 refractory=False,
                 namespace=None,
                 dtype=None,
                 clock=None, name='neurongroup*',
                 codeobj_class=None):
        Group.__init__(self, when=clock, name=name)

        self.codeobj_class = codeobj_class

        try:
            self._N = N = int(N)
        except ValueError:
            if isinstance(N, str):
                raise TypeError("First NeuronGroup argument should be size, not equations.")
            raise
        if N < 1:
            raise ValueError("NeuronGroup size should be at least 1, was " + str(N))

        self.start = 0
        self.stop = self._N

        ##### Prepare and validate equations
        if isinstance(model, basestring):
            model = Equations(model)
        if not isinstance(model, Equations):
            raise TypeError(('model has to be a string or an Equations '
                             'object, is "%s" instead.') % type(model))

        # Check flags
        model.check_flags({DIFFERENTIAL_EQUATION: ('unless refractory'),
                           PARAMETER: ('constant')})

        # add refractoriness
        if refractory is not False:
            model = add_refractoriness(model)
        self.equations = model
        uses_refractoriness = len(model) and any(['unless refractory' in eq.flags
                                                  for eq in model.itervalues()
                                                  if eq.type == DIFFERENTIAL_EQUATION])

        logger.debug("Creating NeuronGroup of size {self._N}, "
                     "equations {self.equations}.".format(self=self))

        # Setup the namespace
        self.namespace = create_namespace(namespace)

        # Setup variables
        self._create_variables(dtype)

        # All of the following will be created in before_run
        
        #: The threshold condition
        self.threshold = threshold
        
        #: The reset statement(s)
        self.reset = reset

        #: The refractory condition or timespan
        self._refractory = refractory
        if uses_refractoriness and refractory is False:
            logger.warn('Model equations use the "unless refractory" flag but '
                        'no refractory keyword was given.', 'no_refractory')

        #: The state update method selected by the user
        self.method_choice = method
        
        #: Performs thresholding step, sets the value of `spikes`
        self.thresholder = None
        if self.threshold is not None:
            self.thresholder = Thresholder(self)
            

        #: Resets neurons which have spiked (`spikes`)
        self.resetter = None
        if self.reset is not None:
            self.resetter = Resetter(self)

        # We try to run a before_run already now. This might fail because of an
        # incomplete namespace but if the namespace is already complete we
        # can spot unit errors in the equation already here.
        try:
            self.before_run(None)
        except KeyError:
            pass

        #: Performs numerical integration step
        self.state_updater = StateUpdater(self, method)

        # Creation of contained_objects that do the work
        self.contained_objects.append(self.state_updater)
        if self.thresholder is not None:
            self.contained_objects.append(self.thresholder)
        if self.resetter is not None:
            self.contained_objects.append(self.resetter)

        if refractory is not False:
            # Set the refractoriness information
            self.variables['lastspike'].set_value(-np.inf*second)
            self.variables['not_refractory'].set_value(True)

        # Activate name attribute access
        self._enable_group_attributes()


    def __len__(self):
        '''
        Return number of neurons in the group.
        '''
        return self.N

    @property
    def spikes(self):
        '''
        The spikes returned by the most recent thresholding operation.
        '''
        # Note that we have to directly access the ArrayVariable object here
        # instead of using the Group mechanism by accessing self._spikespace
        # Using the latter would cut _spikespace to the length of the group
        spikespace = self.variables['_spikespace'].get_value()
        return spikespace[:spikespace[-1]]

    def __getitem__(self, item):
        if not isinstance(item, slice):
            raise TypeError('Subgroups can only be constructed using slicing syntax')
        start, stop, step = item.indices(self._N)
        if step != 1:
            raise IndexError('Subgroups have to be contiguous')
        if start >= stop:
            raise IndexError('Illegal start/end values for subgroup, %d>=%d' %
                             (start, stop))

        return Subgroup(self, start, stop)

    def runner(self, code, when=None, name=None):
        '''
        Returns a `CodeRunner` that runs abstract code in the groups namespace
        
        Parameters
        ----------
        code : str
            The abstract code to run.
        when : `Scheduler`, optional
            When to run, by default in the 'start' slot with the same clock as
            the group.
        name : str, optional
            A unique name, if non is given the name of the group appended with
            'runner', 'runner_1', etc. will be used. If a name is given
            explicitly, it will be used as given (i.e. the group name will not
            be prepended automatically).
        '''
        if when is None:  # TODO: make this better with default values
            when = Scheduler(clock=self.clock)
        else:
            raise NotImplementedError

        if name is None:
            name = self.name + '_runner*'

        runner = GroupCodeRunner(self, self.language.template_state_update,
                                 code=code, name=name, when=when)
        return runner

    def _create_variables(self, dtype=None):
        '''
        Create the variables dictionary for this `NeuronGroup`, containing
        entries for the equation variables and some standard entries.
        '''
        self.variables = Variables(self)
        self.variables.add_clock_variables(self.clock)
        self.variables.add_constant('N', Unit(1), self._N)

        if dtype is None:
            dtype = defaultdict(lambda: brian_prefs['core.default_scalar_dtype'])
        elif isinstance(dtype, np.dtype):
            dtype = defaultdict(lambda: dtype)
        elif not hasattr(dtype, '__getitem__'):
            raise TypeError(('Cannot use type %s as dtype '
                             'specification') % type(dtype))

        # Standard variables always present
        self.variables.add_array('_spikespace', unit=Unit(1), size=self._N+1,
                                 dtype=np.int32, constant=False)
        # Add the special variable "i" which can be used to refer to the neuron index
        self.variables.add_arange('i', size=self._N, constant=True,
                                  read_only=True)

        for eq in self.equations.itervalues():
            if eq.type in (DIFFERENTIAL_EQUATION, PARAMETER):
                constant = ('constant' in eq.flags)
                self.variables.add_array(eq.varname, size=self._N,
                                         unit=eq.unit, dtype=dtype[eq.varname],
                                         constant=constant, is_bool=eq.is_bool)
            elif eq.type == STATIC_EQUATION:
                self.variables.add_subexpression(eq.varname, unit=eq.unit,
                                                 expr=str(eq.expr),
                                                 is_bool=eq.is_bool)
            else:
                raise AssertionError('Unknown type of equation: ' + eq.eq_type)

        # Stochastic variables
        for xi in self.equations.stochastic_variables:
            self.variables.add_auxiliary_variable(xi, unit=second**-0.5)

    def before_run(self, namespace):
    
        # Update the namespace information in the variables in case the
        # namespace was not specified explicitly defined at creation time
        # Note that values in the explicit namespace might still change
        # between runs, but the Subexpression stores a reference to 
        # self.namespace so these changes are taken into account automatically
        if not self.namespace.is_explicit:
            for var in self.variables.itervalues():
                if isinstance(var, Subexpression):
                    var.additional_namespace = namespace

        # Check units
        self.equations.check_units(self.namespace, self.variables,
                                   namespace)
    
    def _repr_html_(self):
        text = [r'NeuronGroup "%s" with %d neurons.<br>' % (self.name, self._N)]
        text.append(r'<b>Model:</b><nr>')
        text.append(sympy.latex(self.equations))
        text.append(r'<b>Integration method:</b><br>')
        text.append(sympy.latex(self.state_updater.method)+'<br>')
        if self.threshold is not None:
            text.append(r'<b>Threshold condition:</b><br>')
            text.append('<code>%s</code><br>' % str(self.threshold))
            text.append('')
        if self.reset is not None:
            text.append(r'<b>Reset statement:</b><br>')            
            text.append(r'<code>%s</code><br>' % str(self.reset))
            text.append('')
                    
        return '\n'.join(text)