Exemplo n.º 1
0
 def __setattr__(self, name, val, level=0):
     # attribute access is switched off until this attribute is created by
     # _enable_group_attributes
     if not hasattr(
             self,
             '_group_attribute_access_active') or name in self.__dict__:
         object.__setattr__(self, name, val)
     elif name in self.variables:
         var = self.variables[name]
         if not isinstance(val, basestring):
             fail_for_dimension_mismatch(
                 val, var.unit, 'Incorrect units for setting %s' % name)
         if var.read_only:
             raise TypeError('Variable %s is read-only.' % name)
         # Make the call X.var = ... equivalent to X.var[:] = ...
         var.get_addressable_value_with_unit(name,
                                             self).set_item(slice(None),
                                                            val,
                                                            level=level + 1)
     elif len(name) and name[-1] == '_' and name[:-1] in self.variables:
         # no unit checking
         var = self.variables[name[:-1]]
         if var.read_only:
             raise TypeError('Variable %s is read-only.' % name[:-1])
         # Make the call X.var = ... equivalent to X.var[:] = ...
         var.get_addressable_value(name[:-1],
                                   self).set_item(slice(None),
                                                  val,
                                                  level=level + 1)
     else:
         object.__setattr__(self, name, val)
Exemplo n.º 2
0
 def __setattr__(self, name, val):
     # attribute access is switched off until this attribute is created by
     # _enable_group_attributes
     if not hasattr(self, '_group_attribute_access_active') or name in self.__dict__:
         object.__setattr__(self, name, val)
     elif name in self.variables:
         var = self.variables[name]
         if not isinstance(val, basestring):
             fail_for_dimension_mismatch(val, var.unit,
                                         'Incorrect units for setting %s' % name)
         if var.read_only:
             raise TypeError('Variable %s is read-only.' % name)
         # Make the call X.var = ... equivalent to X.var[:] = ...
         var.get_addressable_value_with_unit(name, self).set_item(slice(None),
                                                                  val,
                                                                  level=1)
     elif len(name) and name[-1]=='_' and name[:-1] in self.variables:
         # no unit checking
         var = self.variables[name[:-1]]
         if var.read_only:
             raise TypeError('Variable %s is read-only.' % name[:-1])
         # Make the call X.var = ... equivalent to X.var[:] = ...
         var.get_addressable_value(name[:-1], self).set_item(slice(None),
                                                             val,
                                                             level=1)
     else:
         object.__setattr__(self, name, val)
Exemplo n.º 3
0
def check_dimensions(expression, dimensions, variables):
    '''
    Compares the physical dimensions of an expression to expected dimensions in
    a given namespace.

    Parameters
    ----------
    expression : str
        The expression to evaluate.
    dimensions : `Dimension`
        The expected physical dimensions for the `expression`.
    variables : dict
        Dictionary of all variables (including external constants) used in
        the `expression`.

    Raises
    ------
    KeyError
        In case on of the identifiers cannot be resolved.
    DimensionMismatchError
        If an unit mismatch occurs during the evaluation.
    '''
    expr_dims = parse_expression_dimensions(expression, variables)
    err_msg = ('Expression {expr} does not have the '
               'expected unit {expected}').format(expr=expression.strip(),
                                                  expected=repr(get_unit(dimensions)))
    fail_for_dimension_mismatch(expr_dims, dimensions, err_msg)
Exemplo n.º 4
0
def check_dimensions(expression, dimensions, variables):
    '''
    Compares the physical dimensions of an expression to expected dimensions in
    a given namespace.

    Parameters
    ----------
    expression : str
        The expression to evaluate.
    dimensions : `Dimension`
        The expected physical dimensions for the `expression`.
    variables : dict
        Dictionary of all variables (including external constants) used in
        the `expression`.

    Raises
    ------
    KeyError
        In case on of the identifiers cannot be resolved.
    DimensionMismatchError
        If an unit mismatch occurs during the evaluation.
    '''
    expr_dims = parse_expression_dimensions(expression, variables)
    err_msg = ('Expression {expr} does not have the '
               'expected unit {expected}').format(expr=expression.strip(),
                                                  expected=repr(get_unit(dimensions)))
    fail_for_dimension_mismatch(expr_dims, dimensions, err_msg)
Exemplo n.º 5
0
 def wrapper_function(*args):
     if not len(args) == len(self._function._arg_units):
         raise ValueError(('Function %s got %d arguments, '
                           'expected %d') % (self._function.pyfunc.__name__, len(args),
                                             len(self._function._arg_units)))
     new_args = [Quantity.with_dimensions(arg, get_dimensions(arg_unit))
                 for arg, arg_unit in zip(args, self._function._arg_units)]
     result = orig_func(*new_args)
     return_unit = self._function._return_unit
     if return_unit is 1 or return_unit.dim is DIMENSIONLESS:
         fail_for_dimension_mismatch(result,
                                     return_unit,
                                     'The function %s returned '
                                     '{value}, but it was expected '
                                     'to return a dimensionless '
                                     'quantity' % orig_func.__name__,
                                     value=result)
     else:
         fail_for_dimension_mismatch(result,
                                     return_unit,
                                     ('The function %s returned '
                                      '{value}, but it was expected '
                                      'to return a quantity with '
                                      'units %r') % (orig_func.__name__,
                                                     return_unit),
                                     value=result)
     return np.asarray(result)
Exemplo n.º 6
0
def check_unit(expression, unit, namespace, variables):
    '''
    Evaluates the unit for an expression in a given namespace.
    
    Parameters
    ----------
    expression : str
        The expression to evaluate.
    namespace : dict-like
        The namespace of external variables.
    variables : dict of `Variable` objects
        The information about the internal variables
    
    Raises
    ------
    KeyError
        In case on of the identifiers cannot be resolved.
    DimensionMismatchError
        If an unit mismatch occurs during the evaluation.
    
    See Also
    --------
    unit_from_expression
    '''
    expr_unit = parse_expression_unit(expression, namespace, variables)
    fail_for_dimension_mismatch(expr_unit, unit, ('Expression %s does not '
                                                  'have the expected units' %
                                                  expression))
Exemplo n.º 7
0
def check_unit(expression, unit, variables):
    '''
    Compares the unit for an expression to an expected unit in a given
    namespace.
    
    Parameters
    ----------
    expression : str
        The expression to evaluate.
    unit : `Unit`
        The expected unit for the `expression`.
    variables : dict
        Dictionary of all variables (including external constants) used in
        the `expression`.
    
    Raises
    ------
    KeyError
        In case on of the identifiers cannot be resolved.
    DimensionMismatchError
        If an unit mismatch occurs during the evaluation.
    '''
    expr_unit = parse_expression_unit(expression, variables)
    fail_for_dimension_mismatch(expr_unit, unit, ('Expression %s does not '
                                                  'have the expected unit %r') %
                                                  (expression.strip(), unit))
Exemplo n.º 8
0
def check_unit(expression, unit, variables):
    '''
    Compares the unit for an expression to an expected unit in a given
    namespace.
    
    Parameters
    ----------
    expression : str
        The expression to evaluate.
    unit : `Unit`
        The expected unit for the `expression`.
    variables : dict
        Dictionary of all variables (including external constants) used in
        the `expression`.
    
    Raises
    ------
    KeyError
        In case on of the identifiers cannot be resolved.
    DimensionMismatchError
        If an unit mismatch occurs during the evaluation.
    '''
    expr_unit = parse_expression_unit(expression, variables)
    fail_for_dimension_mismatch(expr_unit, unit, ('Expression %s does not '
                                                  'have the expected units' %
                                                  expression))
Exemplo n.º 9
0
def check_unit(expression, unit, namespace, variables):
    '''
    Evaluates the unit for an expression in a given namespace.
    
    Parameters
    ----------
    expression : str
        The expression to evaluate.
    namespace : dict-like
        The namespace of external variables.
    variables : dict of `Variable` objects
        The information about the internal variables
    
    Raises
    ------
    KeyError
        In case on of the identifiers cannot be resolved.
    DimensionMismatchError
        If an unit mismatch occurs during the evaluation.
    
    See Also
    --------
    unit_from_expression
    '''
    expr_unit = parse_expression_unit(expression, namespace, variables)
    fail_for_dimension_mismatch(expr_unit, unit,
                                ('Expression %s does not '
                                 'have the expected units' % expression))
Exemplo n.º 10
0
 def __setattr__(self, name, val, level=0):
     # attribute access is switched off until this attribute is created by
     # _enable_group_attributes
     if not hasattr(self, '_group_attribute_access_active') or name in self.__dict__:
         object.__setattr__(self, name, val)
     elif name in self.variables:
         var = self.variables[name]
         if not isinstance(val, basestring):
             if var.unit.dim is DIMENSIONLESS:
                 fail_for_dimension_mismatch(val, var.unit,
                                             ('%s should be set with a '
                                              'dimensionless value, but got '
                                              '{value}') % name,
                                             value=val)
             else:
                 fail_for_dimension_mismatch(val, var.unit,
                                             ('%s should be set with a '
                                              'value with units %r, but got '
                                              '{value}') % (name, var.unit),
                                             value=val)
         if var.read_only:
             raise TypeError('Variable %s is read-only.' % name)
         # Make the call X.var = ... equivalent to X.var[:] = ...
         var.get_addressable_value_with_unit(name, self).set_item(slice(None),
                                                                  val,
                                                                  level=level+1)
     elif len(name) and name[-1]=='_' and name[:-1] in self.variables:
         # no unit checking
         var = self.variables[name[:-1]]
         if var.read_only:
             raise TypeError('Variable %s is read-only.' % name[:-1])
         # Make the call X.var = ... equivalent to X.var[:] = ...
         var.get_addressable_value(name[:-1], self).set_item(slice(None),
                                                             val,
                                                             level=level+1)
     elif hasattr(self, name) or name.startswith('_'):
         object.__setattr__(self, name, val)
     else:
         # Try to suggest the correct name in case of a typo
         checker = SpellChecker([varname for varname, var in self.variables.iteritems()
                                 if not (varname.startswith('_') or var.read_only)])
         if name.endswith('_'):
             suffix = '_'
             name = name[:-1]
         else:
             suffix = ''
         error_msg = 'Could not find a state variable with name "%s".' % name
         suggestions = checker.suggest(name)
         if len(suggestions) == 1:
             suggestion, = suggestions
             error_msg += ' Did you mean to write "%s%s"?' % (suggestion,
                                                              suffix)
         elif len(suggestions) > 1:
             error_msg += (' Did you mean to write any of the following: %s ?' %
                           (', '.join(['"%s%s"' % (suggestion, suffix)
                                       for suggestion in suggestions])))
         error_msg += (' Use the add_attribute method if you intend to add '
                       'a new attribute to the object.')
         raise AttributeError(error_msg)
Exemplo n.º 11
0
    def scale_array_code(self, diff_vars, method_options):
        """
        Return code for definition of ``_GSL_scale_array`` in generated code.

        Parameters
        ----------
        diff_vars : dict
            dictionary with variable name (str) as key and differential variable
            index (int) as value
        method_options : dict
            dictionary containing integrator settings

        Returns
        -------
        code : str
            full code describing a function returning a array containing doubles
            with the absolute errors for each differential variable (according
            to their assigned index in the GSL StateUpdater)
        """
        # get scale values per variable from method_options
        abs_per_var = method_options['absolute_error_per_variable']
        abs_default = method_options['absolute_error']

        if not isinstance(abs_default, float):
            raise TypeError(f"The absolute_error key in method_options should be "
                            f"a float. Was type {type(abs_default)}")

        if abs_per_var is None:
            diff_scale = {var: float(abs_default) for var in list(diff_vars.keys())}
        elif isinstance(abs_per_var, dict):
            diff_scale = {}
            for var, error in list(abs_per_var.items()):
                # first do some checks on input
                if not var in diff_vars:
                    if not var in self.variables:
                        raise KeyError(f"absolute_error specified for variable that "
                                       f"does not exist: {var}")
                    else:
                        raise KeyError(f"absolute_error specified for variable that is "
                                       f"not being integrated: {var}")
                fail_for_dimension_mismatch(error, self.variables[var],
                                            f"Unit of absolute_error_per_variable "
                                            f"for variable {var} does not match "
                                            f"unit of variable itself")
                # if all these are passed we can add the value for error in base units
                diff_scale[var] = float(error)
            # set the variables that are not mentioned to default value
            for var in list(diff_vars.keys()):
                if var not in abs_per_var:
                    diff_scale[var] = float(abs_default)
        else:
            raise TypeError(f"The absolute_error_per_variable key in method_options "
                            f"should either be None or a dictionary "
                            f"containing the error for each individual state variable. "
                            f"Was type {type(abs_per_var)}")
        # write code
        return self.initialize_array('_GSL_scale_array',
                                     [diff_scale[var] for var in sorted(diff_vars)])
Exemplo n.º 12
0
    def scale_array_code(self, diff_vars, method_options):
        '''
        Return code for definition of ``_GSL_scale_array`` in generated code.

        Parameters
        ----------
        diff_vars : dict
            dictionary with variable name (str) as key and differential variable
            index (int) as value
        method_options : dict
            dictionary containing integrator settings

        Returns
        -------
        code : str
            full code describing a function returning a array containing doubles
            with the absolute errors for each differential variable (according
            to their assigned index in the GSL StateUpdater)
        '''
        # get scale values per variable from method_options
        abs_per_var = method_options['absolute_error_per_variable']
        abs_default = method_options['absolute_error']

        if not isinstance(abs_default, float):
            raise TypeError(("The absolute_error key in method_options should be "
                             "a float. Was type %s" % (str(type(abs_default)))))

        if abs_per_var is None:
            diff_scale = {var: float(abs_default) for var in diff_vars.keys()}
        elif isinstance(abs_per_var, dict):
            diff_scale = {}
            for var, error in abs_per_var.items():
                # first do some checks on input
                if not var in diff_vars:
                    if not var in self.variables:
                        raise KeyError("absolute_error specified for variable "
                                       "that does not exist: %s"%var)
                    else:
                        raise KeyError("absolute_error specified for variable "
                                       "that is not being integrated: %s"%var)
                fail_for_dimension_mismatch(error, self.variables[var],
                                            ("Unit of absolute_error_per_variable "
                                             "for variable %s does not match "
                                             "unit of varialbe itself"%var))
                # if all these are passed we can add the value for error in base units
                diff_scale[var] = float(error)
            # set the variables that are not mentioned to default value
            for var in diff_vars.keys():
                if var not in abs_per_var:
                    diff_scale[var] = float(abs_default)
        else:
            raise TypeError(("The absolute_error_per_variable key in method_options "
                             "should either be None or a dictionary "
                             "containing the error for each individual state variable. "
                             "Was type %s"%(str(type(abs_per_var)))))
        # write code
        return self.initialize_array('_GSL_scale_array', [diff_scale[var] for var in sorted(diff_vars)])
Exemplo n.º 13
0
 def wrapper_function(*args):
     if not len(args) == len(self._function._arg_units):
         raise ValueError(('Function %s got %d arguments, '
                           'expected %d') % (self._function.name, len(args),
                                             len(self._function._arg_units)))
     new_args = [Quantity.with_dimensions(arg, get_dimensions(arg_unit))
                 for arg, arg_unit in zip(args, self._function._arg_units)]
     result = orig_func(*new_args)
     fail_for_dimension_mismatch(result, self._function._return_unit)
     return np.asarray(result)
Exemplo n.º 14
0
def test_fail_for_dimension_mismatch():
    '''
    Test the fail_for_dimension_mismatch function.
    '''
    # examples that should not raise an error
    fail_for_dimension_mismatch(3)
    fail_for_dimension_mismatch(3 * volt/volt)
    fail_for_dimension_mismatch(3 * volt/volt, 7)
    fail_for_dimension_mismatch(3 * volt, 5 * volt)
    
    # examples that should raise an error
    assert_raises(DimensionMismatchError, lambda: fail_for_dimension_mismatch(6 * volt))
    assert_raises(DimensionMismatchError, lambda: fail_for_dimension_mismatch(6 * volt, 5 * second))    
Exemplo n.º 15
0
def test_fail_for_dimension_mismatch():
    '''
    Test the fail_for_dimension_mismatch function.
    '''
    # examples that should not raise an error
    fail_for_dimension_mismatch(3)
    fail_for_dimension_mismatch(3 * volt/volt)
    fail_for_dimension_mismatch(3 * volt/volt, 7)
    fail_for_dimension_mismatch(3 * volt, 5 * volt)
    
    # examples that should raise an error
    assert_raises(DimensionMismatchError, lambda: fail_for_dimension_mismatch(6 * volt))
    assert_raises(DimensionMismatchError, lambda: fail_for_dimension_mismatch(6 * volt, 5 * second))    
Exemplo n.º 16
0
 def wrapper_function(*args):
     if not len(args) == len(self._function._arg_units):
         raise ValueError(
             ('Function %s got %d arguments, '
              'expected %d') % (self._function.name, len(args),
                                len(self._function._arg_units)))
     new_args = [
         Quantity.with_dimensions(arg, get_dimensions(arg_unit))
         for arg, arg_unit in zip(args, self._function._arg_units)
     ]
     result = orig_func(*new_args)
     fail_for_dimension_mismatch(result,
                                 self._function._return_unit)
     return np.asarray(result)
Exemplo n.º 17
0
 def __setattr__(self, name, val, level=0):
     # attribute access is switched off until this attribute is created by
     # _enable_group_attributes
     if not hasattr(self, '_group_attribute_access_active') or name in self.__dict__:
         object.__setattr__(self, name, val)
     elif name in self.variables:
         var = self.variables[name]
         if not isinstance(val, basestring):
             fail_for_dimension_mismatch(val, var.unit,
                                         'Incorrect units for setting %s' % name)
         if var.read_only:
             raise TypeError('Variable %s is read-only.' % name)
         # Make the call X.var = ... equivalent to X.var[:] = ...
         var.get_addressable_value_with_unit(name, self).set_item(slice(None),
                                                                  val,
                                                                  level=level+1)
     elif len(name) and name[-1]=='_' and name[:-1] in self.variables:
         # no unit checking
         var = self.variables[name[:-1]]
         if var.read_only:
             raise TypeError('Variable %s is read-only.' % name[:-1])
         # Make the call X.var = ... equivalent to X.var[:] = ...
         var.get_addressable_value(name[:-1], self).set_item(slice(None),
                                                             val,
                                                             level=level+1)
     elif hasattr(self, name) or name.startswith('_'):
         object.__setattr__(self, name, val)
     else:
         # Try to suggest the correct name in case of a typo
         checker = SpellChecker([varname for varname, var in self.variables.iteritems()
                                 if not (varname.startswith('_') or var.read_only)])
         if name.endswith('_'):
             suffix = '_'
             name = name[:-1]
         else:
             suffix = ''
         error_msg = 'Could not find a state variable with name "%s".' % name
         suggestions = checker.suggest(name)
         if len(suggestions) == 1:
             suggestion, = suggestions
             error_msg += ' Did you mean to write "%s%s"?' % (suggestion,
                                                              suffix)
         elif len(suggestions) > 1:
             error_msg += (' Did you mean to write any of the following: %s ?' %
                           (', '.join(['"%s%s"' % (suggestion, suffix)
                                       for suggestion in suggestions])))
         error_msg += (' Use the add_attribute method if you intend to add '
                       'a new attribute to the object.')
         raise AttributeError(error_msg)
Exemplo n.º 18
0
 def before_run(self, run_namespace=None):
     rates_var = self.variables['rates']
     if isinstance(rates_var, Subexpression):
         # Check that the units of the expression make sense
         expr = rates_var.expr
         identifiers = get_identifiers(expr)
         variables = self.resolve_all(identifiers,
                                      run_namespace,
                                      user_identifiers=identifiers)
         unit = parse_expression_dimensions(rates_var.expr, variables)
         fail_for_dimension_mismatch(unit, Hz, "The expression provided for "
                                               "PoissonGroup's 'rates' "
                                               "argument, has to have units "
                                               "of Hz")
     super(PoissonGroup, self).before_run(run_namespace)
Exemplo n.º 19
0
 def before_run(self, run_namespace=None):
     rates_var = self.variables['rates']
     if isinstance(rates_var, Subexpression):
         # Check that the units of the expression make sense
         expr = rates_var.expr
         identifiers = get_identifiers(expr)
         variables = self.resolve_all(identifiers,
                                      run_namespace,
                                      user_identifiers=identifiers)
         unit = parse_expression_dimensions(rates_var.expr, variables)
         fail_for_dimension_mismatch(unit, Hz, "The expression provided for "
                                               "PoissonGroup's 'rates' "
                                               "argument, has to have units "
                                               "of Hz")
     super(PoissonGroup, self).before_run(run_namespace)
Exemplo n.º 20
0
            def wrapper_function(*args):
                arg_units = list(self._function._arg_units)

                if self._function.auto_vectorise:
                    arg_units += [DIMENSIONLESS]
                if not len(args) == len(arg_units):
                    func_name = self._function.pyfunc.__name__
                    raise ValueError(
                        f"Function {func_name} got {len(args)} arguments, "
                        f"expected {len(arg_units)}.")
                new_args = []
                for arg, arg_unit in zip(args, arg_units):
                    if arg_unit == bool or arg_unit is None or isinstance(
                            arg_unit, str):
                        new_args.append(arg)
                    else:
                        new_args.append(
                            Quantity.with_dimensions(arg,
                                                     get_dimensions(arg_unit)))
                result = orig_func(*new_args)
                if isinstance(self._function._return_unit, Callable):
                    return_unit = self._function._return_unit(
                        *[get_dimensions(a) for a in args])
                else:
                    return_unit = self._function._return_unit
                if return_unit == bool:
                    if not (isinstance(result, bool)
                            or np.asarray(result).dtype == bool):
                        raise TypeError(
                            f"The function {orig_func.__name__} returned "
                            f"'{result}', but it was expected to return a "
                            f"boolean value ")
                elif (isinstance(return_unit, int) and return_unit
                      == 1) or return_unit.dim is DIMENSIONLESS:
                    fail_for_dimension_mismatch(
                        result, return_unit,
                        f"The function '{orig_func.__name__}' "
                        f"returned {result}, but it was "
                        f"expected to return a dimensionless "
                        f"quantity.")
                else:
                    fail_for_dimension_mismatch(
                        result, return_unit,
                        f"The function '{orig_func.__name__}' "
                        f"returned {result}, but it was "
                        f"expected to return a quantity with "
                        f"units {return_unit!r}.")
                return np.asarray(result)
Exemplo n.º 21
0
            def wrapper_function(*args):
                arg_units = list(self._function._arg_units)

                if self._function.auto_vectorise:
                    arg_units += [DIMENSIONLESS]
                if not len(args) == len(arg_units):
                    raise ValueError(('Function %s got %d arguments, '
                                      'expected %d') % (self._function.pyfunc.__name__, len(args),
                                                        len(arg_units)))
                new_args = []
                for arg, arg_unit in zip(args, arg_units):
                    if arg_unit == bool or arg_unit is None or isinstance(arg_unit, str):
                        new_args.append(arg)
                    else:
                        new_args.append(Quantity.with_dimensions(arg,
                                                                 get_dimensions(arg_unit)))
                result = orig_func(*new_args)
                if isinstance(self._function._return_unit, Callable):
                    return_unit = self._function._return_unit(*[get_dimensions(a)
                                                                for a in args])
                else:
                    return_unit = self._function._return_unit
                if return_unit == bool:
                    if not (isinstance(result, bool) or
                            np.asarray(result).dtype == bool):
                        raise TypeError('The function %s returned '
                                        '%s, but it was expected '
                                        'to return a boolean '
                                        'value ' % (orig_func.__name__,
                                                    result))
                elif (isinstance(return_unit, int) and return_unit == 1) or return_unit.dim is DIMENSIONLESS:
                    fail_for_dimension_mismatch(result,
                                                return_unit,
                                                'The function %s returned '
                                                '{value}, but it was expected '
                                                'to return a dimensionless '
                                                'quantity' % orig_func.__name__,
                                                value=result)
                else:
                    fail_for_dimension_mismatch(result,
                                                return_unit,
                                                ('The function %s returned '
                                                 '{value}, but it was expected '
                                                 'to return a quantity with '
                                                 'units %r') % (orig_func.__name__,
                                                                return_unit),
                                                value=result)
                return np.asarray(result)
Exemplo n.º 22
0
 def __setitem__(self, i, value):
     variable = self.variable
     if variable.scalar:
         if not (i == slice(None) or i == 0 or (hasattr(i, '__len__') and len(i) == 0)):
             raise IndexError('Variable is a scalar variable.')
         indices = np.array([0])
     else:
         indices = self.group.indices[self.group.variable_indices[self.name]][i]
     if isinstance(value, basestring):
         check_units = self.unit is not None
         self.group._set_with_code(variable, indices, value,
                                   check_units, level=self.level + 1)
     else:
         if not self.unit is None:
             fail_for_dimension_mismatch(value, self.unit)
         variable.value[indices] = value
Exemplo n.º 23
0
    def _get_refractory_code(self, run_namespace):
        ref = self.group._refractory
        if ref is False:
            # No refractoriness
            abstract_code = ""
        elif isinstance(ref, Quantity):
            fail_for_dimension_mismatch(
                ref,
                second,
                ("Refractory period has to " "be specified in units " "of seconds but got " "{value}"),
                value=ref,
            )

            abstract_code = "not_refractory = (t - lastspike) > %f\n" % ref
        else:
            identifiers = get_identifiers(ref)
            variables = self.group.resolve_all(identifiers, identifiers, run_namespace=run_namespace)
            unit = parse_expression_unit(str(ref), variables)
            if have_same_dimensions(unit, second):
                abstract_code = "not_refractory = (t - lastspike) > %s\n" % ref
            elif have_same_dimensions(unit, Unit(1)):
                if not is_boolean_expression(str(ref), variables):
                    raise TypeError(
                        (
                            "Refractory expression is dimensionless "
                            "but not a boolean value. It needs to "
                            "either evaluate to a timespan or to a "
                            "boolean value."
                        )
                    )
                # boolean condition
                # we have to be a bit careful here, we can't just use the given
                # condition as it is, because we only want to *leave*
                # refractoriness, based on the condition
                abstract_code = "not_refractory = not_refractory or not (%s)\n" % ref
            else:
                raise TypeError(
                    (
                        "Refractory expression has to evaluate to a "
                        "timespan or a boolean value, expression"
                        '"%s" has units %s instead'
                    )
                    % (ref, unit)
                )
        return abstract_code
Exemplo n.º 24
0
 def __setattr__(self, name, val):
     # attribute access is switched off until this attribute is created by
     # Group.__init__
     if not hasattr(self, '_group_attribute_access_active'):
         object.__setattr__(self, name, val)
     elif name in self.variables:
         var = self.variables[name]
         if not isinstance(val, basestring):
             fail_for_dimension_mismatch(val, var.unit,
                                         'Incorrect units for setting %s' % name)
         # Make the call X.var = ... equivalent to X.var[:] = ...
         var.get_addressable_value_with_unit(self, level=1)[:] = val
     elif len(name) and name[-1]=='_' and name[:-1] in self.variables:
         # no unit checking
         var = self.variables[name[:-1]]
         # Make the call X.var = ... equivalent to X.var[:] = ...
         var.get_addressable_value(self, level=1)[:] = val
     else:
         object.__setattr__(self, name, val)
Exemplo n.º 25
0
 def _get_refractory_code(self, run_namespace):
     ref = self.group._refractory
     if ref is False:
         # No refractoriness
         abstract_code = ''
     elif isinstance(ref, Quantity):
         fail_for_dimension_mismatch(ref,
                                     second, ('Refractory period has to '
                                              'be specified in units '
                                              'of seconds but got '
                                              '{value}'),
                                     value=ref)
         if prefs.legacy.refractory_timing:
             abstract_code = 'not_refractory = (t - lastspike) > %f\n' % ref
         else:
             abstract_code = 'not_refractory = timestep(t - lastspike, dt) >= timestep(%f, dt)\n' % ref
     else:
         identifiers = get_identifiers(ref)
         variables = self.group.resolve_all(identifiers,
                                            run_namespace,
                                            user_identifiers=identifiers)
         dims = parse_expression_dimensions(str(ref), variables)
         if dims is second.dim:
             if prefs.legacy.refractory_timing:
                 abstract_code = '(t - lastspike) > %s\n' % ref
             else:
                 abstract_code = 'not_refractory = timestep(t - lastspike, dt) >= timestep(%s, dt)\n' % ref
         elif dims is DIMENSIONLESS:
             if not is_boolean_expression(str(ref), variables):
                 raise TypeError(('Refractory expression is dimensionless '
                                  'but not a boolean value. It needs to '
                                  'either evaluate to a timespan or to a '
                                  'boolean value.'))
             # boolean condition
             # we have to be a bit careful here, we can't just use the given
             # condition as it is, because we only want to *leave*
             # refractoriness, based on the condition
             abstract_code = 'not_refractory = not_refractory or not (%s)\n' % ref
         else:
             raise TypeError(('Refractory expression has to evaluate to a '
                              'timespan or a boolean value, expression'
                              '"%s" has units %s instead') % (ref, dims))
     return abstract_code
Exemplo n.º 26
0
    def set_with_index_array(self, variable_name, variable, item, value,
                             check_units):
        if check_units:
            fail_for_dimension_mismatch(variable.unit, value,
                                        'Incorrect unit for setting variable %s' % variable_name)
        if variable.scalar:
            if not (isinstance(item, slice) and item == slice(None)):
                raise IndexError(('Illegal index for variable %s, it is a '
                                  'scalar variable.') % variable_name)
            variable.get_value()[0] = value
        else:
            indices = self.calc_indices(item)
            # We are not going via code generation so we have to take care
            # of correct indexing (in particular for subgroups) explicitly
            var_index = self.variables.indices[variable_name]
            if var_index != '_idx':
                indices = self.variables[var_index].get_value()[indices]

            variable.get_value()[indices] = value
Exemplo n.º 27
0
 def _get_refractory_code(self, run_namespace):
     ref = self.group._refractory
     if ref is False:
         # No refractoriness
         abstract_code = ''
     elif isinstance(ref, Quantity):
         fail_for_dimension_mismatch(ref, second, ('Refractory period has to '
                                                   'be specified in units '
                                                   'of seconds but got '
                                                   '{value}'),
                                     value=ref)
         if prefs.legacy.refractory_timing:
             abstract_code = 'not_refractory = (t - lastspike) > %f\n' % ref
         else:
             abstract_code = 'not_refractory = timestep(t - lastspike, dt) >= timestep(%f, dt)\n' % ref
     else:
         identifiers = get_identifiers(ref)
         variables = self.group.resolve_all(identifiers,
                                            run_namespace,
                                            user_identifiers=identifiers)
         dims = parse_expression_dimensions(str(ref), variables)
         if dims is second.dim:
             if prefs.legacy.refractory_timing:
                 abstract_code = '(t - lastspike) > %s\n' % ref
             else:
                 abstract_code = 'not_refractory = timestep(t - lastspike, dt) >= timestep(%s, dt)\n' % ref
         elif dims is DIMENSIONLESS:
             if not is_boolean_expression(str(ref), variables):
                 raise TypeError(('Refractory expression is dimensionless '
                                  'but not a boolean value. It needs to '
                                  'either evaluate to a timespan or to a '
                                  'boolean value.'))
             # boolean condition
             # we have to be a bit careful here, we can't just use the given
             # condition as it is, because we only want to *leave*
             # refractoriness, based on the condition
             abstract_code = 'not_refractory = not_refractory or not (%s)\n' % ref
         else:
             raise TypeError(('Refractory expression has to evaluate to a '
                              'timespan or a boolean value, expression'
                              '"%s" has units %s instead') % (ref, dims))
     return abstract_code
Exemplo n.º 28
0
    def set_with_index_array(self, group, variable_name, variable, item, value,
                        check_units):
        if variable.scalar:
            if not ((isinstance(item, slice) and item == slice(None)) or item == 0 or (hasattr(item, '__len__')
                                                                                           and len(item) == 0)):
                raise IndexError('Variable is a scalar variable.')
            indices = np.array([0])
        else:
            indices = group.calc_indices(item)
        # We are not going via code generation so we have to take care
        # of correct indexing (in particular for subgroups) explicitly
        var_index = group.variables.indices[variable_name]
        if var_index != '_idx':
            indices = group.variables[var_index].get_value()[indices]

        if check_units:
            fail_for_dimension_mismatch(variable.unit, value,
                                        'Incorrect unit for setting variable %s' % variable_name)

        variable.get_value()[indices] = value
Exemplo n.º 29
0
    def _get_refractory_code(self, run_namespace, level=0):
        ref = self.group._refractory
        if ref is False:
            # No refractoriness
            abstract_code = ''
        elif isinstance(ref, Quantity):
            fail_for_dimension_mismatch(ref,
                                        second, ('Refractory period has to '
                                                 'be specified in units '
                                                 'of seconds but got '
                                                 '{value}'),
                                        value=ref)

            abstract_code = 'not_refractory = (t - lastspike) > %f\n' % ref
        else:
            identifiers = get_identifiers(ref)
            variables = self.group.resolve_all(identifiers,
                                               identifiers,
                                               run_namespace=run_namespace,
                                               level=level + 1)
            unit = parse_expression_unit(str(ref), variables)
            if have_same_dimensions(unit, second):
                abstract_code = 'not_refractory = (t - lastspike) > %s\n' % ref
            elif have_same_dimensions(unit, Unit(1)):
                if not is_boolean_expression(str(ref), variables):
                    raise TypeError(('Refractory expression is dimensionless '
                                     'but not a boolean value. It needs to '
                                     'either evaluate to a timespan or to a '
                                     'boolean value.'))
                # boolean condition
                # we have to be a bit careful here, we can't just use the given
                # condition as it is, because we only want to *leave*
                # refractoriness, based on the condition
                abstract_code = 'not_refractory = not_refractory or not (%s)\n' % ref
            else:
                raise TypeError(('Refractory expression has to evaluate to a '
                                 'timespan or a boolean value, expression'
                                 '"%s" has units %s instead') % (ref, unit))
        return abstract_code
Exemplo n.º 30
0
def test_fail_for_dimension_mismatch():
    '''
    Test the fail_for_dimension_mismatch function.
    '''
    # examples that should not raise an error
    dim1, dim2 = fail_for_dimension_mismatch(3)
    assert dim1 is DIMENSIONLESS
    assert dim2 is DIMENSIONLESS
    dim1, dim2 = fail_for_dimension_mismatch(3 * volt / volt)
    assert dim1 is DIMENSIONLESS
    assert dim2 is DIMENSIONLESS
    dim1, dim2 = fail_for_dimension_mismatch(3 * volt / volt, 7)
    assert dim1 is DIMENSIONLESS
    assert dim2 is DIMENSIONLESS
    dim1, dim2 = fail_for_dimension_mismatch(3 * volt, 5 * volt)
    assert dim1 is volt.dim
    assert dim2 is volt.dim

    # examples that should raise an error
    with pytest.raises(DimensionMismatchError):
        fail_for_dimension_mismatch(6 * volt)
    with pytest.raises(DimensionMismatchError):
        fail_for_dimension_mismatch(6 * volt, 5 * second)
Exemplo n.º 31
0
def test_fail_for_dimension_mismatch():
    '''
    Test the fail_for_dimension_mismatch function.
    '''
    # examples that should not raise an error
    dim1, dim2 = fail_for_dimension_mismatch(3)
    assert dim1 is DIMENSIONLESS
    assert dim2 is DIMENSIONLESS
    dim1, dim2 = fail_for_dimension_mismatch(3 * volt/volt)
    assert dim1 is DIMENSIONLESS
    assert dim2 is DIMENSIONLESS
    dim1, dim2 = fail_for_dimension_mismatch(3 * volt/volt, 7)
    assert dim1 is DIMENSIONLESS
    assert dim2 is DIMENSIONLESS
    dim1, dim2 = fail_for_dimension_mismatch(3 * volt, 5 * volt)
    assert dim1 is volt.dim
    assert dim2 is volt.dim

    # examples that should raise an error
    assert_raises(DimensionMismatchError, lambda: fail_for_dimension_mismatch(6 * volt))
    assert_raises(DimensionMismatchError, lambda: fail_for_dimension_mismatch(6 * volt, 5 * second))    
Exemplo n.º 32
0
def check_units_statements(code, variables):
    '''
    Check the units for a series of statements. Setting a model variable has to
    use the correct unit. For newly introduced temporary variables, the unit
    is determined and used to check the following statements to ensure
    consistency.
    
    Parameters
    ----------
    code : str
        The statements as a (multi-line) string
    variables : dict of `Variable` objects
        The information about all variables used in `code` (including
        `Constant` objects for external variables)
    
    Raises
    ------
    KeyError
        In case on of the identifiers cannot be resolved.
    DimensionMismatchError
        If an unit mismatch occurs during the evaluation.
    '''
    # Avoid a circular import
    from brian2.codegen.translation import analyse_identifiers
    known = set(variables.keys())
    newly_defined, _, unknown = analyse_identifiers(code, known)
    
    if len(unknown):
        raise AssertionError(('Encountered unknown identifiers, this should '
                             'not happen at this stage. Unkown identifiers: %s'
                             % unknown))

    
    code = re.split(r'[;\n]', code)
    for line in code:
        line = line.strip()
        if not len(line):
            continue  # skip empty lines
        
        varname, op, expr, comment = parse_statement(line)
        if op in ('+=', '-=', '*=', '/=', '%='):
            # Replace statements such as "w *=2" by "w = w * 2"
            expr = '{var} {op_first} {expr}'.format(var=varname,
                                                    op_first=op[0],
                                                    expr=expr)
            op = '='
        elif op == '=':
            pass
        else:
            raise AssertionError('Unknown operator "%s"' % op) 

        expr_unit = parse_expression_unit(expr, variables)

        if varname in variables:
            fail_for_dimension_mismatch(variables[varname].unit,
                                        expr_unit,
                                        ('Code statement "%s" does not use '
                                         'correct units' % line))
        elif varname in newly_defined:
            # note the unit for later
            variables[varname] = Variable(name=varname, unit=expr_unit,
                                          scalar=False)
        else:
            raise AssertionError(('Variable "%s" is neither in the variables '
                                  'dictionary nor in the list of undefined '
                                  'variables.' % varname))
Exemplo n.º 33
0
    def __init__(self, synapses, code, prepost, objname=None,
                 delay=None):
        self.code = code
        self.prepost = prepost
        if prepost == 'pre':
            self.source = synapses.source
            self.target = synapses.target
            self.synapse_sources = synapses.variables['_synaptic_pre']
        elif prepost == 'post':
            self.source = synapses.target
            self.target = synapses.source
            self.synapse_sources = synapses.variables['_synaptic_post']
        else:
            raise ValueError('prepost argument has to be either "pre" or '
                             '"post"')
        self.synapses = synapses

        if objname is None:
            objname = prepost + '*'

        CodeRunner.__init__(self, synapses,
                            'synapses',
                            code=code,
                            when=(synapses.clock, 'synapses'),
                            name=synapses.name + '_' + objname,
                            template_kwds={'pathway': self})

        self._pushspikes_codeobj = None

        self.spikes_start = self.source.start
        self.spikes_stop = self.source.stop

        self.spiking_synapses = []
        self.variables = Variables(self)
        self.variables.add_attribute_variable('_spiking_synapses', unit=Unit(1),
                                              obj=self,
                                              attribute='spiking_synapses',
                                              constant=False,
                                              scalar=False)
        self.variables.add_reference('_spikespace',
                                     self.source.variables['_spikespace'])
        self.variables.add_reference('N', synapses.variables['N'])
        if delay is None:  # variable delays
            self.variables.add_dynamic_array('delay', unit=second,
                                             size=synapses._N, constant=True,
                                             constant_size=True)
            # Register the object with the `SynapticIndex` object so it gets
            # automatically resized
            synapses.register_variable(self.variables['delay'])
        else:
            if not isinstance(delay, Quantity):
                raise TypeError(('Cannot set the delay for pathway "%s": '
                                 'expected a quantity, got %s instead.') % (objname,
                                                                            type(delay)))
            if delay.size != 1:
                raise TypeError(('Cannot set the delay for pathway "%s": '
                                 'expected a scalar quantity, got a '
                                 'quantity with shape %s instead.') % str(delay.shape))
            fail_for_dimension_mismatch(delay, second, ('Delay has to be '
                                                        'specified in units '
                                                        'of seconds'))
            self.variables.add_array('delay', unit=second, size=1,
                                     constant=True, scalar=True)
            self.variables['delay'].set_value(delay)

        self._delays = self.variables['delay']

        # Re-extract the last part of the name from the full name
        self.objname = self.name[len(synapses.name) + 1:]

        #: The simulation dt (necessary for the delays)
        self.dt = self.synapses.clock.dt_

        #: The `SpikeQueue`
        self.queue = None

        #: The `CodeObject` initalising the `SpikeQueue` at the begin of a run
        self._initialise_queue_codeobj = None

        self.namespace = synapses.namespace
        # Enable access to the delay attribute via the specifier
        self._enable_group_attributes()
Exemplo n.º 34
0
def check_units_statements(code, variables):
    '''
    Check the units for a series of statements. Setting a model variable has to
    use the correct unit. For newly introduced temporary variables, the unit
    is determined and used to check the following statements to ensure
    consistency.
    
    Parameters
    ----------
    code : str
        The statements as a (multi-line) string
    variables : dict of `Variable` objects
        The information about all variables used in `code` (including
        `Constant` objects for external variables)
    
    Raises
    ------
    KeyError
        In case on of the identifiers cannot be resolved.
    DimensionMismatchError
        If an unit mismatch occurs during the evaluation.
    '''
    # Avoid a circular import
    from brian2.codegen.translation import analyse_identifiers
    known = set(variables.keys())
    newly_defined, _, unknown = analyse_identifiers(code, known)
    
    if len(unknown):
        raise AssertionError(('Encountered unknown identifiers, this should '
                             'not happen at this stage. Unkown identifiers: %s'
                             % unknown))

    
    code = re.split(r'[;\n]', code)
    for line in code:
        line = line.strip()
        if not len(line):
            continue  # skip empty lines
        
        varname, op, expr, comment = parse_statement(line)
        if op in ('+=', '-=', '*=', '/=', '%='):
            # Replace statements such as "w *=2" by "w = w * 2"
            expr = '{var} {op_first} {expr}'.format(var=varname,
                                                    op_first=op[0],
                                                    expr=expr)
            op = '='
        elif op == '=':
            pass
        else:
            raise AssertionError('Unknown operator "%s"' % op) 

        expr_unit = parse_expression_unit(expr, variables)

        if varname in variables:
            expected_unit = variables[varname].unit
            fail_for_dimension_mismatch(expr_unit, expected_unit,
                                        ('The right-hand-side of code '
                                         'statement ""%s" does not have the '
                                         'expected unit %r') % (line,
                                                               expected_unit))
        elif varname in newly_defined:
            # note the unit for later
            variables[varname] = Variable(name=varname, unit=expr_unit,
                                          scalar=False)
        else:
            raise AssertionError(('Variable "%s" is neither in the variables '
                                  'dictionary nor in the list of undefined '
                                  'variables.' % varname))
Exemplo n.º 35
0
 def __setattr__(self, name, val, level=0):
     # attribute access is switched off until this attribute is created by
     # _enable_group_attributes
     if not hasattr(self, "_group_attribute_access_active") or name in self.__dict__:
         object.__setattr__(self, name, val)
     elif name in self.__getattribute__("__dict__") or name in self.__getattribute__("__class__").__dict__:
         # Makes sure that classes can override the "variables" mechanism
         # with instance/class attributes and properties
         return object.__setattr__(self, name, val)
     elif name in self.variables:
         var = self.variables[name]
         if not isinstance(val, basestring):
             if var.unit.dim is DIMENSIONLESS:
                 fail_for_dimension_mismatch(
                     val,
                     var.unit,
                     ("%s should be set with a " "dimensionless value, but got " "{value}") % name,
                     value=val,
                 )
             else:
                 fail_for_dimension_mismatch(
                     val,
                     var.unit,
                     ("%s should be set with a " "value with units %r, but got " "{value}") % (name, var.unit),
                     value=val,
                 )
         if var.read_only:
             raise TypeError("Variable %s is read-only." % name)
         # Make the call X.var = ... equivalent to X.var[:] = ...
         var.get_addressable_value_with_unit(name, self).set_item(slice(None), val, level=level + 1)
     elif len(name) and name[-1] == "_" and name[:-1] in self.variables:
         # no unit checking
         var = self.variables[name[:-1]]
         if var.read_only:
             raise TypeError("Variable %s is read-only." % name[:-1])
         # Make the call X.var = ... equivalent to X.var[:] = ...
         var.get_addressable_value(name[:-1], self).set_item(slice(None), val, level=level + 1)
     elif hasattr(self, name) or name.startswith("_"):
         object.__setattr__(self, name, val)
     else:
         # Try to suggest the correct name in case of a typo
         checker = SpellChecker(
             [
                 varname
                 for varname, var in self.variables.iteritems()
                 if not (varname.startswith("_") or var.read_only)
             ]
         )
         if name.endswith("_"):
             suffix = "_"
             name = name[:-1]
         else:
             suffix = ""
         error_msg = 'Could not find a state variable with name "%s".' % name
         suggestions = checker.suggest(name)
         if len(suggestions) == 1:
             suggestion, = suggestions
             error_msg += ' Did you mean to write "%s%s"?' % (suggestion, suffix)
         elif len(suggestions) > 1:
             error_msg += " Did you mean to write any of the following: %s ?" % (
                 ", ".join(['"%s%s"' % (suggestion, suffix) for suggestion in suggestions])
             )
         error_msg += " Use the add_attribute method if you intend to add " "a new attribute to the object."
         raise AttributeError(error_msg)
Exemplo n.º 36
0
    def __init__(self,
                 source,
                 target=None,
                 model=None,
                 pre=None,
                 post=None,
                 connect=False,
                 delay=None,
                 namespace=None,
                 dtype=None,
                 codeobj_class=None,
                 clock=None,
                 method=None,
                 name='synapses*'):
        self._N = 0
        Group.__init__(self, when=clock, name=name)

        self.codeobj_class = codeobj_class

        self.source = weakref.proxy(source)
        if target is None:
            self.target = self.source
        else:
            self.target = weakref.proxy(target)

        ##### Prepare and validate equations
        if model is None:
            model = ''

        if isinstance(model, basestring):
            model = Equations(model)
        if not isinstance(model, Equations):
            raise TypeError(('model has to be a string or an Equations '
                             'object, is "%s" instead.') % type(model))

        # Check flags
        model.check_flags({
            DIFFERENTIAL_EQUATION: ['event-driven'],
            SUBEXPRESSION: ['summed', 'scalar'],
            PARAMETER: ['constant', 'scalar']
        })

        # Add the lastupdate variable, needed for event-driven updates
        if 'lastupdate' in model._equations:
            raise SyntaxError('lastupdate is a reserved name.')
        model._equations['lastupdate'] = SingleEquation(
            PARAMETER, 'lastupdate', second)
        self._create_variables(model)

        # Separate the equations into event-driven equations,
        # continuously updated equations and summed variable updates
        event_driven = []
        continuous = []
        summed_updates = []
        for single_equation in model.itervalues():
            if 'event-driven' in single_equation.flags:
                event_driven.append(single_equation)
            elif 'summed' in single_equation.flags:
                summed_updates.append(single_equation)
            else:
                continuous.append(single_equation)

        if len(event_driven):
            self.event_driven = Equations(event_driven)
        else:
            self.event_driven = None

        self.equations = Equations(continuous)

        if namespace is None:
            namespace = {}
        #: The group-specific namespace
        self.namespace = namespace

        #: Set of `Variable` objects that should be resized when the
        #: number of synapses changes
        self._registered_variables = set()

        for varname, var in self.variables.iteritems():
            if isinstance(var, DynamicArrayVariable):
                # Register the array with the `SynapticItemMapping` object so
                # it gets automatically resized
                self.register_variable(var)

        if delay is None:
            delay = {}

        if isinstance(delay, Quantity):
            delay = {'pre': delay}
        elif not isinstance(delay, collections.Mapping):
            raise TypeError('Delay argument has to be a quantity or a '
                            'dictionary, is type %s instead.' % type(delay))

        #: List of names of all updaters, e.g. ['pre', 'post']
        self._synaptic_updaters = []
        #: List of all `SynapticPathway` objects
        self._pathways = []
        for prepost, argument in zip(('pre', 'post'), (pre, post)):
            if not argument:
                continue
            if isinstance(argument, basestring):
                pathway_delay = delay.get(prepost, None)
                self._add_updater(argument, prepost, delay=pathway_delay)
            elif isinstance(argument, collections.Mapping):
                for key, value in argument.iteritems():
                    if not isinstance(key, basestring):
                        err_msg = ('Keys for the "{}" argument'
                                   'have to be strings, got '
                                   '{} instead.').format(prepost, type(key))
                        raise TypeError(err_msg)
                    pathway_delay = delay.get(key, None)
                    self._add_updater(value,
                                      prepost,
                                      objname=key,
                                      delay=pathway_delay)

        # Check whether any delays were specified for pathways that don't exist
        for pathway in delay:
            if not pathway in self._synaptic_updaters:
                raise ValueError(('Cannot set the delay for pathway '
                                  '"%s": unknown pathway.') % pathway)

        # If we have a pathway called "pre" (the most common use case), provide
        # direct access to its delay via a delay attribute (instead of having
        # to use pre.delay)
        if 'pre' in self._synaptic_updaters:
            self.variables.add_reference('delay', self.pre.variables['delay'])

        #: Performs numerical integration step
        self.state_updater = None

        # We only need a state update if we have differential equations
        if len(self.equations.diff_eq_names):
            self.state_updater = StateUpdater(self, method)
            self.contained_objects.append(self.state_updater)

        #: "Summed variable" mechanism -- sum over all synapses of a
        #: pre-/postsynaptic target
        self.summed_updaters = {}
        # We want to raise an error if the same variable is updated twice
        # using this mechanism. This could happen if the Synapses object
        # connected a NeuronGroup to itself since then all variables are
        # accessible as var_pre and var_post.
        summed_targets = set()
        for single_equation in summed_updates:
            varname = single_equation.varname
            if not (varname.endswith('_pre') or varname.endswith('_post')):
                raise ValueError(('The summed variable "%s" does not end '
                                  'in "_pre" or "_post".') % varname)
            if not varname in self.variables:
                raise ValueError(('The summed variable "%s" does not refer'
                                  'do any known variable in the '
                                  'target group.') % varname)
            if varname.endswith('_pre'):
                summed_target = self.source
                orig_varname = varname[:-4]
            else:
                summed_target = self.target
                orig_varname = varname[:-5]

            target_eq = getattr(summed_target, 'equations',
                                {}).get(orig_varname, None)
            if target_eq is None or target_eq.type != PARAMETER:
                raise ValueError(('The summed variable "%s" needs a '
                                  'corresponding parameter "%s" in the '
                                  'target group.') % (varname, orig_varname))

            fail_for_dimension_mismatch(
                self.variables['_summed_' + varname].unit,
                self.variables[varname].unit, ('Summed variables need to have '
                                               'the same units in Synapses '
                                               'and the target group'))
            if self.variables[varname] in summed_targets:
                raise ValueError(('The target variable "%s" is already '
                                  'updated by another summed '
                                  'variable') % orig_varname)
            summed_targets.add(self.variables[varname])
            updater = SummedVariableUpdater(single_equation.expr, varname,
                                            self, summed_target)
            self.summed_updaters[varname] = updater
            self.contained_objects.append(updater)

        # Do an initial connect, if requested
        if not isinstance(connect, (bool, basestring)):
            raise TypeError(
                ('"connect" keyword has to be a boolean value or a '
                 'string, is type %s instead.' % type(connect)))
        self._initial_connect = connect
        if not connect is False:
            self.connect(connect, level=1)

        # Activate name attribute access
        self._enable_group_attributes()
Exemplo n.º 37
0
    def __init__(self, morphology=None, model=None, threshold=None,
                 refractory=False, reset=None, events=None,
                 threshold_location=None,
                 dt=None, clock=None, order=0, Cm=0.9 * uF / cm ** 2, Ri=150 * ohm * cm,
                 name='spatialneuron*', dtype=None, namespace=None,
                 method=('linear', 'exponential_euler', 'rk2', 'heun')):

        # #### Prepare and validate equations
        if isinstance(model, basestring):
            model = Equations(model)
        if not isinstance(model, Equations):
            raise TypeError(('model has to be a string or an Equations '
                             'object, is "%s" instead.') % type(model))

        # Insert the threshold mechanism at the specified location
        if threshold_location is not None:
            if hasattr(threshold_location,
                       '_indices'):  # assuming this is a method
                threshold_location = threshold_location._indices()
                # for now, only a single compartment allowed
                if len(threshold_location) == 1:
                    threshold_location = threshold_location[0]
                else:
                    raise AttributeError(('Threshold can only be applied on a '
                                          'single location'))
            threshold = '(' + threshold + ') and (i == ' + str(threshold_location) + ')'

        # Check flags (we have point currents)
        model.check_flags({DIFFERENTIAL_EQUATION: ('point current',),
                           PARAMETER: ('constant', 'shared', 'linked', 'point current'),
                           SUBEXPRESSION: ('shared', 'point current')})

        # Add the membrane potential
        model += Equations('''
        v:volt # membrane potential
        ''')

        # Extract membrane equation
        if 'Im' in model:
            membrane_eq = model['Im']  # the membrane equation
        else:
            raise TypeError('The transmembrane current Im must be defined')

        # Insert point currents in the membrane equation
        for eq in model.itervalues():
            if 'point current' in eq.flags:
                fail_for_dimension_mismatch(eq.unit, amp,
                                            "Point current " + eq.varname + " should be in amp")
                eq.flags.remove('point current')
                membrane_eq.expr = Expression(
                    str(membrane_eq.expr.code) + '+' + eq.varname + '/area')

        ###### Process model equations (Im) to extract total conductance and the remaining current
        # Check conditional linearity with respect to v
        # Match to _A*v+_B
        var = sp.Symbol('v', real=True)
        wildcard = sp.Wild('_A', exclude=[var])
        constant_wildcard = sp.Wild('_B', exclude=[var])
        pattern = wildcard * var + constant_wildcard

        # Expand expressions in the membrane equation
        membrane_eq.type = DIFFERENTIAL_EQUATION
        for var, expr in model.get_substituted_expressions():
            if var == 'Im':
                Im_expr = expr
        membrane_eq.type = SUBEXPRESSION

        # Factor out the variable
        s_expr = sp.collect(str_to_sympy(Im_expr.code).expand(), var)
        matches = s_expr.match(pattern)

        if matches is None:
            raise TypeError, "The membrane current must be linear with respect to v"
        a, b = (matches[wildcard],
                matches[constant_wildcard])

        # Extracts the total conductance from Im, and the remaining current
        minusa_str, b_str = sympy_to_str(-a), sympy_to_str(b)
        # Add correct units if necessary
        if minusa_str == '0':
            minusa_str += '*siemens/meter**2'
        if b_str == '0':
            b_str += '*amp/meter**2'
        gtot_str = "gtot__private=" + minusa_str + ": siemens/meter**2"
        I0_str = "I0__private=" + b_str + ": amp/meter**2"
        model += Equations(gtot_str + "\n" + I0_str)

        # Insert morphology (store a copy)
        self.morphology = copy.deepcopy(morphology)

        # Flatten the morphology
        self.flat_morphology = FlatMorphology(morphology)

        # Equations for morphology
        # TODO: check whether Cm and Ri are already in the equations
        #       no: should be shared instead of constant
        #       yes: should be constant (check)
        eqs_constants = Equations("""
        length : meter (constant)
        distance : meter (constant)
        area : meter**2 (constant)
        volume : meter**3
        diameter : meter (constant)
        Cm : farad/meter**2 (constant)
        Ri : ohm*meter (constant, shared)
        r_length_1 : meter (constant)
        r_length_2 : meter (constant)
        time_constant = Cm/gtot__private : second
        space_constant = (2/pi)**(1.0/3.0) * (area/(1/r_length_1 + 1/r_length_2))**(1.0/6.0) /
                         (2*(Ri*gtot__private)**(1.0/2.0)) : meter
        """)
        if self.flat_morphology.has_coordinates:
            eqs_constants += Equations('''
            x : meter (constant)
            y : meter (constant)
            z : meter (constant)
            ''')

        NeuronGroup.__init__(self, morphology.total_compartments,
                             model=model + eqs_constants,
                             threshold=threshold, refractory=refractory,
                             reset=reset, events=events,
                             method=method, dt=dt, clock=clock, order=order,
                             namespace=namespace, dtype=dtype, name=name)
        # Parameters and intermediate variables for solving the cable equations
        # Note that some of these variables could have meaningful physical
        # units (e.g. _v_star is in volt, _I0_all is in amp/meter**2 etc.) but
        # since these variables should never be used in user code, we don't
        # assign them any units
        self.variables.add_arrays(['_ab_star0', '_ab_star1', '_ab_star2',
                                   '_a_minus0', '_a_minus1', '_a_minus2',
                                   '_a_plus0', '_a_plus1', '_a_plus2',
                                   '_b_plus', '_b_minus',
                                   '_v_star', '_u_plus', '_u_minus',
                                   # The following three are for solving the
                                   # three tridiag systems in parallel
                                   '_c1', '_c2', '_c3',
                                   # The following two are only necessary for
                                   # C code where we cannot deal with scalars
                                   # and arrays interchangeably:
                                   '_I0_all', '_gtot_all'], unit=1,
                                  size=self.N, read_only=True)

        self.Cm = Cm
        self.Ri = Ri
        # These explict assignments will load the morphology values from disk
        # in standalone mode
        self.distance_ = self.flat_morphology.distance
        self.length_ = self.flat_morphology.length
        self.area_ = self.flat_morphology.area
        self.diameter_ = self.flat_morphology.diameter
        self.r_length_1_ = self.flat_morphology.r_length_1
        self.r_length_2_ = self.flat_morphology.r_length_2
        if self.flat_morphology.has_coordinates:
            self.x_ = self.flat_morphology.x
            self.y_ = self.flat_morphology.y
            self.z_ = self.flat_morphology.z

        # Performs numerical integration step
        self.add_attribute('diffusion_state_updater')
        self.diffusion_state_updater = SpatialStateUpdater(self, method,
                                                           clock=self.clock,
                                                           order=order)

        # Creation of contained_objects that do the work
        self.contained_objects.extend([self.diffusion_state_updater])
Exemplo n.º 38
0
    def __init__(self,
                 morphology=None,
                 model=None,
                 threshold=None,
                 refractory=False,
                 reset=None,
                 events=None,
                 threshold_location=None,
                 dt=None,
                 clock=None,
                 order=0,
                 Cm=0.9 * uF / cm**2,
                 Ri=150 * ohm * cm,
                 name='spatialneuron*',
                 dtype=None,
                 namespace=None,
                 method=('linear', 'exponential_euler', 'rk2', 'heun')):

        # #### Prepare and validate equations
        if isinstance(model, basestring):
            model = Equations(model)
        if not isinstance(model, Equations):
            raise TypeError(('model has to be a string or an Equations '
                             'object, is "%s" instead.') % type(model))

        # Insert the threshold mechanism at the specified location
        if threshold_location is not None:
            if hasattr(threshold_location,
                       '_indices'):  # assuming this is a method
                threshold_location = threshold_location._indices()
                # for now, only a single compartment allowed
                if len(threshold_location) == 1:
                    threshold_location = threshold_location[0]
                else:
                    raise AttributeError(('Threshold can only be applied on a '
                                          'single location'))
            threshold = '(' + threshold + ') and (i == ' + str(
                threshold_location) + ')'

        # Check flags (we have point currents)
        model.check_flags({
            DIFFERENTIAL_EQUATION: ('point current', ),
            PARAMETER: ('constant', 'shared', 'linked', 'point current'),
            SUBEXPRESSION: ('shared', 'point current', 'constant over dt')
        })
        #: The original equations as specified by the user (i.e. before
        #: inserting point-currents into the membrane equation, before adding
        #: all the internally used variables and constants, etc.).
        self.user_equations = model

        # Separate subexpressions depending whether they are considered to be
        # constant over a time step or not (this would also be done by the
        # NeuronGroup initializer later, but this would give incorrect results
        # for the linearity check)
        model, constant_over_dt = extract_constant_subexpressions(model)

        # Extract membrane equation
        if 'Im' in model:
            if len(model['Im'].flags):
                raise TypeError(
                    'Cannot specify any flags for the transmembrane '
                    'current Im.')
            membrane_expr = model['Im'].expr  # the membrane equation
        else:
            raise TypeError('The transmembrane current Im must be defined')

        model_equations = []
        # Insert point currents in the membrane equation
        for eq in model.itervalues():
            if eq.varname == 'Im':
                continue  # ignore -- handled separately
            if 'point current' in eq.flags:
                fail_for_dimension_mismatch(
                    eq.dim, amp,
                    "Point current " + eq.varname + " should be in amp")
                membrane_expr = Expression(
                    str(membrane_expr.code) + '+' + eq.varname + '/area')
                eq = SingleEquation(
                    eq.type,
                    eq.varname,
                    eq.dim,
                    expr=eq.expr,
                    flags=list(set(eq.flags) - set(['point current'])))
            model_equations.append(eq)

        model_equations.append(
            SingleEquation(SUBEXPRESSION,
                           'Im',
                           dimensions=(amp / meter**2).dim,
                           expr=membrane_expr))
        model_equations.append(SingleEquation(PARAMETER, 'v', volt.dim))
        model = Equations(model_equations)

        ###### Process model equations (Im) to extract total conductance and the remaining current
        # Expand expressions in the membrane equation
        for var, expr in model.get_substituted_expressions(
                include_subexpressions=True):
            if var == 'Im':
                Im_expr = expr
                break
        else:
            raise AssertionError('Model equations did not contain Im!')

        # Differentiate Im with respect to v
        Im_sympy_exp = str_to_sympy(Im_expr.code)
        v_sympy = sp.Symbol('v', real=True)
        diffed = sp.diff(Im_sympy_exp, v_sympy)

        unevaled_derivatives = diffed.atoms(sp.Derivative)
        if len(unevaled_derivatives):
            raise TypeError(
                'Cannot take the derivative of "{Im}" with respect '
                'to v.'.format(Im=Im_expr.code))

        gtot_str = sympy_to_str(sp.simplify(-diffed))
        I0_str = sympy_to_str(sp.simplify(Im_sympy_exp - diffed * v_sympy))

        if gtot_str == '0':
            gtot_str += '*siemens/meter**2'
        if I0_str == '0':
            I0_str += '*amp/meter**2'
        gtot_str = "gtot__private=" + gtot_str + ": siemens/meter**2"
        I0_str = "I0__private=" + I0_str + ": amp/meter**2"

        model += Equations(gtot_str + "\n" + I0_str)

        # Insert morphology (store a copy)
        self.morphology = copy.deepcopy(morphology)

        # Flatten the morphology
        self.flat_morphology = FlatMorphology(morphology)

        # Equations for morphology
        # TODO: check whether Cm and Ri are already in the equations
        #       no: should be shared instead of constant
        #       yes: should be constant (check)
        eqs_constants = Equations("""
        length : meter (constant)
        distance : meter (constant)
        area : meter**2 (constant)
        volume : meter**3
        Ic : amp/meter**2
        diameter : meter (constant)
        Cm : farad/meter**2 (constant)
        Ri : ohm*meter (constant, shared)
        r_length_1 : meter (constant)
        r_length_2 : meter (constant)
        time_constant = Cm/gtot__private : second
        space_constant = (2/pi)**(1.0/3.0) * (area/(1/r_length_1 + 1/r_length_2))**(1.0/6.0) /
                         (2*(Ri*gtot__private)**(1.0/2.0)) : meter
        """)
        if self.flat_morphology.has_coordinates:
            eqs_constants += Equations('''
            x : meter (constant)
            y : meter (constant)
            z : meter (constant)
            ''')

        NeuronGroup.__init__(self,
                             morphology.total_compartments,
                             model=model + eqs_constants,
                             threshold=threshold,
                             refractory=refractory,
                             reset=reset,
                             events=events,
                             method=method,
                             dt=dt,
                             clock=clock,
                             order=order,
                             namespace=namespace,
                             dtype=dtype,
                             name=name)
        # Parameters and intermediate variables for solving the cable equations
        # Note that some of these variables could have meaningful physical
        # units (e.g. _v_star is in volt, _I0_all is in amp/meter**2 etc.) but
        # since these variables should never be used in user code, we don't
        # assign them any units
        self.variables.add_arrays(
            [
                '_ab_star0',
                '_ab_star1',
                '_ab_star2',
                '_a_minus0',
                '_a_minus1',
                '_a_minus2',
                '_a_plus0',
                '_a_plus1',
                '_a_plus2',
                '_b_plus',
                '_b_minus',
                '_v_star',
                '_u_plus',
                '_u_minus',
                '_v_previous',
                # The following three are for solving the
                # three tridiag systems in parallel
                '_c1',
                '_c2',
                '_c3',
                # The following two are only necessary for
                # C code where we cannot deal with scalars
                # and arrays interchangeably:
                '_I0_all',
                '_gtot_all'
            ],
            size=self.N,
            read_only=True)

        self.Cm = Cm
        self.Ri = Ri
        # These explict assignments will load the morphology values from disk
        # in standalone mode
        self.distance_ = self.flat_morphology.distance
        self.length_ = self.flat_morphology.length
        self.area_ = self.flat_morphology.area
        self.diameter_ = self.flat_morphology.diameter
        self.r_length_1_ = self.flat_morphology.r_length_1
        self.r_length_2_ = self.flat_morphology.r_length_2
        if self.flat_morphology.has_coordinates:
            self.x_ = self.flat_morphology.x
            self.y_ = self.flat_morphology.y
            self.z_ = self.flat_morphology.z

        # Performs numerical integration step
        self.add_attribute('diffusion_state_updater')
        self.diffusion_state_updater = SpatialStateUpdater(self,
                                                           method,
                                                           clock=self.clock,
                                                           order=order)

        # Update v after the gating variables to obtain consistent Ic and Im
        self.diffusion_state_updater.order = 1

        # Creation of contained_objects that do the work
        self.contained_objects.extend([self.diffusion_state_updater])

        if len(constant_over_dt):
            self.subexpression_updater = SubexpressionUpdater(
                self, constant_over_dt)
            self.contained_objects.append(self.subexpression_updater)
def plot_morphology(morphology, plot_3d=None, show_compartments=False,
                    show_diameter=False, colors=('darkblue', 'darkred'),
                    values=None, value_norm=(None, None), value_colormap='hot',
                    value_colorbar=True, value_unit=None, axes=None):
    '''
    Plot a given `~brian2.spatialneuron.morphology.Morphology` in 2D or 3D.

    Parameters
    ----------
    morphology : `~brian2.spatialneuron.morphology.Morphology`
        The morphology to plot
    plot_3d : bool, optional
        Whether to plot the morphology in 3D or in 2D. If not set (the default)
        a morphology where all z values are 0 is plotted in 2D, otherwise it is
        plot in 3D.
    show_compartments : bool, optional
        Whether to plot a dot at the center of each compartment. Defaults to
        ``False``.
    show_diameter : bool, optional
        Whether to plot the compartments with the diameter given in the
        morphology. Defaults to ``False``.
    colors : sequence of color specifications
        A list of colors that is cycled through for each new section. Can be
        any color specification that matplotlib understands (e.g. a string such
        as ``'darkblue'`` or a tuple such as `(0, 0.7, 0)`.
    values : ~brian2.units.fundamentalunits.Quantity, optional
        Values to fill compartment patches with a color that corresponds to
        their given value.
    value_norm : tuple or callable, optional
        Normalization function to scale the displayed values. Can be a tuple
        of a minimum and a maximum value (where either of them can be ``None``
        to denote taking the minimum/maximum from the data) or a function that
        takes a value and returns the scaled value (e.g. as returned by
        `.matplotlib.colors.PowerNorm`). For a tuple of values, will use
        `.matplotlib.colors.Normalize```(vmin, vmax, clip=True)``` with the
        given ``(vmin, vmax)`` values.
    value_colormap : str or matplotlib.colors.Colormap, optional
        Desired colormap for plots. Either the name of a standard colormap
        or a `.matplotlib.colors.Colormap` instance. Defaults to ``'hot'``.
        Note that this uses ``matplotlib`` color maps even for 3D plots with
        Mayavi.
    value_colorbar : bool or dict, optional
        Whether to add a colorbar for the ``values``. Defaults to ``True``,
        but will be ignored if no ``values`` are provided. Can also be a
        dictionary with the keyword arguments for matplotlib's
        `~.matplotlib.figure.Figure.colorbar` method (2D plot), or for
        Mayavi's `~.mayavi.mlab.scalarbar` method (3D plot).
    value_unit : `Unit`, optional
        A `Unit` to rescale the values for display in the colorbar. Does not
        have any visible effect if no colorbar is used. If not specified, will
        try to determine the "best unit" to itself.
    axes : `~matplotlib.axes.Axes` or `~mayavi.core.api.Scene`, optional
        A matplotlib `~matplotlib.axes.Axes` (for 2D plots) or mayavi
        `~mayavi.core.api.Scene` ( for 3D plots) instance, where the plot will
        be added.

    Returns
    -------
    axes : `~matplotlib.axes.Axes` or `~mayavi.core.api.Scene`
        The `~matplotlib.axes.Axes` or `~mayavi.core.api.Scene` instance that
        was used for plotting. This object allows to modify the plot further,
        e.g. by setting the plotted range, the axis labels, the plot title, etc.
    '''
    # Avoid circular import issues
    from brian2tools.plotting.base import (_setup_axes_matplotlib,
                                           _setup_axes_mayavi)

    if plot_3d is None:
        # Decide whether to use 2d or 3d plotting based on the coordinates
        flat_morphology = FlatMorphology(morphology)
        plot_3d = any(np.abs(flat_morphology.z) > 1e-12)

    if values is not None:
        if hasattr(values, 'name'):
            value_varname = values.name
        else:
            value_varname = 'values'
        if value_unit is not None:
            if not isinstance(value_unit, Unit):
                raise TypeError(f'\'value_unit\' has to be a unit but is'
                                f'\'{type(value_unit)}\'.')
            fail_for_dimension_mismatch(value_unit, values,
                                        'The \'value_unit\' arguments needs '
                                        'to have the same dimensions as '
                                        'the \'values\'.')
        else:
            if have_same_dimensions(values, DIMENSIONLESS):
                value_unit = 1.
            else:
                value_unit = values[:].get_best_unit()
        orig_values = values
        values = values/value_unit
        if isinstance(value_norm, tuple):
            if not len(value_norm) == 2:
                raise TypeError('Need a (vmin, vmax) tuple for the value '
                                'normalization, but got a tuple of length '
                                f'{len(value_norm)}.')
            vmin, vmax = value_norm
            if vmin is not None:
                err_msg = ('The minimum value in \'value_norm\' needs to '
                           'have the same units as \'values\'.')
                fail_for_dimension_mismatch(vmin, orig_values,
                                            error_message=err_msg)
                vmin /= value_unit
            if vmax is not None:
                err_msg = ('The maximum value in \'value_norm\' needs to '
                           'have the same units as \'values\'.')
                fail_for_dimension_mismatch(vmax, orig_values,
                                            error_message=err_msg)
                vmax /= value_unit
            if plot_3d:
                value_norm = (vmin, vmax)
            else:
                value_norm = Normalize(vmin=vmin, vmax=vmax, clip=True)
                value_norm.autoscale_None(values)
        elif plot_3d:
            raise TypeError('3d plots only support normalizations given by '
                            'a (min, max) tuple.')
        value_colormap = plt.get_cmap(value_colormap)

    if plot_3d:
        try:
            import mayavi.mlab as mayavi
        except ImportError:
            raise ImportError('3D plotting needs the mayavi library')
        axes = _setup_axes_mayavi(axes)
        axes.scene.disable_render = True
        surf = _plot_morphology3D(morphology, axes, colors=colors,
                                  values=values, value_norm=value_norm,
                                  value_colormap=value_colormap,
                                  show_diameters=show_diameter,
                                  show_compartments=show_compartments)
        if values is not None and value_colorbar:
            if not isinstance(value_colorbar, Mapping):
                value_colorbar = {}
                if not have_same_dimensions(value_unit, DIMENSIONLESS):
                    unit_str = f' ({value_unit!s})'
                else:
                    unit_str = ''
                if value_varname:
                    value_colorbar['title'] = f'{value_varname}{unit_str}'
            cb = mayavi.scalarbar(surf, **value_colorbar)
            # Make text dark gray
            cb.title_text_property.color = (0.1, 0.1, 0.1)
            cb.label_text_property.color = (0.1, 0.1, 0.1)
        axes.scene.disable_render = False
    else:
        axes = _setup_axes_matplotlib(axes)

        _plot_morphology2D(morphology, axes, colors,
                           values, value_norm, value_colormap,
                           show_compartments=show_compartments,
                           show_diameter=show_diameter)
        axes.set_xlabel('x (um)')
        axes.set_ylabel('y (um)')
        axes.set_aspect('equal')
        if values is not None and value_colorbar:
            divider = make_axes_locatable(axes)
            cax = divider.append_axes("right", size="5%", pad=0.1)
            mappable = ScalarMappable(norm=value_norm, cmap=value_colormap)
            mappable.set_array([])
            fig = axes.get_figure()
            if not isinstance(value_colorbar, Mapping):
                value_colorbar = {}
                if not have_same_dimensions(value_unit, DIMENSIONLESS):
                    unit_str = f' ({value_unit!s})'
                else:
                    unit_str = ''
                if value_varname:
                    value_colorbar['label'] = f'{value_varname}{unit_str}'
            fig.colorbar(mappable, cax=cax, **value_colorbar)
    return axes
Exemplo n.º 40
0
 def __setattr__(self, name, val, level=0):
     # attribute access is switched off until this attribute is created by
     # _enable_group_attributes
     if not hasattr(
             self,
             '_group_attribute_access_active') or name in self.__dict__:
         object.__setattr__(self, name, val)
     elif (name in self.__getattribute__('__dict__')
           or name in self.__getattribute__('__class__').__dict__):
         # Makes sure that classes can override the "variables" mechanism
         # with instance/class attributes and properties
         return object.__setattr__(self, name, val)
     elif name in self.variables:
         var = self.variables[name]
         if not isinstance(val, basestring):
             if var.dim is DIMENSIONLESS:
                 fail_for_dimension_mismatch(
                     val,
                     var.dim,
                     ('%s should be set with a '
                      'dimensionless value, but got '
                      '{value}') % name,
                     value=val)
             else:
                 fail_for_dimension_mismatch(
                     val,
                     var.dim,
                     ('%s should be set with a '
                      'value with units %r, but got '
                      '{value}') % (name, get_unit(var.dim)),
                     value=val)
         if var.read_only:
             raise TypeError('Variable %s is read-only.' % name)
         # Make the call X.var = ... equivalent to X.var[:] = ...
         var.get_addressable_value_with_unit(name,
                                             self).set_item(slice(None),
                                                            val,
                                                            level=level + 1)
     elif len(name) and name[-1] == '_' and name[:-1] in self.variables:
         # no unit checking
         var = self.variables[name[:-1]]
         if var.read_only:
             raise TypeError('Variable %s is read-only.' % name[:-1])
         # Make the call X.var = ... equivalent to X.var[:] = ...
         var.get_addressable_value(name[:-1],
                                   self).set_item(slice(None),
                                                  val,
                                                  level=level + 1)
     elif hasattr(self, name) or name.startswith('_'):
         object.__setattr__(self, name, val)
     else:
         # Try to suggest the correct name in case of a typo
         checker = SpellChecker([
             varname for varname, var in self.variables.items()
             if not (varname.startswith('_') or var.read_only)
         ])
         if name.endswith('_'):
             suffix = '_'
             name = name[:-1]
         else:
             suffix = ''
         error_msg = 'Could not find a state variable with name "%s".' % name
         suggestions = checker.suggest(name)
         if len(suggestions) == 1:
             suggestion, = suggestions
             error_msg += ' Did you mean to write "%s%s"?' % (suggestion,
                                                              suffix)
         elif len(suggestions) > 1:
             error_msg += (
                 ' Did you mean to write any of the following: %s ?' %
                 (', '.join([
                     '"%s%s"' % (suggestion, suffix)
                     for suggestion in suggestions
                 ])))
         error_msg += (' Use the add_attribute method if you intend to add '
                       'a new attribute to the object.')
         raise AttributeError(error_msg)
Exemplo n.º 41
0
    def __init__(self,
                 source,
                 target=None,
                 model=None,
                 pre=None,
                 post=None,
                 connect=False,
                 delay=None,
                 namespace=None,
                 dtype=None,
                 codeobj_class=None,
                 clock=None,
                 method=None,
                 name='synapses*'):

        BrianObject.__init__(self, when=clock, name=name)

        self.codeobj_class = codeobj_class

        self.source = weakref.proxy(source)
        if target is None:
            self.target = self.source
        else:
            self.target = weakref.proxy(target)

        ##### Prepare and validate equations
        if model is None:
            model = ''

        if isinstance(model, basestring):
            model = Equations(model)
        if not isinstance(model, Equations):
            raise TypeError(('model has to be a string or an Equations '
                             'object, is "%s" instead.') % type(model))

        # Check flags
        model.check_flags({
            DIFFERENTIAL_EQUATION: ['event-driven', 'lumped'],
            STATIC_EQUATION: ['lumped'],
            PARAMETER: ['constant', 'lumped']
        })

        # Separate the equations into event-driven and continuously updated
        # equations
        event_driven = []
        continuous = []
        for single_equation in model.itervalues():
            if 'event-driven' in single_equation.flags:
                if 'lumped' in single_equation.flags:
                    raise ValueError(
                        ('Event-driven variable %s cannot be '
                         'a lumped variable.') % single_equation.varname)
                event_driven.append(single_equation)
            else:
                continuous.append(single_equation)
        # Add the lastupdate variable, used by event-driven equations
        continuous.append(SingleEquation(PARAMETER, 'lastupdate', second))

        if len(event_driven):
            self.event_driven = Equations(event_driven)
        else:
            self.event_driven = None

        self.equations = Equations(continuous)

        ##### Setup the memory
        self.arrays = self._allocate_memory(dtype=dtype)

        # Setup the namespace
        self._given_namespace = namespace
        self.namespace = create_namespace(namespace)

        self._queues = {}
        self._delays = {}

        self.item_mapping = SynapticItemMapping(self)
        self.indices = {
            '_idx': self.item_mapping,
            '_presynaptic_idx': self.item_mapping.synaptic_pre,
            '_postsynaptic_idx': self.item_mapping.synaptic_post
        }
        # Allow S.i instead of S.indices.i, etc.
        self.i = self.item_mapping.i
        self.j = self.item_mapping.j
        self.k = self.item_mapping.k

        # Setup variables
        self.variables = self._create_variables()

        #: List of names of all updaters, e.g. ['pre', 'post']
        self._updaters = []
        for prepost, argument in zip(('pre', 'post'), (pre, post)):
            if not argument:
                continue
            if isinstance(argument, basestring):
                self._add_updater(argument, prepost)
            elif isinstance(argument, collections.Mapping):
                for key, value in argument.iteritems():
                    if not isinstance(key, basestring):
                        err_msg = ('Keys for the "{}" argument'
                                   'have to be strings, got '
                                   '{} instead.').format(prepost, type(key))
                        raise TypeError(err_msg)
                    self._add_updater(value, prepost, objname=key)

        # If we have a pathway called "pre" (the most common use case), provide
        # direct access to its delay via a delay attribute (instead of having
        # to use pre.delay)
        if 'pre' in self._updaters:
            self.variables['delay'] = self.pre.variables['delay']

        if delay is not None:
            if isinstance(delay, Quantity):
                if not 'pre' in self._updaters:
                    raise ValueError(
                        ('Cannot set delay, no "pre" pathway exists.'
                         'Use a dictionary if you want to set the '
                         'delay for a pathway with a different name.'))
                delay = {'pre': delay}

            if not isinstance(delay, collections.Mapping):
                raise TypeError('Delay argument has to be a quantity or a '
                                'dictionary, is type %s instead.' %
                                type(delay))
            for pathway, pathway_delay in delay.iteritems():
                if not pathway in self._updaters:
                    raise ValueError(('Cannot set the delay for pathway '
                                      '"%s": unknown pathway.') % pathway)
                if not isinstance(pathway_delay, Quantity):
                    raise TypeError(('Cannot set the delay for pathway "%s": '
                                     'expected a quantity, got %s instead.') %
                                    (pathway, type(pathway_delay)))
                if pathway_delay.size != 1:
                    raise TypeError(('Cannot set the delay for pathway "%s": '
                                     'expected a scalar quantity, got a '
                                     'quantity with shape %s instead.') %
                                    str(pathway_delay.shape))
                fail_for_dimension_mismatch(pathway_delay, second,
                                            ('Delay has to be '
                                             'specified in units '
                                             'of seconds'))
                updater = getattr(self, pathway)
                self.item_mapping.unregister_variable(updater._delays)
                del updater._delays
                # For simplicity, store the delay as a one-element array
                # so that for example updater._delays[:] works.
                updater._delays = np.array([float(pathway_delay)])
                variable = ArrayVariable('delay',
                                         second,
                                         updater._delays,
                                         group_name=self.name,
                                         scalar=True)
                updater.variables['delay'] = variable
                if pathway == 'pre':
                    self.variables['delay'] = variable

        #: Performs numerical integration step
        self.state_updater = StateUpdater(self, method)
        self.contained_objects.append(self.state_updater)

        #: "Lumped variable" mechanism -- sum over all synapses of a
        #: postsynaptic target
        self.lumped_updaters = {}
        for single_equation in self.equations.itervalues():
            if 'lumped' in single_equation.flags:
                varname = single_equation.varname
                # For a lumped variable, we need an equivalent parameter in the
                # target group
                if not varname in self.target.variables:
                    raise ValueError(
                        ('The lumped variable %s needs a variable '
                         'of the same name in the target '
                         'group ') % single_equation.varname)
                fail_for_dimension_mismatch(self.variables[varname].unit,
                                            self.target.variables[varname],
                                            ('Lumped variables need to have '
                                             'the same units in Synapses '
                                             'and the target group'))
                # TODO: Add some more stringent check about the type of
                # variable in the target group
                updater = LumpedUpdater(varname, self, self.target)
                self.lumped_updaters[varname] = updater
                self.contained_objects.append(updater)

        # Do an initial connect, if requested
        if not isinstance(connect, (bool, basestring)):
            raise TypeError(
                ('"connect" keyword has to be a boolean value or a '
                 'string, is type %s instead.' % type(connect)))
        self._initial_connect = connect
        if not connect is False:
            self.connect(connect, level=1)

        # Activate name attribute access
        Group.__init__(self)
Exemplo n.º 42
0
    def __init__(self,
                 morphology=None,
                 model=None,
                 threshold=None,
                 refractory=False,
                 reset=None,
                 threshold_location=None,
                 dt=None,
                 clock=None,
                 order=0,
                 Cm=0.9 * uF / cm**2,
                 Ri=150 * ohm * cm,
                 name='spatialneuron*',
                 dtype=None,
                 namespace=None,
                 method=('linear', 'exponential_euler', 'rk2', 'heun')):

        # #### Prepare and validate equations
        if isinstance(model, basestring):
            model = Equations(model)
        if not isinstance(model, Equations):
            raise TypeError(('model has to be a string or an Equations '
                             'object, is "%s" instead.') % type(model))

        # Insert the threshold mechanism at the specified location
        if threshold_location is not None:
            if hasattr(threshold_location,
                       '_indices'):  # assuming this is a method
                threshold_location = threshold_location._indices()
                # for now, only a single compartment allowed
                if len(threshold_location) == 1:
                    threshold_location = threshold_location[0]
                else:
                    raise AttributeError(('Threshold can only be applied on a '
                                          'single location'))
            threshold = '(' + threshold + ') and (i == ' + str(
                threshold_location) + ')'

        # Check flags (we have point currents)
        model.check_flags({
            DIFFERENTIAL_EQUATION: ('point current', ),
            PARAMETER: ('constant', 'shared', 'linked', 'point current'),
            SUBEXPRESSION: ('shared', 'point current')
        })

        # Add the membrane potential
        model += Equations('''
        v:volt # membrane potential
        ''')

        # Extract membrane equation
        if 'Im' in model:
            membrane_eq = model['Im']  # the membrane equation
        else:
            raise TypeError('The transmembrane current Im must be defined')

        # Insert point currents in the membrane equation
        for eq in model.itervalues():
            if 'point current' in eq.flags:
                fail_for_dimension_mismatch(
                    eq.unit, amp,
                    "Point current " + eq.varname + " should be in amp")
                eq.flags.remove('point current')
                membrane_eq.expr = Expression(
                    str(membrane_eq.expr.code) + '+' + eq.varname + '/area')

        ###### Process model equations (Im) to extract total conductance and the remaining current
        # Check conditional linearity with respect to v
        # Match to _A*v+_B
        var = sp.Symbol('v', real=True)
        wildcard = sp.Wild('_A', exclude=[var])
        constant_wildcard = sp.Wild('_B', exclude=[var])
        pattern = wildcard * var + constant_wildcard

        # Expand expressions in the membrane equation
        membrane_eq.type = DIFFERENTIAL_EQUATION
        for var, expr in model._get_substituted_expressions(
        ):  # this returns substituted expressions for diff eqs
            if var == 'Im':
                Im_expr = expr
        membrane_eq.type = SUBEXPRESSION

        # Factor out the variable
        s_expr = sp.collect(Im_expr.sympy_expr.expand(), var)
        matches = s_expr.match(pattern)

        if matches is None:
            raise TypeError, "The membrane current must be linear with respect to v"
        a, b = (matches[wildcard], matches[constant_wildcard])

        # Extracts the total conductance from Im, and the remaining current
        minusa_str, b_str = sympy_to_str(-a), sympy_to_str(b)
        # Add correct units if necessary
        if minusa_str == '0':
            minusa_str += '*siemens/meter**2'
        if b_str == '0':
            b_str += '*amp/meter**2'
        gtot_str = "gtot__private=" + minusa_str + ": siemens/meter**2"
        I0_str = "I0__private=" + b_str + ": amp/meter**2"
        model += Equations(gtot_str + "\n" + I0_str)

        # Equations for morphology
        # TODO: check whether Cm and Ri are already in the equations
        #       no: should be shared instead of constant
        #       yes: should be constant (check)
        eqs_constants = Equations("""
        diameter : meter (constant)
        length : meter (constant)
        x : meter (constant)
        y : meter (constant)
        z : meter (constant)
        distance : meter (constant)
        area : meter**2 (constant)
        Cm : farad/meter**2 (constant)
        Ri : ohm*meter (constant, shared)
        space_constant = (diameter/(4*Ri*gtot__private))**.5 : meter # Not so sure about the name

        ### Parameters and intermediate variables for solving the cable equation
        ab_star0 : siemens/meter**2
        ab_plus0 : siemens/meter**2
        ab_minus0 : siemens/meter**2
        ab_star1 : siemens/meter**2
        ab_plus1 : siemens/meter**2
        ab_minus1 : siemens/meter**2
        ab_star2 : siemens/meter**2
        ab_plus2 : siemens/meter**2
        ab_minus2 : siemens/meter**2
        b_plus : siemens/meter**2
        b_minus : siemens/meter**2
        v_star : volt
        u_plus : 1
        u_minus : 1
        # The following two are only necessary for C code where we cannot deal
        # with scalars and arrays interchangeably
        gtot_all : siemens/meter**2
        I0_all : amp/meter**2
        """)
        # Possibilities for the name: characteristic_length, electrotonic_length, length_constant, space_constant

        # Insert morphology
        self.morphology = morphology

        # Link morphology variables to neuron's state variables
        self.morphology_data = MorphologyData(len(morphology))
        self.morphology.compress(self.morphology_data)

        NeuronGroup.__init__(self,
                             len(morphology),
                             model=model + eqs_constants,
                             threshold=threshold,
                             refractory=refractory,
                             reset=reset,
                             method=method,
                             dt=dt,
                             clock=clock,
                             order=order,
                             namespace=namespace,
                             dtype=dtype,
                             name=name)

        self.Cm = Cm
        self.Ri = Ri
        # TODO: View instead of copy for runtime?
        self.diameter_ = self.morphology_data.diameter
        self.distance_ = self.morphology_data.distance
        self.length_ = self.morphology_data.length
        self.area_ = self.morphology_data.area
        self.x_ = self.morphology_data.x
        self.y_ = self.morphology_data.y
        self.z_ = self.morphology_data.z

        # Performs numerical integration step
        self.add_attribute('diffusion_state_updater')
        self.diffusion_state_updater = SpatialStateUpdater(self,
                                                           method,
                                                           clock=self.clock,
                                                           order=order)

        # Creation of contained_objects that do the work
        self.contained_objects.extend([self.diffusion_state_updater])
Exemplo n.º 43
0
def check_units_statements(code, namespace, variables):
    '''
    Check the units for a series of statements. Setting a model variable has to
    use the correct unit. For newly introduced temporary variables, the unit
    is determined and used to check the following statements to ensure
    consistency.
    
    Parameters
    ----------
    expression : str
        The expression to evaluate.
    namespace : dict-like
        The namespace of external variables.
    variables : dict of `Variable` objects
        The information about the internal variables
    
    Raises
    ------
    KeyError
        In case on of the identifiers cannot be resolved.
    DimensionMismatchError
        If an unit mismatch occurs during the evaluation.
    '''
    known = set(variables.keys()) | set(namespace.keys())
    newly_defined, _, unknown = analyse_identifiers(code, known)

    if len(unknown):
        raise AssertionError(
            ('Encountered unknown identifiers, this should '
             'not happen at this stage. Unkown identifiers: %s' % unknown))

    # We want to add newly defined variables to the variables dictionary so we
    # make a copy now
    variables = dict(variables)

    code = re.split(r'[;\n]', code)
    for line in code:
        line = line.strip()
        if not len(line):
            continue  # skip empty lines

        varname, op, expr = parse_statement(line)
        if op in ('+=', '-=', '*=', '/=', '%='):
            # Replace statements such as "w *=2" by "w = w * 2"
            expr = '{var} {op_first} {expr}'.format(var=varname,
                                                    op_first=op[0],
                                                    expr=expr)
            op = '='
        elif op == '=':
            pass
        else:
            raise AssertionError('Unknown operator "%s"' % op)

        expr_unit = parse_expression_unit(expr, namespace, variables)

        if varname in variables:
            fail_for_dimension_mismatch(variables[varname].unit, expr_unit,
                                        ('Code statement "%s" does not use '
                                         'correct units' % line))
        elif varname in newly_defined:
            # note the unit for later
            variables[varname] = Variable(expr_unit,
                                          is_bool=False,
                                          scalar=False)
        else:
            raise AssertionError(('Variable "%s" is neither in the variables '
                                  'dictionary nor in the list of undefined '
                                  'variables.' % varname))
Exemplo n.º 44
0
    def __init__(self, source, target=None, model=None, pre=None, post=None,
                 connect=False, delay=None, namespace=None, dtype=None,
                 codeobj_class=None,
                 clock=None, method=None, name='synapses*'):
        self._N = 0
        Group.__init__(self, when=clock, name=name)
        
        self.codeobj_class = codeobj_class

        self.source = weakref.proxy(source)
        if target is None:
            self.target = self.source
        else:
            self.target = weakref.proxy(target)
            
        ##### Prepare and validate equations
        if model is None:
            model = ''

        if isinstance(model, basestring):
            model = Equations(model)
        if not isinstance(model, Equations):
            raise TypeError(('model has to be a string or an Equations '
                             'object, is "%s" instead.') % type(model))

        # Check flags
        model.check_flags({DIFFERENTIAL_EQUATION: ['event-driven'],
                           STATIC_EQUATION: ['summed'],
                           PARAMETER: ['constant']})

        # Separate the equations into event-driven and continuously updated
        # equations
        event_driven = []
        continuous = []
        for single_equation in model.itervalues():
            if 'event-driven' in single_equation.flags:
                event_driven.append(single_equation)
            else:
                continuous.append(single_equation)
        # Add the lastupdate variable, used by event-driven equations
        continuous.append(SingleEquation(PARAMETER, 'lastupdate', second))

        if len(event_driven):
            self.event_driven = Equations(event_driven)
        else:
            self.event_driven = None

        self.equations = Equations(continuous)

        # Setup the namespace
        self._given_namespace = namespace
        self.namespace = create_namespace(namespace)

        self._queues = {}
        self._delays = {}

        # Setup variables
        self._create_variables()

        #: Set of `Variable` objects that should be resized when the
        #: number of synapses changes
        self._registered_variables = set()

        for varname, var in self.variables.iteritems():
            if isinstance(var, DynamicArrayVariable):
                # Register the array with the `SynapticItemMapping` object so
                # it gets automatically resized
                self.register_variable(var)

        #: List of names of all updaters, e.g. ['pre', 'post']
        self._synaptic_updaters = []
        #: List of all `SynapticPathway` objects
        self._pathways = []
        for prepost, argument in zip(('pre', 'post'), (pre, post)):
            if not argument:
                continue
            if isinstance(argument, basestring):
                self._add_updater(argument, prepost)
            elif isinstance(argument, collections.Mapping):
                for key, value in argument.iteritems():
                    if not isinstance(key, basestring):
                        err_msg = ('Keys for the "{}" argument'
                                   'have to be strings, got '
                                   '{} instead.').format(prepost, type(key))
                        raise TypeError(err_msg)
                    self._add_updater(value, prepost, objname=key)

        # If we have a pathway called "pre" (the most common use case), provide
        # direct access to its delay via a delay attribute (instead of having
        # to use pre.delay)
        if 'pre' in self._synaptic_updaters:
            self.variables.add_reference('delay', self.pre.variables['delay'])

        if delay is not None:
            if isinstance(delay, Quantity):
                if not 'pre' in self._synaptic_updaters:
                    raise ValueError(('Cannot set delay, no "pre" pathway exists.'
                                      'Use a dictionary if you want to set the '
                                      'delay for a pathway with a different name.'))
                delay = {'pre': delay}

            if not isinstance(delay, collections.Mapping):
                raise TypeError('Delay argument has to be a quantity or a '
                                'dictionary, is type %s instead.' % type(delay))
            for pathway, pathway_delay in delay.iteritems():
                if not pathway in self._synaptic_updaters:
                    raise ValueError(('Cannot set the delay for pathway '
                                      '"%s": unknown pathway.') % pathway)
                if not isinstance(pathway_delay, Quantity):
                    raise TypeError(('Cannot set the delay for pathway "%s": '
                                     'expected a quantity, got %s instead.') % (pathway,
                                                                                type(pathway_delay)))
                if pathway_delay.size != 1:
                    raise TypeError(('Cannot set the delay for pathway "%s": '
                                     'expected a scalar quantity, got a '
                                     'quantity with shape %s instead.') % str(pathway_delay.shape))
                fail_for_dimension_mismatch(pathway_delay, second, ('Delay has to be '
                                                                    'specified in units '
                                                                    'of seconds'))
                updater = getattr(self, pathway)
                # For simplicity, store the delay as a one-element array
                # so that for example updater._delays[:] works.
                updater._delays.resize(1)
                updater._delays.set_value(float(pathway_delay))
                updater._delays.scalar = True
                # Do not resize the scalar delay variable when adding synapses
                self.unregister_variable(updater._delays)

        #: Performs numerical integration step
        self.state_updater = StateUpdater(self, method)        
        self.contained_objects.append(self.state_updater)

        #: "Summed variable" mechanism -- sum over all synapses of a
        #: pre-/postsynaptic target
        self.summed_updaters = {}
        # We want to raise an error if the same variable is updated twice
        # using this mechanism. This could happen if the Synapses object
        # connected a NeuronGroup to itself since then all variables are
        # accessible as var_pre and var_post.
        summed_targets = set()
        for single_equation in self.equations.itervalues():
            if 'summed' in single_equation.flags:
                varname = single_equation.varname
                if not (varname.endswith('_pre') or varname.endswith('_post')):
                    raise ValueError(('The summed variable "%s" does not end '
                                      'in "_pre" or "_post".') % varname)
                if not varname in self.variables:
                    raise ValueError(('The summed variable "%s" does not refer'
                                      'do any known variable in the '
                                      'target group.') % varname)
                if varname.endswith('_pre'):
                    summed_target = self.source
                    orig_varname = varname[:-4]
                else:
                    summed_target = self.target
                    orig_varname = varname[:-5]

                target_eq = getattr(summed_target, 'equations', {}).get(orig_varname, None)
                if target_eq is None or target_eq.type != PARAMETER:
                    raise ValueError(('The summed variable "%s" needs a '
                                      'corresponding parameter "%s" in the '
                                      'target group.') % (varname,
                                                          orig_varname))

                fail_for_dimension_mismatch(self.variables['_summed_'+varname].unit,
                                            self.variables[varname].unit,
                                            ('Summed variables need to have '
                                             'the same units in Synapses '
                                             'and the target group'))
                if self.variables[varname] in summed_targets:
                    raise ValueError(('The target variable "%s" is already '
                                      'updated by another summed '
                                      'variable') % orig_varname)
                summed_targets.add(self.variables[varname])
                updater = SummedVariableUpdater(single_equation.expr,
                                                varname, self, summed_target)
                self.summed_updaters[varname] = updater
                self.contained_objects.append(updater)

        # Do an initial connect, if requested
        if not isinstance(connect, (bool, basestring)):
            raise TypeError(('"connect" keyword has to be a boolean value or a '
                             'string, is type %s instead.' % type(connect)))
        self._initial_connect = connect
        if not connect is False:
            self.connect(connect, level=1)

        # Activate name attribute access
        self._enable_group_attributes()
Exemplo n.º 45
0
def check_units_statements(code, variables):
    """
    Check the units for a series of statements. Setting a model variable has to
    use the correct unit. For newly introduced temporary variables, the unit
    is determined and used to check the following statements to ensure
    consistency.
    
    Parameters
    ----------
    code : str
        The statements as a (multi-line) string
    variables : dict of `Variable` objects
        The information about all variables used in `code` (including
        `Constant` objects for external variables)
    
    Raises
    ------
    KeyError
        In case on of the identifiers cannot be resolved.
    DimensionMismatchError
        If an unit mismatch occurs during the evaluation.
    """
    variables = dict(variables)
    # Avoid a circular import
    from brian2.codegen.translation import analyse_identifiers
    newly_defined, _, unknown = analyse_identifiers(code, variables)

    if len(unknown):
        raise AssertionError(
            f"Encountered unknown identifiers, this should not "
            f"happen at this stage. Unknown identifiers: {unknown}")

    code = re.split(r'[;\n]', code)
    for line in code:
        line = line.strip()
        if not len(line):
            continue  # skip empty lines

        varname, op, expr, comment = parse_statement(line)
        if op in ('+=', '-=', '*=', '/=', '%='):
            # Replace statements such as "w *=2" by "w = w * 2"
            expr = f'{varname} {op[0]} {expr}'
        elif op == '=':
            pass
        else:
            raise AssertionError(f'Unknown operator "{op}"')

        expr_unit = parse_expression_dimensions(expr, variables)

        if varname in variables:
            expected_unit = variables[varname].dim
            fail_for_dimension_mismatch(expr_unit,
                                        expected_unit,
                                        ('The right-hand-side of code '
                                         'statement "%s" does not have the '
                                         'expected unit {expected}') % line,
                                        expected=expected_unit)
        elif varname in newly_defined:
            # note the unit for later
            variables[varname] = Variable(name=varname,
                                          dimensions=get_dimensions(expr_unit),
                                          scalar=False)
        else:
            raise AssertionError(
                f"Variable '{varname}' is neither in the variables "
                f"dictionary nor in the list of undefined "
                f"variables.")
Exemplo n.º 46
0
    def __init__(self, source, target=None, model=None, pre=None, post=None,
                 connect=False, delay=None, namespace=None, dtype=None,
                 codeobj_class=None,
                 clock=None, method=None, name='synapses*'):
        
        BrianObject.__init__(self, when=clock, name=name)
        
        self.codeobj_class = codeobj_class

        self.source = weakref.proxy(source)
        if target is None:
            self.target = self.source
        else:
            self.target = weakref.proxy(target)
            
        ##### Prepare and validate equations
        if model is None:
            model = ''

        if isinstance(model, basestring):
            model = Equations(model)
        if not isinstance(model, Equations):
            raise TypeError(('model has to be a string or an Equations '
                             'object, is "%s" instead.') % type(model))

        # Check flags
        model.check_flags({DIFFERENTIAL_EQUATION: ['event-driven', 'lumped'],
                           STATIC_EQUATION: ['lumped'],
                           PARAMETER: ['constant', 'lumped']})

        # Separate the equations into event-driven and continuously updated
        # equations
        event_driven = []
        continuous = []
        for single_equation in model.itervalues():
            if 'event-driven' in single_equation.flags:
                if 'lumped' in single_equation.flags:
                    raise ValueError(('Event-driven variable %s cannot be '
                                      'a lumped variable.') % single_equation.varname)
                event_driven.append(single_equation)
            else:
                continuous.append(single_equation)
        # Add the lastupdate variable, used by event-driven equations
        continuous.append(SingleEquation(PARAMETER, 'lastupdate', second))

        if len(event_driven):
            self.event_driven = Equations(event_driven)
        else:
            self.event_driven = None

        self.equations = Equations(continuous)

        ##### Setup the memory
        self.arrays = self._allocate_memory(dtype=dtype)

        # Setup the namespace
        self._given_namespace = namespace
        self.namespace = create_namespace(namespace)

        self._queues = {}
        self._delays = {}

        self.item_mapping = SynapticItemMapping(self)
        self.indices = {'_idx': self.item_mapping,
                        '_presynaptic_idx': self.item_mapping.synaptic_pre,
                        '_postsynaptic_idx': self.item_mapping.synaptic_post}
        # Allow S.i instead of S.indices.i, etc.
        self.i = self.item_mapping.i
        self.j = self.item_mapping.j
        self.k = self.item_mapping.k

        # Setup variables
        self.variables = self._create_variables()

        #: List of names of all updaters, e.g. ['pre', 'post']
        self._updaters = []
        for prepost, argument in zip(('pre', 'post'), (pre, post)):
            if not argument:
                continue
            if isinstance(argument, basestring):
                self._add_updater(argument, prepost)
            elif isinstance(argument, collections.Mapping):
                for key, value in argument.iteritems():
                    if not isinstance(key, basestring):
                        err_msg = ('Keys for the "{}" argument'
                                   'have to be strings, got '
                                   '{} instead.').format(prepost, type(key))
                        raise TypeError(err_msg)
                    self._add_updater(value, prepost, objname=key)

        # If we have a pathway called "pre" (the most common use case), provide
        # direct access to its delay via a delay attribute (instead of having
        # to use pre.delay)
        if 'pre' in self._updaters:
            self.variables['delay'] = self.pre.variables['delay']

        if delay is not None:
            if isinstance(delay, Quantity):
                if not 'pre' in self._updaters:
                    raise ValueError(('Cannot set delay, no "pre" pathway exists.'
                                      'Use a dictionary if you want to set the '
                                      'delay for a pathway with a different name.'))
                delay = {'pre': delay}

            if not isinstance(delay, collections.Mapping):
                raise TypeError('Delay argument has to be a quantity or a '
                                'dictionary, is type %s instead.' % type(delay))
            for pathway, pathway_delay in delay.iteritems():
                if not pathway in self._updaters:
                    raise ValueError(('Cannot set the delay for pathway '
                                      '"%s": unknown pathway.') % pathway)
                if not isinstance(pathway_delay, Quantity):
                    raise TypeError(('Cannot set the delay for pathway "%s": '
                                     'expected a quantity, got %s instead.') % (pathway,
                                                                                type(pathway_delay)))
                if pathway_delay.size != 1:
                    raise TypeError(('Cannot set the delay for pathway "%s": '
                                     'expected a scalar quantity, got a '
                                     'quantity with shape %s instead.') % str(pathway_delay.shape))
                fail_for_dimension_mismatch(pathway_delay, second, ('Delay has to be '
                                                                    'specified in units '
                                                                    'of seconds'))
                updater = getattr(self, pathway)
                self.item_mapping.unregister_variable(updater._delays)
                del updater._delays
                # For simplicity, store the delay as a one-element array
                # so that for example updater._delays[:] works.
                updater._delays = np.array([float(pathway_delay)])
                variable = ArrayVariable('delay', second, updater._delays,
                                          group_name=self.name, scalar=True)
                updater.variables['delay'] = variable
                if pathway == 'pre':
                    self.variables['delay'] = variable

        #: Performs numerical integration step
        self.state_updater = StateUpdater(self, method)        
        self.contained_objects.append(self.state_updater)

        #: "Lumped variable" mechanism -- sum over all synapses of a
        #: postsynaptic target
        self.lumped_updaters = {}
        for single_equation in self.equations.itervalues():
            if 'lumped' in single_equation.flags:
                varname = single_equation.varname
                # For a lumped variable, we need an equivalent parameter in the
                # target group
                if not varname in self.target.variables:
                    raise ValueError(('The lumped variable %s needs a variable '
                                      'of the same name in the target '
                                      'group ') % single_equation.varname)
                fail_for_dimension_mismatch(self.variables[varname].unit,
                                            self.target.variables[varname],
                                            ('Lumped variables need to have '
                                             'the same units in Synapses '
                                             'and the target group'))
                # TODO: Add some more stringent check about the type of
                # variable in the target group
                updater = LumpedUpdater(varname, self, self.target)
                self.lumped_updaters[varname] = updater
                self.contained_objects.append(updater)

        # Do an initial connect, if requested
        if not isinstance(connect, (bool, basestring)):
            raise TypeError(('"connect" keyword has to be a boolean value or a '
                             'string, is type %s instead.' % type(connect)))
        self._initial_connect = connect
        if not connect is False:
            self.connect(connect, level=1)

        # Activate name attribute access
        Group.__init__(self)
Exemplo n.º 47
0
    def __init__(self,
                 morphology=None,
                 model=None,
                 threshold=None,
                 refractory=False,
                 reset=None,
                 events=None,
                 threshold_location=None,
                 dt=None,
                 clock=None,
                 order=0,
                 Cm=0.9 * uF / cm**2,
                 Ri=150 * ohm * cm,
                 name='spatialneuron*',
                 dtype=None,
                 namespace=None,
                 method=('linear', 'exponential_euler', 'rk2', 'heun')):

        # #### Prepare and validate equations
        if isinstance(model, basestring):
            model = Equations(model)
        if not isinstance(model, Equations):
            raise TypeError(('model has to be a string or an Equations '
                             'object, is "%s" instead.') % type(model))

        # Insert the threshold mechanism at the specified location
        if threshold_location is not None:
            if hasattr(threshold_location,
                       '_indices'):  # assuming this is a method
                threshold_location = threshold_location._indices()
                # for now, only a single compartment allowed
                if len(threshold_location) == 1:
                    threshold_location = threshold_location[0]
                else:
                    raise AttributeError(('Threshold can only be applied on a '
                                          'single location'))
            threshold = '(' + threshold + ') and (i == ' + str(
                threshold_location) + ')'

        # Check flags (we have point currents)
        model.check_flags({
            DIFFERENTIAL_EQUATION: ('point current', ),
            PARAMETER: ('constant', 'shared', 'linked', 'point current'),
            SUBEXPRESSION: ('shared', 'point current')
        })

        # Add the membrane potential
        model += Equations('''
        v:volt # membrane potential
        ''')

        # Extract membrane equation
        if 'Im' in model:
            membrane_eq = model['Im']  # the membrane equation
        else:
            raise TypeError('The transmembrane current Im must be defined')

        # Insert point currents in the membrane equation
        for eq in model.itervalues():
            if 'point current' in eq.flags:
                fail_for_dimension_mismatch(
                    eq.unit, amp,
                    "Point current " + eq.varname + " should be in amp")
                eq.flags.remove('point current')
                membrane_eq.expr = Expression(
                    str(membrane_eq.expr.code) + '+' + eq.varname + '/area')

        ###### Process model equations (Im) to extract total conductance and the remaining current
        # Check conditional linearity with respect to v
        # Match to _A*v+_B
        var = sp.Symbol('v', real=True)
        wildcard = sp.Wild('_A', exclude=[var])
        constant_wildcard = sp.Wild('_B', exclude=[var])
        pattern = wildcard * var + constant_wildcard

        # Expand expressions in the membrane equation
        membrane_eq.type = DIFFERENTIAL_EQUATION
        for var, expr in model.get_substituted_expressions():
            if var == 'Im':
                Im_expr = expr
        membrane_eq.type = SUBEXPRESSION

        # Factor out the variable
        s_expr = sp.collect(str_to_sympy(Im_expr.code).expand(), var)
        matches = s_expr.match(pattern)

        if matches is None:
            raise TypeError, "The membrane current must be linear with respect to v"
        a, b = (matches[wildcard], matches[constant_wildcard])

        # Extracts the total conductance from Im, and the remaining current
        minusa_str, b_str = sympy_to_str(-a), sympy_to_str(b)
        # Add correct units if necessary
        if minusa_str == '0':
            minusa_str += '*siemens/meter**2'
        if b_str == '0':
            b_str += '*amp/meter**2'
        gtot_str = "gtot__private=" + minusa_str + ": siemens/meter**2"
        I0_str = "I0__private=" + b_str + ": amp/meter**2"
        model += Equations(gtot_str + "\n" + I0_str)

        # Insert morphology (store a copy)
        self.morphology = copy.deepcopy(morphology)

        # Flatten the morphology
        self.flat_morphology = FlatMorphology(morphology)

        # Equations for morphology
        # TODO: check whether Cm and Ri are already in the equations
        #       no: should be shared instead of constant
        #       yes: should be constant (check)
        eqs_constants = Equations("""
        length : meter (constant)
        distance : meter (constant)
        area : meter**2 (constant)
        volume : meter**3
        diameter : meter (constant)
        Cm : farad/meter**2 (constant)
        Ri : ohm*meter (constant, shared)
        r_length_1 : meter (constant)
        r_length_2 : meter (constant)
        time_constant = Cm/gtot__private : second
        space_constant = (2/pi)**(1.0/3.0) * (area/(1/r_length_1 + 1/r_length_2))**(1.0/6.0) /
                         (2*(Ri*gtot__private)**(1.0/2.0)) : meter
        """)
        if self.flat_morphology.has_coordinates:
            eqs_constants += Equations('''
            x : meter (constant)
            y : meter (constant)
            z : meter (constant)
            ''')

        NeuronGroup.__init__(self,
                             morphology.total_compartments,
                             model=model + eqs_constants,
                             threshold=threshold,
                             refractory=refractory,
                             reset=reset,
                             events=events,
                             method=method,
                             dt=dt,
                             clock=clock,
                             order=order,
                             namespace=namespace,
                             dtype=dtype,
                             name=name)
        # Parameters and intermediate variables for solving the cable equations
        # Note that some of these variables could have meaningful physical
        # units (e.g. _v_star is in volt, _I0_all is in amp/meter**2 etc.) but
        # since these variables should never be used in user code, we don't
        # assign them any units
        self.variables.add_arrays(
            [
                '_ab_star0',
                '_ab_star1',
                '_ab_star2',
                '_a_minus0',
                '_a_minus1',
                '_a_minus2',
                '_a_plus0',
                '_a_plus1',
                '_a_plus2',
                '_b_plus',
                '_b_minus',
                '_v_star',
                '_u_plus',
                '_u_minus',
                # The following three are for solving the
                # three tridiag systems in parallel
                '_c1',
                '_c2',
                '_c3',
                # The following two are only necessary for
                # C code where we cannot deal with scalars
                # and arrays interchangeably:
                '_I0_all',
                '_gtot_all'
            ],
            unit=1,
            size=self.N,
            read_only=True)

        self.Cm = Cm
        self.Ri = Ri
        # These explict assignments will load the morphology values from disk
        # in standalone mode
        self.distance_ = self.flat_morphology.distance
        self.length_ = self.flat_morphology.length
        self.area_ = self.flat_morphology.area
        self.diameter_ = self.flat_morphology.diameter
        self.r_length_1_ = self.flat_morphology.r_length_1
        self.r_length_2_ = self.flat_morphology.r_length_2
        if self.flat_morphology.has_coordinates:
            self.x_ = self.flat_morphology.x
            self.y_ = self.flat_morphology.y
            self.z_ = self.flat_morphology.z

        # Performs numerical integration step
        self.add_attribute('diffusion_state_updater')
        self.diffusion_state_updater = SpatialStateUpdater(self,
                                                           method,
                                                           clock=self.clock,
                                                           order=order)

        # Creation of contained_objects that do the work
        self.contained_objects.extend([self.diffusion_state_updater])
Exemplo n.º 48
0
    def __init__(self, morphology=None, model=None, threshold=None,
                 refractory=False, reset=None,
                 threshold_location=None,
                 dt=None, clock=None, order=0, Cm=0.9 * uF / cm ** 2, Ri=150 * ohm * cm,
                 name='spatialneuron*', dtype=None, namespace=None,
                 method=('linear', 'exponential_euler', 'rk2', 'milstein')):

        # #### Prepare and validate equations
        if isinstance(model, basestring):
            model = Equations(model)
        if not isinstance(model, Equations):
            raise TypeError(('model has to be a string or an Equations '
                             'object, is "%s" instead.') % type(model))

        # Insert the threshold mechanism at the specified location
        if threshold_location is not None:
            if hasattr(threshold_location,
                       '_indices'):  # assuming this is a method
                threshold_location = threshold_location._indices()
                # for now, only a single compartment allowed
                if len(threshold_location) == 1:
                    threshold_location = threshold_location[0]
                else:
                    raise AttributeError(('Threshold can only be applied on a '
                                          'single location'))
            threshold = '(' + threshold + ') and (i == ' + str(threshold_location) + ')'

        # Check flags (we have point currents)
        model.check_flags({DIFFERENTIAL_EQUATION: ('point current',),
                           PARAMETER: ('constant', 'shared', 'linked', 'point current'),
                           SUBEXPRESSION: ('shared', 'point current')})

        # Add the membrane potential
        model += Equations('''
        v:volt # membrane potential
        ''')

        # Extract membrane equation
        if 'Im' in model:
            membrane_eq = model['Im']  # the membrane equation
        else:
            raise TypeError('The transmembrane current Im must be defined')

        # Insert point currents in the membrane equation
        for eq in model.itervalues():
            if 'point current' in eq.flags:
                fail_for_dimension_mismatch(eq.unit, amp,
                                            "Point current " + eq.varname + " should be in amp")
                eq.flags.remove('point current')
                membrane_eq.expr = Expression(
                    str(membrane_eq.expr.code) + '+' + eq.varname + '/area')

        ###### Process model equations (Im) to extract total conductance and the remaining current
        # Check conditional linearity with respect to v
        # Match to _A*v+_B
        var = sp.Symbol('v', real=True)
        wildcard = sp.Wild('_A', exclude=[var])
        constant_wildcard = sp.Wild('_B', exclude=[var])
        pattern = wildcard * var + constant_wildcard

        # Expand expressions in the membrane equation
        membrane_eq.type = DIFFERENTIAL_EQUATION
        for var, expr in model._get_substituted_expressions():  # this returns substituted expressions for diff eqs
            if var == 'Im':
                Im_expr = expr
        membrane_eq.type = SUBEXPRESSION

        # Factor out the variable
        s_expr = sp.collect(Im_expr.sympy_expr.expand(), var)
        matches = s_expr.match(pattern)

        if matches is None:
            raise TypeError, "The membrane current must be linear with respect to v"
        a, b = (matches[wildcard],
                matches[constant_wildcard])

        # Extracts the total conductance from Im, and the remaining current
        minusa_str, b_str = sympy_to_str(-a), sympy_to_str(b)
        # Add correct units if necessary
        if minusa_str == '0':
            minusa_str += '*siemens/meter**2'
        if b_str == '0':
            b_str += '*amp/meter**2'
        gtot_str = "gtot__private=" + minusa_str + ": siemens/meter**2"
        I0_str = "I0__private=" + b_str + ": amp/meter**2"
        model += Equations(gtot_str + "\n" + I0_str)

        # Equations for morphology
        # TODO: check whether Cm and Ri are already in the equations
        #       no: should be shared instead of constant
        #       yes: should be constant (check)
        eqs_constants = Equations("""
        diameter : meter (constant)
        length : meter (constant)
        x : meter (constant)
        y : meter (constant)
        z : meter (constant)
        distance : meter (constant)
        area : meter**2 (constant)
        Cm : farad/meter**2 (constant)
        Ri : ohm*meter (constant, shared)
        space_constant = (diameter/(4*Ri*gtot__private))**.5 : meter # Not so sure about the name

        ### Parameters and intermediate variables for solving the cable equation
        ab_star0 : siemens/meter**2
        ab_plus0 : siemens/meter**2
        ab_minus0 : siemens/meter**2
        ab_star1 : siemens/meter**2
        ab_plus1 : siemens/meter**2
        ab_minus1 : siemens/meter**2
        ab_star2 : siemens/meter**2
        ab_plus2 : siemens/meter**2
        ab_minus2 : siemens/meter**2
        b_plus : siemens/meter**2
        b_minus : siemens/meter**2
        v_star : volt
        u_plus : 1
        u_minus : 1
        """)
        # Possibilities for the name: characteristic_length, electrotonic_length, length_constant, space_constant

        # Insert morphology
        self.morphology = morphology

        # Link morphology variables to neuron's state variables
        self.morphology_data = MorphologyData(len(morphology))
        self.morphology.compress(self.morphology_data)

        NeuronGroup.__init__(self, len(morphology), model=model + eqs_constants,
                             threshold=threshold, refractory=refractory,
                             reset=reset,
                             method=method, dt=dt, clock=clock, order=order,
                             namespace=namespace, dtype=dtype, name=name)

        self.Cm = Cm
        self.Ri = Ri
        # TODO: View instead of copy for runtime?
        self.diameter_ = self.morphology_data.diameter
        self.distance_ = self.morphology_data.distance
        self.length_ = self.morphology_data.length
        self.area_ = self.morphology_data.area
        self.x_ = self.morphology_data.x
        self.y_ = self.morphology_data.y
        self.z_ = self.morphology_data.z

        # Performs numerical integration step
        self.add_attribute('diffusion_state_updater')
        self.diffusion_state_updater = SpatialStateUpdater(self, method,
                                                           clock=self.clock,
                                                           order=order)

        # Creation of contained_objects that do the work
        self.contained_objects.extend([self.diffusion_state_updater])
Exemplo n.º 49
0
    def __init__(self, synapses, code, prepost, objname=None, delay=None):
        self.code = code
        self.prepost = prepost
        if prepost == 'pre':
            self.source = synapses.source
            self.target = synapses.target
            self.synapse_sources = synapses.variables['_synaptic_pre']
        elif prepost == 'post':
            self.source = synapses.target
            self.target = synapses.source
            self.synapse_sources = synapses.variables['_synaptic_post']
        else:
            raise ValueError('prepost argument has to be either "pre" or '
                             '"post"')
        self.synapses = synapses

        if objname is None:
            objname = prepost + '*'

        CodeRunner.__init__(self,
                            synapses,
                            'synapses',
                            code=code,
                            when=(synapses.clock, 'synapses'),
                            name=synapses.name + '_' + objname,
                            template_kwds={'pathway': self})

        self._pushspikes_codeobj = None

        self.spikes_start = self.source.start
        self.spikes_stop = self.source.stop

        self.spiking_synapses = []
        self.variables = Variables(self)
        self.variables.add_attribute_variable('_spiking_synapses',
                                              unit=Unit(1),
                                              obj=self,
                                              attribute='spiking_synapses',
                                              constant=False,
                                              scalar=False)
        self.variables.add_reference('_spikespace',
                                     self.source.variables['_spikespace'])
        self.variables.add_reference('N', synapses.variables['N'])
        if delay is None:  # variable delays
            self.variables.add_dynamic_array('delay',
                                             unit=second,
                                             size=synapses._N,
                                             constant=True,
                                             constant_size=True)
            # Register the object with the `SynapticIndex` object so it gets
            # automatically resized
            synapses.register_variable(self.variables['delay'])
        else:
            if not isinstance(delay, Quantity):
                raise TypeError(('Cannot set the delay for pathway "%s": '
                                 'expected a quantity, got %s instead.') %
                                (objname, type(delay)))
            if delay.size != 1:
                raise TypeError(
                    ('Cannot set the delay for pathway "%s": '
                     'expected a scalar quantity, got a '
                     'quantity with shape %s instead.') % str(delay.shape))
            fail_for_dimension_mismatch(delay, second, ('Delay has to be '
                                                        'specified in units '
                                                        'of seconds'))
            self.variables.add_array('delay',
                                     unit=second,
                                     size=1,
                                     constant=True,
                                     scalar=True)
            self.variables['delay'].set_value(delay)

        self._delays = self.variables['delay']

        # Re-extract the last part of the name from the full name
        self.objname = self.name[len(synapses.name) + 1:]

        #: The simulation dt (necessary for the delays)
        self.dt = self.synapses.clock.dt_

        #: The `SpikeQueue`
        self.queue = None

        #: The `CodeObject` initalising the `SpikeQueue` at the begin of a run
        self._initialise_queue_codeobj = None

        self.namespace = synapses.namespace
        # Enable access to the delay attribute via the specifier
        self._enable_group_attributes()
Exemplo n.º 50
0
def check_units_statements(code, namespace, variables):
    '''
    Check the units for a series of statements. Setting a model variable has to
    use the correct unit. For newly introduced temporary variables, the unit
    is determined and used to check the following statements to ensure
    consistency.
    
    Parameters
    ----------
    expression : str
        The expression to evaluate.
    namespace : dict-like
        The namespace of external variables.
    variables : dict of `Variable` objects
        The information about the internal variables
    
    Raises
    ------
    KeyError
        In case on of the identifiers cannot be resolved.
    DimensionMismatchError
        If an unit mismatch occurs during the evaluation.
    '''
    known = set(variables.keys()) | set(namespace.keys())
    newly_defined, _, unknown = analyse_identifiers(code, known)
    
    if len(unknown):
        raise AssertionError(('Encountered unknown identifiers, this should '
                             'not happen at this stage. Unkown identifiers: %s'
                             % unknown))
    
    # We want to add newly defined variables to the variables dictionary so we
    # make a copy now
    variables = dict(variables)
    
    code = re.split(r'[;\n]', code)
    for line in code:
        line = line.strip()
        if not len(line):
            continue  # skip empty lines
        
        varname, op, expr = parse_statement(line)
        if op in ('+=', '-=', '*=', '/=', '%='):
            # Replace statements such as "w *=2" by "w = w * 2"
            expr = '{var} {op_first} {expr}'.format(var=varname,
                                                    op_first=op[0],
                                                    expr=expr)
            op = '='
        elif op == '=':
            pass
        else:
            raise AssertionError('Unknown operator "%s"' % op) 

        expr_unit = parse_expression_unit(expr, namespace, variables)

        if varname in variables:
            fail_for_dimension_mismatch(variables[varname].unit,
                                        expr_unit,
                                        ('Code statement "%s" does not use '
                                         'correct units' % line))
        elif varname in newly_defined:
            # note the unit for later
            variables[varname] = Variable(expr_unit, is_bool=False,
                                          scalar=False)
        else:
            raise AssertionError(('Variable "%s" is neither in the variables '
                                  'dictionary nor in the list of undefined '
                                  'variables.' % varname))
Exemplo n.º 51
0
    def __init__(self, morphology=None, model=None, threshold=None,
                 refractory=False, reset=None, events=None,
                 threshold_location=None,
                 dt=None, clock=None, order=0, Cm=0.9 * uF / cm ** 2, Ri=150 * ohm * cm,
                 name='spatialneuron*', dtype=None, namespace=None,
                 method=('exact', 'exponential_euler', 'rk2', 'heun'),
                 method_options=None):

        # #### Prepare and validate equations
        if isinstance(model, basestring):
            model = Equations(model)
        if not isinstance(model, Equations):
            raise TypeError(('model has to be a string or an Equations '
                             'object, is "%s" instead.') % type(model))

        # Insert the threshold mechanism at the specified location
        if threshold_location is not None:
            if hasattr(threshold_location,
                       '_indices'):  # assuming this is a method
                threshold_location = threshold_location._indices()
                # for now, only a single compartment allowed
                if len(threshold_location) == 1:
                    threshold_location = threshold_location[0]
                else:
                    raise AttributeError(('Threshold can only be applied on a '
                                          'single location'))
            threshold = '(' + threshold + ') and (i == ' + str(threshold_location) + ')'

        # Check flags (we have point currents)
        model.check_flags({DIFFERENTIAL_EQUATION: ('point current',),
                           PARAMETER: ('constant', 'shared', 'linked', 'point current'),
                           SUBEXPRESSION: ('shared', 'point current',
                                           'constant over dt')})
        #: The original equations as specified by the user (i.e. before
        #: inserting point-currents into the membrane equation, before adding
        #: all the internally used variables and constants, etc.).
        self.user_equations = model

        # Separate subexpressions depending whether they are considered to be
        # constant over a time step or not (this would also be done by the
        # NeuronGroup initializer later, but this would give incorrect results
        # for the linearity check)
        model, constant_over_dt = extract_constant_subexpressions(model)

        # Extract membrane equation
        if 'Im' in model:
            if len(model['Im'].flags):
                raise TypeError('Cannot specify any flags for the transmembrane '
                                'current Im.')
            membrane_expr = model['Im'].expr  # the membrane equation
        else:
            raise TypeError('The transmembrane current Im must be defined')

        model_equations = []
        # Insert point currents in the membrane equation
        for eq in model.itervalues():
            if eq.varname == 'Im':
                continue  # ignore -- handled separately
            if 'point current' in eq.flags:
                fail_for_dimension_mismatch(eq.dim, amp,
                                            "Point current " + eq.varname + " should be in amp")
                membrane_expr = Expression(
                    str(membrane_expr.code) + '+' + eq.varname + '/area')
                eq = SingleEquation(eq.type, eq.varname, eq.dim, expr=eq.expr,
                                    flags=list(set(eq.flags)-set(['point current'])))
            model_equations.append(eq)

        model_equations.append(SingleEquation(SUBEXPRESSION, 'Im',
                                              dimensions=(amp/meter**2).dim,
                                              expr=membrane_expr))
        model_equations.append(SingleEquation(PARAMETER, 'v', volt.dim))
        model = Equations(model_equations)

        ###### Process model equations (Im) to extract total conductance and the remaining current
        # Expand expressions in the membrane equation
        for var, expr in model.get_substituted_expressions(include_subexpressions=True):
            if var == 'Im':
                Im_expr = expr
                break
        else:
            raise AssertionError('Model equations did not contain Im!')

        # Differentiate Im with respect to v
        Im_sympy_exp = str_to_sympy(Im_expr.code)
        v_sympy = sp.Symbol('v', real=True)
        diffed = sp.diff(Im_sympy_exp, v_sympy)

        unevaled_derivatives = diffed.atoms(sp.Derivative)
        if len(unevaled_derivatives):
            raise TypeError('Cannot take the derivative of "{Im}" with respect '
                            'to v.'.format(Im=Im_expr.code))

        gtot_str = sympy_to_str(sp.simplify(-diffed))
        I0_str = sympy_to_str(sp.simplify(Im_sympy_exp - diffed*v_sympy))

        if gtot_str == '0':
            gtot_str += '*siemens/meter**2'
        if I0_str == '0':
            I0_str += '*amp/meter**2'
        gtot_str = "gtot__private=" + gtot_str + ": siemens/meter**2"
        I0_str = "I0__private=" + I0_str + ": amp/meter**2"

        model += Equations(gtot_str + "\n" + I0_str)

        # Insert morphology (store a copy)
        self.morphology = copy.deepcopy(morphology)

        # Flatten the morphology
        self.flat_morphology = FlatMorphology(morphology)

        # Equations for morphology
        # TODO: check whether Cm and Ri are already in the equations
        #       no: should be shared instead of constant
        #       yes: should be constant (check)
        eqs_constants = Equations("""
        length : meter (constant)
        distance : meter (constant)
        area : meter**2 (constant)
        volume : meter**3
        Ic : amp/meter**2
        diameter : meter (constant)
        Cm : farad/meter**2 (constant)
        Ri : ohm*meter (constant, shared)
        r_length_1 : meter (constant)
        r_length_2 : meter (constant)
        time_constant = Cm/gtot__private : second
        space_constant = (2/pi)**(1.0/3.0) * (area/(1/r_length_1 + 1/r_length_2))**(1.0/6.0) /
                         (2*(Ri*gtot__private)**(1.0/2.0)) : meter
        """)
        if self.flat_morphology.has_coordinates:
            eqs_constants += Equations('''
            x : meter (constant)
            y : meter (constant)
            z : meter (constant)
            ''')

        NeuronGroup.__init__(self, morphology.total_compartments,
                             model=model + eqs_constants,
                             method_options=method_options,
                             threshold=threshold, refractory=refractory,
                             reset=reset, events=events,
                             method=method, dt=dt, clock=clock, order=order,
                             namespace=namespace, dtype=dtype, name=name)
        # Parameters and intermediate variables for solving the cable equations
        # Note that some of these variables could have meaningful physical
        # units (e.g. _v_star is in volt, _I0_all is in amp/meter**2 etc.) but
        # since these variables should never be used in user code, we don't
        # assign them any units
        self.variables.add_arrays(['_ab_star0', '_ab_star1', '_ab_star2',
                                   '_b_plus', '_b_minus',
                                   '_v_star', '_u_plus', '_u_minus',
                                   '_v_previous', '_c',
                                   # The following two are only necessary for
                                   # C code where we cannot deal with scalars
                                   # and arrays interchangeably:
                                   '_I0_all', '_gtot_all'],
                                  size=self.N, read_only=True)

        self.Cm = Cm
        self.Ri = Ri
        # These explict assignments will load the morphology values from disk
        # in standalone mode
        self.distance_ = self.flat_morphology.distance
        self.length_ = self.flat_morphology.length
        self.area_ = self.flat_morphology.area
        self.diameter_ = self.flat_morphology.diameter
        self.r_length_1_ = self.flat_morphology.r_length_1
        self.r_length_2_ = self.flat_morphology.r_length_2
        if self.flat_morphology.has_coordinates:
            self.x_ = self.flat_morphology.x
            self.y_ = self.flat_morphology.y
            self.z_ = self.flat_morphology.z

        # Performs numerical integration step
        self.add_attribute('diffusion_state_updater')
        self.diffusion_state_updater = SpatialStateUpdater(self, method,
                                                           clock=self.clock,
                                                           order=order)

        # Update v after the gating variables to obtain consistent Ic and Im
        self.diffusion_state_updater.order = 1

        # Creation of contained_objects that do the work
        self.contained_objects.extend([self.diffusion_state_updater])

        if len(constant_over_dt):
            self.subexpression_updater = SubexpressionUpdater(self,
                                                              constant_over_dt)
            self.contained_objects.append(self.subexpression_updater)