Ejemplo n.º 1
0
def test_parse_expression_unit():
    Var = namedtuple('Var', ['unit', 'dtype'])
    variables = {
        'a': Var(unit=volt * amp, dtype=np.float64),
        'b': Var(unit=volt, dtype=np.float64),
        'c': Var(unit=amp, dtype=np.float64)
    }
    group = SimpleGroup(namespace={}, variables=variables)
    EE = [
        (volt * amp, 'a+b*c'),
        (DimensionMismatchError, 'a+b'),
        (DimensionMismatchError, 'a<b'),
        (1, 'a<b*c'),
        (1, 'a or b'),
        (1, 'not (a >= b*c)'),
        (DimensionMismatchError, 'a or b<c'),
        (1, 'a/(b*c)<1'),
        (1, 'a/(a-a)'),
        (1, 'a<mV*mA'),
        (volt**2, 'b**2'),
        (volt * amp, 'a%(b*c)'),
        (volt, '-b'),
        (1, '(a/a)**(a/a)'),
        # Expressions involving functions
        (volt, 'rand()*b'),
        (volt**0.5, 'sqrt(b)'),
        (volt, 'ceil(b)'),
        (volt, 'sqrt(randn()*b**2)'),
        (1, 'sin(b/b)'),
        (DimensionMismatchError, 'sin(b)'),
        (DimensionMismatchError, 'sqrt(b) + b')
    ]
    for expect, expr in EE:
        all_variables = {}
        for name in get_identifiers(expr):
            if name in variables:
                all_variables[name] = variables[name]
            else:
                all_variables[name] = group.resolve(name)

        if expect is DimensionMismatchError:
            assert_raises(DimensionMismatchError, parse_expression_unit, expr,
                          all_variables)
        else:
            u = parse_expression_unit(expr, all_variables)
            assert have_same_dimensions(u, expect)

    wrong_expressions = [
        'a**b',
        'a << b',
        'int(True'  # typo
    ]
    for expr in wrong_expressions:
        all_variables = {}
        for name in get_identifiers(expr):
            if name in variables:
                all_variables[name] = variables[name]
            else:
                all_variables[name] = group.resolve(name)
        assert_raises(SyntaxError, parse_expression_unit, expr, all_variables)
Ejemplo n.º 2
0
    def __init__(self, template, template_source):
        self.template = template
        self.template_source = template_source
        #: The set of variables in this template
        self.variables = set([])
        #: The indices over which the template iterates completely
        self.iterate_all = set([])
        #: Read-only variables that are changed by this template
        self.writes_read_only = set([])
        # This is the bit inside {} for USES_VARIABLES { list of words }
        specifier_blocks = re.findall(r'\bUSES_VARIABLES\b\s*\{(.*?)\}',
                                      template_source, re.M|re.S)
        # Same for ITERATE_ALL
        iterate_all_blocks = re.findall(r'\bITERATE_ALL\b\s*\{(.*?)\}',
                                        template_source, re.M|re.S)
        # And for WRITES_TO_READ_ONLY_VARIABLES
        writes_read_only_blocks = re.findall(r'\bWRITES_TO_READ_ONLY_VARIABLES\b\s*\{(.*?)\}',
                                             template_source, re.M|re.S)
        #: Does this template allow writing to scalar variables?
        self.allows_scalar_write = 'ALLOWS_SCALAR_WRITE' in template_source

        for block in specifier_blocks:
            self.variables.update(get_identifiers(block))
        for block in iterate_all_blocks:
            self.iterate_all.update(get_identifiers(block))
        for block in writes_read_only_blocks:
            self.writes_read_only.update(get_identifiers(block))
Ejemplo n.º 3
0
 def __init__(self, template):
     self.template = template
     res = self([''])
     if isinstance(res, str):
         temps = [res]
     else:
         temps = res._templates.values()
     #: The set of words in this template
     self.words = set([])
     for v in temps:
         self.words.update(get_identifiers(v))
     #: The set of variables in this template
     self.variables = set([])
     #: The indices over which the template iterates completely
     self.iterate_all = set([])
     for v in temps:
         # This is the bit inside {} for USES_VARIABLES { list of words }
         specifier_blocks = re.findall(r'\bUSES_VARIABLES\b\s*\{(.*?)\}', v,
                                       re.M | re.S)
         # Same for ITERATE_ALL
         iterate_all_blocks = re.findall(r'\bITERATE_ALL\b\s*\{(.*?)\}', v,
                                         re.M | re.S)
         for block in specifier_blocks:
             self.variables.update(get_identifiers(block))
         for block in iterate_all_blocks:
             self.iterate_all.update(get_identifiers(block))
Ejemplo n.º 4
0
 def __init__(self, template):
     self.template = template
     res = self([''])
     if isinstance(res, str):
         temps = [res]
     else:
         temps = res._templates.values()
     #: The set of words in this template
     self.words = set([])
     for v in temps:
         self.words.update(get_identifiers(v))
     #: The set of variables in this template
     self.variables = set([])
     #: The indices over which the template iterates completely
     self.iterate_all = set([])
     for v in temps:
         # This is the bit inside {} for USES_VARIABLES { list of words }
         specifier_blocks = re.findall(r'\bUSES_VARIABLES\b\s*\{(.*?)\}',
                                       v, re.M|re.S)
         # Same for ITERATE_ALL
         iterate_all_blocks = re.findall(r'\bITERATE_ALL\b\s*\{(.*?)\}',
                           v, re.M|re.S)
         for block in specifier_blocks:
             self.variables.update(get_identifiers(block))
         for block in iterate_all_blocks:
             self.iterate_all.update(get_identifiers(block))
Ejemplo n.º 5
0
    def __init__(self, template, template_source):
        self.template = template
        self.template_source = template_source
        #: The set of variables in this template
        self.variables = set([])
        #: The indices over which the template iterates completely
        self.iterate_all = set([])
        #: Read-only variables that are changed by this template
        self.writes_read_only = set([])
        # This is the bit inside {} for USES_VARIABLES { list of words }
        specifier_blocks = re.findall(r'\bUSES_VARIABLES\b\s*\{(.*?)\}',
                                      template_source, re.M | re.S)
        # Same for ITERATE_ALL
        iterate_all_blocks = re.findall(r'\bITERATE_ALL\b\s*\{(.*?)\}',
                                        template_source, re.M | re.S)
        # And for WRITES_TO_READ_ONLY_VARIABLES
        writes_read_only_blocks = re.findall(
            r'\bWRITES_TO_READ_ONLY_VARIABLES\b\s*\{(.*?)\}', template_source,
            re.M | re.S)
        #: Does this template allow writing to scalar variables?
        self.allows_scalar_write = 'ALLOWS_SCALAR_WRITE' in template_source

        for block in specifier_blocks:
            self.variables.update(get_identifiers(block))
        for block in iterate_all_blocks:
            self.iterate_all.update(get_identifiers(block))
        for block in writes_read_only_blocks:
            self.writes_read_only.update(get_identifiers(block))
Ejemplo n.º 6
0
def test_parse_expression_unit():
    Var = namedtuple('Var', ['dim', 'dtype'])
    variables = {'a': Var(dim=(volt*amp).dim, dtype=np.float64),
                 'b': Var(dim=volt.dim, dtype=np.float64),
                 'c': Var(dim=amp.dim, dtype=np.float64)}
    group = SimpleGroup(namespace={}, variables=variables)
    EE = [
        (volt*amp, 'a+b*c'),
        (DimensionMismatchError, 'a+b'),
        (DimensionMismatchError, 'a<b'),
        (1, 'a<b*c'),
        (1, 'a or b'),
        (1, 'not (a >= b*c)'),
        (DimensionMismatchError, 'a or b<c'),
        (1, 'a/(b*c)<1'),
        (1, 'a/(a-a)'),
        (1, 'a<mV*mA'),
        (volt**2, 'b**2'),
        (volt*amp, 'a%(b*c)'),
        (volt, '-b'),
        (1, '(a/a)**(a/a)'),
        # Expressions involving functions
        (volt, 'rand()*b'),
        (volt**0.5, 'sqrt(b)'),
        (volt, 'ceil(b)'),
        (volt, 'sqrt(randn()*b**2)'),
        (1, 'sin(b/b)'),
        (DimensionMismatchError, 'sin(b)'),
        (DimensionMismatchError, 'sqrt(b) + b'),
        (SyntaxError, 'sqrt(b, b)'),
        (SyntaxError, 'sqrt()'),
        (SyntaxError, 'int(1, 2)'),
        ]
    for expect, expr in EE:
        all_variables = {}
        for name in get_identifiers(expr):
            if name in variables:
                all_variables[name] = variables[name]
            else:
                all_variables[name] = group._resolve(name, {})

        if isinstance(expect, type) and issubclass(expect, Exception):
            assert_raises(expect, parse_expression_dimensions, expr,
                          all_variables)
        else:
            u = parse_expression_dimensions(expr, all_variables)
            assert have_same_dimensions(u, expect)

    wrong_expressions = ['a**b',
                         'a << b',
                         'int(True' # typo
                        ]
    for expr in wrong_expressions:
        all_variables = {}
        for name in get_identifiers(expr):
            if name in variables:
                all_variables[name] = variables[name]
            else:
                all_variables[name] = group._resolve(name, {})
        assert_raises(SyntaxError, parse_expression_dimensions, expr, all_variables)
Ejemplo n.º 7
0
def collect_PoissonGroup(poisson_grp, run_namespace):
    """
    Extract information from 'brian2.input.poissongroup.PoissonGroup'
    and represent them in a dictionary format

    Parameters
    ----------
    poisson_grp : brian2.input.poissongroup.PoissonGroup
            PoissonGroup object

    run_namespace : dict
            Namespace dictionary

    Returns
    -------
    poisson_grp_dict : dict
                Dictionary with extracted information
    """

    poisson_grp_dict = {}
    poisson_identifiers = set()

    # get name
    poisson_grp_dict['name'] = poisson_grp._name

    # get size
    poisson_grp_dict['N'] = poisson_grp._N

    # get rates (can be Quantity or str)
    poisson_grp_dict['rates'] = poisson_grp._rates
    if isinstance(poisson_grp._rates, str):
        poisson_identifiers |= (get_identifiers(poisson_grp._rates))

    # `run_regularly` / CodeRunner objects of poisson_grp
    for obj in poisson_grp.contained_objects:
        if type(obj) == CodeRunner:
            if 'run_regularly' not in poisson_grp_dict:
                poisson_grp_dict['run_regularly'] = []
            poisson_grp_dict['run_regularly'].append({
                'name': obj.name,
                'code': obj.abstract_code,
                'dt': obj.clock.dt,
                'when': obj.when,
                'order': obj.order
            })
            poisson_identifiers = (poisson_identifiers
                                   | get_identifiers(obj.abstract_code))
    # resolve group-specific identifiers
    poisson_identifiers = poisson_grp.resolve_all(poisson_identifiers,
                                                  run_namespace)
    # with the identifiers connected to group, prune away unwanted
    poisson_identifiers = _prepare_identifiers(poisson_identifiers)
    # check the dictionary is not empty
    if poisson_identifiers:
        poisson_grp_dict['identifiers'] = poisson_identifiers

    return poisson_grp_dict
Ejemplo n.º 8
0
def collect_Events(group):
    """
    Collect Events (spiking) of the NeuronGroup

    Parameters
    ----------
    group : brian2.groups.neurongroup.NeuronGroup
        NeuronGroup object

    Returns
    -------
    event_dict : dict
        Dictionary with extracted information

    event_identifiers : set
        Set of identifiers related to events
    """

    event_dict = {}
    event_identifiers = set()

    # loop over the thresholder to check `spike` or custom event
    for event in group.thresholder:
        # for simplicity create subdict variable for particular event
        event_dict[event] = {}
        event_subdict = event_dict[event]
        # add threshold
        event_subdict['threshold'] = {
            'code': group.events[event],
            'when': group.thresholder[event].when,
            'order': group.thresholder[event].order,
            'dt': group.thresholder[event].clock.dt
        }
        event_identifiers |= get_identifiers(group.events[event])

        # check reset is defined
        if event in group.event_codes:
            event_subdict['reset'] = {
                'code': group.event_codes[event],
                'when': group.resetter[event].when,
                'order': group.resetter[event].order,
                'dt': group.resetter[event].clock.dt
            }
            event_identifiers |= get_identifiers(group.event_codes[event])

    # check refractory is defined (only for spike event)
    if event == 'spike' and group._refractory:
        event_subdict['refractory'] = group._refractory

    return event_dict, event_identifiers
Ejemplo n.º 9
0
def test_parse_expression_unit_functions(expr, correct):
    Var = namedtuple('Var', ['dim', 'dtype'])

    def foo(x, y, z):
        return (x + z) * y

    variables = {
        'a':
        Var(dim=(volt * amp).dim, dtype=np.float64),
        'b':
        Var(dim=volt.dim, dtype=np.float64),
        'c':
        Var(dim=amp.dim, dtype=np.float64),
        'd':
        Var(dim=DIMENSIONLESS, dtype=np.float64),
        'foo':
        Function(pyfunc=foo,
                 arg_units=[None, volt, 'x'],
                 arg_names=['x', 'y', 'z'],
                 return_unit=lambda x, y, z: x * y)
    }
    all_variables = {}
    group = SimpleGroup(namespace={}, variables=variables)
    for name in get_identifiers(expr):
        if name in variables:
            all_variables[name] = variables[name]
        else:
            all_variables[name] = group._resolve(name, {})
    if correct:
        assert isinstance(parse_expression_dimensions(expr, all_variables),
                          Dimension)
    else:
        with pytest.raises(DimensionMismatchError):
            parse_expression_dimensions(expr, all_variables)
Ejemplo n.º 10
0
    def update_abstract_code(self, run_namespace):
        code = self.group.events[self.event]
        # Raise a useful error message when the user used a Brian1 syntax
        if not isinstance(code, str):
            if isinstance(code, Quantity):
                t = 'a quantity'
            else:
                t = f'{type(code)}'
            error_msg = f'Threshold condition has to be a string, not {t}.'
            if self.event == 'spike':
                try:
                    vm_var = _guess_membrane_potential(self.group.equations)
                except AttributeError:  # not a group with equations...
                    vm_var = None
                if vm_var is not None:
                    error_msg += f" Probably you intended to use '{vm_var} > ...'?"
            raise TypeError(error_msg)

        self.user_code = f"_cond = {code}"

        identifiers = get_identifiers(code)
        variables = self.group.resolve_all(identifiers,
                                           run_namespace,
                                           user_identifiers=identifiers)
        if not is_boolean_expression(code, variables):
            raise TypeError(f"Threshold condition '{code}' is not a boolean "
                            f"expression")
        if self.group._refractory is False or self.event != 'spike':
            self.abstract_code = f'_cond = {code}'
        else:
            self.abstract_code = f'_cond = ({code}) and not_refractory'
Ejemplo n.º 11
0
def parse_expressions(renderer, evaluator, numvalues=10):
    exprs = [([m for m in get_identifiers(l) if len(m) == 1], [], l.strip())
             for l in TEST_EXPRESSIONS.split('\n') if l.strip()]
    i, imod = 1, 33
    for varids, funcids, expr in exprs:
        pexpr = renderer.render_expr(expr)
        n = 0
        for _ in range(numvalues):
            # assign some random values
            ns = {}
            for v in varids:
                if v in ['n', 'm']:  # integer values
                    ns[v] = i
                else:
                    ns[v] = float(i) / imod
                i = i % imod + 1
            r1 = eval(expr.replace('&', ' and ').replace('|', ' or '), ns)
            n += 1
            r2 = evaluator(pexpr, ns)
            try:
                # Use all close because we can introduce small numerical
                # difference through sympy's rearrangements
                assert_allclose(r1, r2, atol=10)
            except AssertionError as e:
                raise AssertionError("In expression " + str(expr) +
                                     " translated to " + str(pexpr) + " " +
                                     str(e))
Ejemplo n.º 12
0
    def _check_expression_scalar(self, expr, varname, level=0, run_namespace=None):
        """
        Helper function to check that an expression only refers to scalar
        variables, used when setting a scalar variable with a string expression.

        Parameters
        ----------
        expr : str
            The expression to check.
        varname : str
            The variable that is being set (only used for the error message)
        level : int, optional
            How far to go up in the stack to find the local namespace (if
            `run_namespace` is not set).
        run_namespace : dict-like, optional
            A specific namespace provided for this expression.

        Raises
        ------
        ValueError
            If the expression refers to a non-scalar variable.
        """
        identifiers = get_identifiers(expr)
        referred_variables = self.resolve_all(identifiers, run_namespace=run_namespace, level=level + 1)
        for ref_varname, ref_var in referred_variables.iteritems():
            if not getattr(ref_var, "scalar", False):
                raise ValueError(
                    ("String expression for setting scalar " "variable %s refers to %s which is not " "scalar.")
                    % (varname, ref_varname)
                )
Ejemplo n.º 13
0
def analyse_identifiers(code, variables, recursive=False):
    '''
    Analyses a code string (sequence of statements) to find all identifiers by type.
    
    In a given code block, some variable names (identifiers) must be given as inputs to the code
    block, and some are created by the code block. For example, the line::
    
        a = b+c
        
    This could mean to create a new variable a from b and c, or it could mean modify the existing
    value of a from b or c, depending on whether a was previously known.
    
    Parameters
    ----------
    code : str
        The code string, a sequence of statements one per line.
    variables : dict of `Variable`, set of names
        Specifiers for the model variables or a set of known names
    recursive : bool, optional
        Whether to recurse down into subexpressions (defaults to ``False``).
    
    Returns
    -------
    newly_defined : set
        A set of variables that are created by the code block.
    used_known : set
        A set of variables that are used and already known, a subset of the
        ``known`` parameter.
    unknown : set
        A set of variables which are used by the code block but not defined by
        it and not previously known. Should correspond to variables in the
        external namespace.
    '''
    if isinstance(variables, collections.Mapping):
        known = set(k for k, v in variables.iteritems()
                    if not isinstance(k, AuxiliaryVariable))
    else:
        known = set(variables)
        variables = dict(
            (k, Variable(unit=None, name=k, dtype=np.float64)) for k in known)

    known |= STANDARD_IDENTIFIERS
    scalar_stmts, vector_stmts = make_statements(code, variables, np.float64)
    stmts = scalar_stmts + vector_stmts
    defined = set(stmt.var for stmt in stmts if stmt.op == ':=')
    if len(stmts) == 0:
        allids = set()
    elif recursive:
        if not isinstance(variables, collections.Mapping):
            raise TypeError('Have to specify a variables dictionary.')
        allids = get_identifiers_recursively(
            [stmt.expr
             for stmt in stmts], variables) | set([stmt.var for stmt in stmts])
    else:
        allids = set.union(*[get_identifiers(stmt.expr)
                             for stmt in stmts]) | set(
                                 [stmt.var for stmt in stmts])
    dependent = allids.difference(defined, known)
    used_known = allids.intersection(known) - STANDARD_IDENTIFIERS
    return defined, used_known, dependent
Ejemplo n.º 14
0
def get_identifiers_recursively(expressions, variables, include_numbers=False):
    '''
    Gets all the identifiers in a list of expressions, recursing down into
    subexpressions.

    Parameters
    ----------
    expressions : list of str
        List of expressions to check.
    variables : dict-like
        Dictionary of `Variable` objects
    include_numbers : bool, optional
        Whether to include number literals in the output. Defaults to ``False``.
    '''
    if len(expressions):
        identifiers = set.union(*[
            get_identifiers(expr, include_numbers=include_numbers)
            for expr in expressions
        ])
    else:
        identifiers = set()
    for name in set(identifiers):
        if name in variables and isinstance(variables[name], Subexpression):
            s_identifiers = get_identifiers_recursively(
                [variables[name].expr],
                variables,
                include_numbers=include_numbers)
            identifiers |= s_identifiers
    return identifiers
Ejemplo n.º 15
0
def get_identifiers_recursively(expressions, variables, include_numbers=False):
    '''
    Gets all the identifiers in a list of expressions, recursing down into
    subexpressions.

    Parameters
    ----------
    expressions : list of str
        List of expressions to check.
    variables : dict-like
        Dictionary of `Variable` objects
    include_numbers : bool, optional
        Whether to include number literals in the output. Defaults to ``False``.
    '''
    if len(expressions):
        identifiers = set.union(*[get_identifiers(expr, include_numbers=include_numbers)
                                  for expr in expressions])
    else:
        identifiers = set()
    for name in set(identifiers):
        if name in variables and isinstance(variables[name], Subexpression):
            s_identifiers = get_identifiers_recursively([variables[name].expr],
                                                        variables,
                                                        include_numbers=include_numbers)
            identifiers |= s_identifiers
    return identifiers
Ejemplo n.º 16
0
def analyse_identifiers(code, variables, recursive=False):
    '''
    Analyses a code string (sequence of statements) to find all identifiers by type.
    
    In a given code block, some variable names (identifiers) must be given as inputs to the code
    block, and some are created by the code block. For example, the line::
    
        a = b+c
        
    This could mean to create a new variable a from b and c, or it could mean modify the existing
    value of a from b or c, depending on whether a was previously known.
    
    Parameters
    ----------
    code : str
        The code string, a sequence of statements one per line.
    variables : dict of `Variable`, set of names
        Specifiers for the model variables or a set of known names
    recursive : bool, optional
        Whether to recurse down into subexpressions (defaults to ``False``).
    
    Returns
    -------
    newly_defined : set
        A set of variables that are created by the code block.
    used_known : set
        A set of variables that are used and already known, a subset of the
        ``known`` parameter.
    unknown : set
        A set of variables which are used by the code block but not defined by
        it and not previously known. Should correspond to variables in the
        external namespace.
    '''
    if isinstance(variables, collections.Mapping):
        known = set(k for k, v in variables.iteritems()
                    if not isinstance(k, AuxiliaryVariable))
    else:
        known = set(variables)
        variables = dict((k, Variable(unit=None, name=k,
                                      dtype=np.float64))
                         for k in known)

    known |= STANDARD_IDENTIFIERS
    scalar_stmts, vector_stmts = make_statements(code, variables, np.float64, optimise=False)
    stmts = scalar_stmts + vector_stmts
    defined = set(stmt.var for stmt in stmts if stmt.op == ':=')
    if len(stmts) == 0:
        allids = set()
    elif recursive:
        if not isinstance(variables, collections.Mapping):
            raise TypeError('Have to specify a variables dictionary.')
        allids = get_identifiers_recursively([stmt.expr for stmt in stmts],
                                             variables) | set([stmt.var
                                                               for stmt in stmts])
    else:
        allids = set.union(*[get_identifiers(stmt.expr)
                             for stmt in stmts]) | set([stmt.var for stmt in stmts])
    dependent = allids.difference(defined, known)
    used_known = allids.intersection(known) - STANDARD_IDENTIFIERS
    return defined, used_known, dependent
Ejemplo n.º 17
0
    def __init__(self, code):

        # : The code string
        self.code = code

        # : Set of identifiers in the code string
        self.identifiers = get_identifiers(code)
Ejemplo n.º 18
0
 def __init__(self, template, template_source):
     self.template = template
     #: The set of variables in this template
     self.variables = set([])
     #: The indices over which the template iterates completely
     self.iterate_all = set([])
     # This is the bit inside {} for USES_VARIABLES { list of words }
     specifier_blocks = re.findall(r'\bUSES_VARIABLES\b\s*\{(.*?)\}',
                                   template_source, re.M|re.S)
     # Same for ITERATE_ALL
     iterate_all_blocks = re.findall(r'\bITERATE_ALL\b\s*\{(.*?)\}',
                                     template_source, re.M|re.S)
     for block in specifier_blocks:
         self.variables.update(get_identifiers(block))
     for block in iterate_all_blocks:
         self.iterate_all.update(get_identifiers(block))
Ejemplo n.º 19
0
 def variableview_set_with_expression_conditional(self,
                                                  variableview,
                                                  cond,
                                                  code,
                                                  run_namespace,
                                                  check_units=True):
     """
     Capture setters with conditioanl expressions,
     for eg. obj.var['i>5'] = 'rand() * -78 * mV'
     """
     # get resolved and clean identifiers
     ident_set = get_identifiers(code)
     ident_dict = variableview.group.resolve_all(ident_set, run_namespace)
     ident_dict = _prepare_identifiers(ident_dict)
     init_dict = {
         'source': variableview.group.name,
         'variable': variableview.name,
         'index': cond,
         'value': code,
         'type': 'initializer'
     }
     # if identifiers are defined, then add the field
     if ident_dict:
         init_dict.update({'identifiers': ident_dict})
     self.initializers_connectors.append(init_dict)
Ejemplo n.º 20
0
    def update_abstract_code(self, run_namespace):
        code = self.group.events[self.event]
        # Raise a useful error message when the user used a Brian1 syntax
        if not isinstance(code, basestring):
            if isinstance(code, Quantity):
                t = 'a quantity'
            else:
                t = '%s' % type(code)
            error_msg = 'Threshold condition has to be a string, not %s.' % t
            if self.event == 'spike':
                try:
                    vm_var = _guess_membrane_potential(self.group.equations)
                except AttributeError:  # not a group with equations...
                    vm_var = None
                if vm_var is not None:
                    error_msg += " Probably you intended to use '%s > ...'?" % vm_var
            raise TypeError(error_msg)

        self.user_code = '_cond = ' + code

        identifiers = get_identifiers(code)
        variables = self.group.resolve_all(identifiers,
                                           run_namespace,
                                           user_identifiers=identifiers)
        if not is_boolean_expression(code, variables):
            raise TypeError(('Threshold condition "%s" is not a boolean '
                             'expression') % code)
        if self.group._refractory is False or self.event != 'spike':
            self.abstract_code = '_cond = %s' % code
        else:
            self.abstract_code = '_cond = (%s) and not_refractory' % code
Ejemplo n.º 21
0
    def update_abstract_code(self, run_namespace):
        code = self.group.events[self.event]
        # Raise a useful error message when the user used a Brian1 syntax
        if not isinstance(code, basestring):
            if isinstance(code, Quantity):
                t = 'a quantity'
            else:
                t = '%s' % type(code)
            error_msg = 'Threshold condition has to be a string, not %s.' % t
            if self.event == 'spike':
                try:
                    vm_var = _guess_membrane_potential(self.group.equations)
                except AttributeError:  # not a group with equations...
                    vm_var = None
                if vm_var is not None:
                    error_msg += " Probably you intended to use '%s > ...'?" % vm_var
            raise TypeError(error_msg)

        self.user_code = '_cond = ' + code

        identifiers = get_identifiers(code)
        variables = self.group.resolve_all(identifiers,
                                           identifiers,
                                           run_namespace=run_namespace)
        if not is_boolean_expression(code, variables):
            raise TypeError(('Threshold condition "%s" is not a boolean '
                             'expression') % code)
        if self.group._refractory is False or self.event != 'spike':
            self.abstract_code = '_cond = %s' % code
        else:
            self.abstract_code = '_cond = (%s) and not_refractory' % code
Ejemplo n.º 22
0
    def _check_expression_scalar(self,
                                 expr,
                                 varname,
                                 level=0,
                                 run_namespace=None):
        '''
        Helper function to check that an expression only refers to scalar
        variables, used when setting a scalar variable with a string expression.

        Parameters
        ----------
        expr : str
            The expression to check.
        varname : str
            The variable that is being set (only used for the error message)
        level : int, optional
            How far to go up in the stack to find the local namespace (if
            `run_namespace` is not set).
        run_namespace : dict-like, optional
            A specific namespace provided for this expression.

        Raises
        ------
        ValueError
            If the expression refers to a non-scalar variable.
        '''
        identifiers = get_identifiers(expr)
        referred_variables = self.resolve_all(identifiers,
                                              run_namespace=run_namespace,
                                              level=level + 1)
        for ref_varname, ref_var in referred_variables.iteritems():
            if not getattr(ref_var, 'scalar', False):
                raise ValueError(('String expression for setting scalar '
                                  'variable %s refers to %s which is not '
                                  'scalar.') % (varname, ref_varname))
Ejemplo n.º 23
0
def collect_PoissonInput(poinp, run_namespace):
    """
    Collect details of `PoissonInput` and represent them in dictionary

    Parameters
    ----------
    poinp : brian2.input.poissoninput.PoissonInput
            PoissonInput object

    run_namespace : dict
            Namespace dictionary

    Returns
    -------
    poinp_dict : dict
            Dictionary representation of the collected details
    """
    poinp_dict = {}
    poinp_dict['target'] = poinp._group.name
    poinp_dict['rate'] = poinp.rate
    poinp_dict['N'] = poinp.N
    poinp_dict['when'] = poinp.when
    poinp_dict['order'] = poinp.order
    poinp_dict['dt'] = poinp.clock.dt
    poinp_dict['weight'] = poinp._weight
    poinp_dict['target_var'] = poinp._target_var
    # collect identifiers, resolve and prune
    if isinstance(poinp_dict['weight'], str):
        identifiers = get_identifiers(poinp_dict['weight'])
        identifiers = poinp._group.resolve_all(identifiers, run_namespace)
        identifiers = _prepare_identifiers(identifiers)
        if identifiers:
            poinp_dict['identifiers'] = identifiers

    return poinp_dict
Ejemplo n.º 24
0
def translate_subexpression(subexpr, variables):
    substitutions = {}
    for name in get_identifiers(subexpr.expr):
        if name not in subexpr.owner.variables:
            # Seems to be a name referring to an external variable,
            # nothing to do
            continue
        subexpr_var = subexpr.owner.variables[name]
        if name in variables and variables[name] is subexpr_var:
            # Variable is available under the same name, nothing to do
            continue

        # The variable is not available under the same name, but maybe
        # under a different name (e.g. x_post instead of x)
        found_variable = False
        for varname, variable in variables.iteritems():
            if variable is subexpr_var:
                # We found it
                substitutions[name] = varname
                found_variable = True
                break
        if not found_variable:
            raise KeyError(('Variable %s, referred to by the subexpression '
                            '%s, is not available in this '
                            'context.') % (name, subexpr.name))
    new_expr = word_substitute(subexpr.expr, substitutions)
    return new_expr
Ejemplo n.º 25
0
    def make_multiinstantiate(self, special_properties, name, parameters):
        """
        Adds ComponentType with MultiInstantiate in order to make
        a population of neurons.

        Parameters
        ----------
        special_properties : dict
            all variables to be defined in MultiInstantiate
        name : str
            MultiInstantiate component name
        parameters : dict
            all extra parameters needed
        """
        PARAM_SUBSCRIPT = "_p"
        self._model_namespace["ct_populationname"] = name+"Multi"
        multi_ct = lems.ComponentType(self._model_namespace["ct_populationname"], extends=BASE_POPULATION)
        structure = lems.Structure()
        multi_ins = lems.MultiInstantiate(component_type=name,
                                          number="N")
        param_dict = {}
        # number of neruons
        multi_ct.add(lems.Parameter(name="N", dimension="none"))
        # other parameters
        for sp in special_properties:
            if special_properties[sp] is None:
                multi_ct.add(lems.Parameter(name=sp+PARAM_SUBSCRIPT, dimension=self._all_params_unit[sp]))
                multi_ins.add(lems.Assign(property=sp, value=sp+PARAM_SUBSCRIPT))
                param_dict[sp] = parameters[sp]
            else:
                # multi_ct.add(lems.Parameter(name=sp, dimension=self._all_params_unit[sp]))
                # check if there are some units in equations
                equation = special_properties[sp]
                # add spaces around brackets to prevent mismatching
                equation = re.sub("\(", " ( ", equation)
                equation = re.sub("\)", " ) ", equation)
                for i in get_identifiers(equation):
                    # iterator is a special case
                    if i == "i":
                        regexp_noletter = "[^a-zA-Z0-9]"
                        equation = re.sub("{re}i{re}".format(re=regexp_noletter),
                                                  " {} ".format(INDEX), equation)
                    # here it's assumed that we don't use Netwton in neuron models
                    elif i in name_to_unit and i != "N":
                        const_i = i+'const'
                        multi_ct.add(lems.Constant(name=const_i, symbol=const_i,
                                     dimension=self._all_params_unit[sp], value="1"+i))
                        equation = re.sub(i, const_i, equation)
                multi_ins.add(lems.Assign(property=sp, value=equation))
        structure.add(multi_ins)
        multi_ct.structure = structure
        self._model.add(multi_ct)
        param_dict = dict([(k+"_p", v) for k, v in param_dict.items()])
        param_dict["N"] = self._nr_of_neurons
        self._model_namespace["populationname"] = self._model_namespace["ct_populationname"] + "pop"
        self._model_namespace["networkname"] = self._model_namespace["ct_populationname"] + "Net"
        self.add_population(self._model_namespace["networkname"],
                            self._model_namespace["populationname"],
                            self._model_namespace["ct_populationname"],
                            **param_dict)
Ejemplo n.º 26
0
 def _get_refractory_code(self, run_namespace, level=0):
     ref = self.group._refractory
     if ref is False:
         # No refractoriness
         abstract_code = ''
     elif isinstance(ref, Quantity):
         abstract_code = 'not_refractory = (t - lastspike) > %f\n' % ref
     else:
         identifiers = get_identifiers(ref)
         variables = self.group.resolve_all(identifiers,
                                            identifiers,
                                            run_namespace=run_namespace,
                                            level=level + 1)
         unit = parse_expression_unit(str(ref), variables)
         if have_same_dimensions(unit, second):
             abstract_code = 'not_refractory = (t - lastspike) > %s\n' % ref
         elif have_same_dimensions(unit, Unit(1)):
             if not is_boolean_expression(str(ref), variables):
                 raise TypeError(('Refractory expression is dimensionless '
                                  'but not a boolean value. It needs to '
                                  'either evaluate to a timespan or to a '
                                  'boolean value.'))
             # boolean condition
             # we have to be a bit careful here, we can't just use the given
             # condition as it is, because we only want to *leave*
             # refractoriness, based on the condition
             abstract_code = 'not_refractory = not_refractory or not (%s)\n' % ref
         else:
             raise TypeError(('Refractory expression has to evaluate to a '
                              'timespan or a boolean value, expression'
                              '"%s" has units %s instead') % (ref, unit))
     return abstract_code
Ejemplo n.º 27
0
def parse_expressions(renderer, evaluator, numvalues=10):
    exprs = [([m for m in get_identifiers(l) if len(m)==1], [], l.strip())
             for l in TEST_EXPRESSIONS.split('\n') if l.strip()]
    i, imod = 1, 33
    for varids, funcids, expr in exprs:
        pexpr = renderer.render_expr(expr)
        n = 0
        for _ in xrange(numvalues):
            # assign some random values
            ns = {}
            for v in varids:
                if v in ['n', 'm']:  # integer values
                    ns[v] = i
                else:
                    ns[v] = float(i)/imod
                i = i%imod+1
            r1 = eval(expr.replace('&', ' and ').replace('|', ' or '), ns)
            n += 1
            r2 = evaluator(pexpr, ns)
            try:
                # Use all close because we can introduce small numerical
                # difference through sympy's rearrangements
                assert_allclose(r1, r2, atol=10)
            except AssertionError as e:
                raise AssertionError("In expression " + str(expr) +
                                     " translated to " + str(pexpr) +
                                     " " + str(e))
Ejemplo n.º 28
0
    def variables_to_namespace(self):

        # Variables can refer to values that are either constant (e.g. dt)
        # or change every timestep (e.g. t). We add the values of the
        # constant variables here and add the names of non-constant variables
        # to a list

        # A list containing tuples of name and a function giving the value
        self.nonconstant_values = []

        for name, var in self.variables.iteritems():
            if isinstance(var, (AuxiliaryVariable, Subexpression)):
                continue
            try:
                value = var.get_value()
            except (TypeError, AttributeError):
                # A dummy Variable without value or a function
                self.namespace[name] = var
                continue

            if isinstance(var, ArrayVariable):
                self.namespace[self.device.get_array_name(var,
                                                            self.variables)] = value
                self.namespace['_num'+name] = var.get_len()
                if var.scalar and var.constant:
                    self.namespace[name] = value.item()
            else:
                self.namespace[name] = value

            if isinstance(var, DynamicArrayVariable):
                dyn_array_name = self.generator_class.get_array_name(var,
                                                                    access_data=False)
                self.namespace[dyn_array_name] = self.device.get_value(var,
                                                                       access_data=False)
            # Also provide the Variable object itself in the namespace (can be
            # necessary for resize operations, for example)
            self.namespace['_var_'+name] = var

        # Get all identifiers in the code -- note that this is not a smart
        # function, it will get identifiers from strings, comments, etc. This
        # is not a problem here, since we only use this list to filter out
        # things. If we include something incorrectly, this only means that we
        # will pass something into the namespace unnecessarily.
        all_identifiers = get_identifiers(self.code)
        # Filter out all unneeded objects
        self.namespace = {k: v for k, v in self.namespace.iteritems()
                          if k in all_identifiers}

        # There is one type of objects that we have to inject into the
        # namespace with their current value at each time step: dynamic
        # arrays that change in size during runs, where the size change is not
        # initiated by the template itself
        for name, var in self.variables.iteritems():
            if (isinstance(var, DynamicArrayVariable) and
                    var.needs_reference_update):
                array_name = self.device.get_array_name(var, self.variables)
                if array_name in self.namespace:
                    self.nonconstant_values.append((array_name, var.get_value))
                if '_num'+name in self.namespace:
                    self.nonconstant_values.append(('_num'+name, var.get_len))
Ejemplo n.º 29
0
Archivo: base.py Proyecto: yger/brian2
 def array_read_write(self, statements, variables, variable_indices):
     '''
     Helper function, gives the set of ArrayVariables that are read from and
     written to in the series of statements. Returns the pair read, write
     of sets of variable names.
     '''
     read = set()
     write = set()
     for stmt in statements:
         ids = set(get_identifiers(stmt.expr))
         # if the operation is inplace this counts as a read.
         if stmt.inplace:
             ids.add(stmt.var)
         read = read.union(ids)
         write.add(stmt.var)
     read = set(varname for varname, var in variables.items()
                if isinstance(var, ArrayVariable) and varname in read)
     write = set(varname for varname, var in variables.items()
                 if isinstance(var, ArrayVariable) and varname in write)
     # Gather the indices stored as arrays (ignore _idx which is special)
     indices = set()
     indices |= set(variable_indices[varname] for varname in read
                    if variable_indices[varname] != '_idx'
                        and isinstance(variables[variable_indices[varname]],
                                       ArrayVariable))
     indices |= set(variable_indices[varname] for varname in write
                    if variable_indices[varname] != '_idx'
                        and isinstance(variables[variable_indices[varname]],
                                       ArrayVariable))
     # don't list arrays that are read explicitly and used as indices twice
     read -= indices
     return read, write, indices
Ejemplo n.º 30
0
 def _get_refractory_code(self, run_namespace, level=0):
     ref = self.group._refractory
     if ref is False:
         # No refractoriness
         abstract_code = ''
     elif isinstance(ref, Quantity):
         abstract_code = 'not_refractory = (t - lastspike) > %f\n' % ref
     else:
         identifiers = get_identifiers(ref)
         variables = self.group.resolve_all(identifiers,
                                            identifiers,
                                            run_namespace=run_namespace,
                                            level=level+1)
         unit = parse_expression_unit(str(ref), variables)
         if have_same_dimensions(unit, second):
             abstract_code = 'not_refractory = (t - lastspike) > %s\n' % ref
         elif have_same_dimensions(unit, Unit(1)):
             if not is_boolean_expression(str(ref), variables):
                 raise TypeError(('Refractory expression is dimensionless '
                                  'but not a boolean value. It needs to '
                                  'either evaluate to a timespan or to a '
                                  'boolean value.'))
             # boolean condition
             # we have to be a bit careful here, we can't just use the given
             # condition as it is, because we only want to *leave*
             # refractoriness, based on the condition
             abstract_code = 'not_refractory = not_refractory or not (%s)\n' % ref
         else:
             raise TypeError(('Refractory expression has to evaluate to a '
                              'timespan or a boolean value, expression'
                              '"%s" has units %s instead') % (ref, unit))
     return abstract_code
Ejemplo n.º 31
0
    def update_abstract_code(self, run_namespace=None, level=0):
        code = self.group.threshold
        # Raise a useful error message when the user used a Brian1 syntax
        if not isinstance(code, basestring):
            if isinstance(code, Quantity):
                t = 'a quantity'
            else:
                t = '%s' % type(code)
            error_msg = 'Threshold condition has to be a string, not %s.' % t
            vm_var = _guess_membrane_potential(self.group.equations)
            if vm_var is not None:
                error_msg += " Probably you intended to use '%s > ...'?" % vm_var
            raise TypeError(error_msg)

        self.user_code = '_cond = ' + code

        identifiers = get_identifiers(code)
        variables = self.group.resolve_all(identifiers,
                                           identifiers,
                                           run_namespace=run_namespace,
                                           level=level + 1)
        if not is_boolean_expression(self.group.threshold, variables):
            raise TypeError(('Threshold condition "%s" is not a boolean '
                             'expression') % self.group.threshold)
        if self.group._refractory is False:
            self.abstract_code = '_cond = %s' % self.group.threshold
        else:
            self.abstract_code = '_cond = (%s) and not_refractory' % self.group.threshold
Ejemplo n.º 32
0
    def update_abstract_code(self, run_namespace=None, level=0):
        code = self.group.threshold
        # Raise a useful error message when the user used a Brian1 syntax
        if not isinstance(code, basestring):
            if isinstance(code, Quantity):
                t = 'a quantity'
            else:
                t = '%s' % type(code)
            error_msg = 'Threshold condition has to be a string, not %s.' % t
            vm_var = _guess_membrane_potential(self.group.equations)
            if vm_var is not None:
                error_msg += " Probably you intended to use '%s > ...'?" % vm_var
            raise TypeError(error_msg)

        self.user_code = '_cond = ' + code

        identifiers = get_identifiers(code)
        variables = self.group.resolve_all(identifiers,
                                           identifiers,
                                           run_namespace=run_namespace,
                                           level=level+1)
        if not is_boolean_expression(self.group.threshold, variables):
            raise TypeError(('Threshold condition "%s" is not a boolean '
                             'expression') % self.group.threshold)
        if self.group._refractory is False:
            self.abstract_code = '_cond = %s' % self.group.threshold
        else:
            self.abstract_code = '_cond = (%s) and not_refractory' % self.group.threshold
Ejemplo n.º 33
0
 def translate_one_statement_sequence(self, statements, scalar=False):
     # This function is refactored into four functions which perform the
     # four necessary operations. It's done like this so that code
     # deriving from this class can overwrite specific parts.
     lines = []
     # index and read arrays (index arrays first)
     lines += self.translate_to_read_arrays(statements)
     # simply declare variables that will be written but not read
     lines += self.translate_to_declarations(statements)
     # the actual code
     statement_lines = self.translate_to_statements(statements)
     lines += statement_lines
     # write arrays
     lines += self.translate_to_write_arrays(statements)
     code = '\n'.join(lines)
     # Check if 64bit integer types occur in the same line as a default function.
     # We can't get the arguments of the function call directly with regex due to
     # possibly nested paranthesis inside function paranthesis.
     convertion_pref = prefs.codegen.generators.cuda.default_functions_integral_convertion
     # only check if there was no warning yet or if convertion preference has changed
     if not self.warned_integral_convertion or self.previous_convertion_pref != convertion_pref:
         for line in statement_lines:
             brian_funcs = re.search(
                 '_brian_(' + '|'.join(functions_C99) + ')', line)
             if brian_funcs is not None:
                 for identifier in get_identifiers(line):
                     if convertion_pref == 'double_precision':
                         # 64bit integer to floating-point conversions are not type safe
                         int64_type = re.search(
                             r'\bu?int64_t\s*{}\b'.format(identifier), code)
                         if int64_type is not None:
                             logger.warn(
                                 "Detected code statement with default function and 64bit integer type in the same line. "
                                 "Using 64bit integer types as default function arguments is not type safe due to convertion of "
                                 "integer to 64bit floating-point types in device code. (relevant functions: sin, cos, tan, sinh, "
                                 "cosh, tanh, exp, log, log10, sqrt, ceil, floor, arcsin, arccos, arctan)\nDetected code "
                                 "statement:\n\t{}\nGenerated from abstract code statements:\n\t{}\n"
                                 .format(line, statements),
                                 once=True)
                             self.warned_integral_convertion = True
                             self.previous_convertion_pref = 'double_precision'
                     else:  # convertion_pref = 'single_precision'
                         # 32bit and 64bit integer to floating-point conversions are not type safe
                         int32_64_type = re.search(
                             r'\bu?int(32|64)_t\s*{}\b'.format(identifier),
                             code)
                         if int32_64_type is not None:
                             logger.warn(
                                 "Detected code statement with default function and 32bit or 64bit integer type in the same line and the "
                                 "preference for default_functions_integral_convertion is 'single_precision'. "
                                 "Using 32bit or 64bit integer types as default function arguments is not type safe due to convertion of "
                                 "integer to single-precision floating-point types in device code. (relevant functions: sin, cos, tan, sinh, "
                                 "cosh, tanh, exp, log, log10, sqrt, ceil, floor, arcsin, arccos, arctan)\nDetected code "
                                 "statement:\n\t{}\nGenerated from abstract code statements:\n\t{}\n"
                                 .format(line, statements),
                                 once=True)
                             self.warned_integral_convertion = True
                             self.previous_convertion_pref = 'single_precision'
     return stripped_deindented_lines(code)
Ejemplo n.º 34
0
    def __call__(self, equations, variables=None):
        
        if variables is None:
            variables = {}
        
        # Get a representation of the ODE system in the form of
        # dX/dt = M*X + B
        varnames, matrix, constants = get_linear_system(equations)

        # Make sure that the matrix M is constant, i.e. it only contains
        # external variables or constant variables
        symbols = set.union(*(el.atoms() for el in matrix))
        non_constant = _non_constant_symbols(symbols, variables)
        if len(non_constant):
            raise ValueError(('The coefficient matrix for the equations '
                              'contains the symbols %s, which are not '
                              'constant.') % str(non_constant))
        
        symbols = [Symbol(variable, real=True) for variable in varnames]
        solution = sp.solve_linear_system(matrix.row_join(constants), *symbols)
        b = sp.ImmutableMatrix([solution[symbol] for symbol in symbols]).transpose()
        
        # Solve the system
        dt = Symbol('dt', real=True, positive=True)
        A = (matrix * dt).exp()                
        C = sp.ImmutableMatrix([A.dot(b)]) - b
        _S = sp.MatrixSymbol('_S', len(varnames), 1)
        updates = A * _S + C.transpose()
        try:
            # In sympy 0.7.3, we have to explicitly convert it to a single matrix
            # In sympy 0.7.2, it is already a matrix (which doesn't have an
            # is_explicit method)
            updates = updates.as_explicit()
        except AttributeError:
            pass
        
        # The solution contains _S[0, 0], _S[1, 0] etc. for the state variables,
        # replace them with the state variable names 
        abstract_code = []
        for idx, (variable, update) in enumerate(zip(varnames, updates)):
            rhs = update.subs(_S[idx, 0], variable)
            identifiers = get_identifiers(sympy_to_str(rhs))
            for identifier in identifiers:
                if identifier in variables:
                    var = variables[identifier]
                    if var is None:
                        print identifier, variables
                    if var.scalar and var.constant:
                        float_val = var.get_value()
                        rhs = rhs.xreplace({Symbol(identifier, real=True): Float(float_val)})

            # Do not overwrite the real state variables yet, the update step
            # of other state variables might still need the original values
            abstract_code.append('_' + variable + ' = ' + sympy_to_str(rhs))
        
        # Update the state variables
        for variable in varnames:
            abstract_code.append('{variable} = _{variable}'.format(variable=variable))
        return '\n'.join(abstract_code)
Ejemplo n.º 35
0
    def vectorise_code(self, statements, variables, variable_indices, index='_idx'):

        # We treat every statement individually with its own read and write code
        # to be on the safe side
        lines = []
        created_vars = {stmt.var for stmt in statements if stmt.op == ':='}
        for statement in statements:
            lines.append('#  Abstract code:  {var} {op} {expr}'.format(var=statement.var,
                                                                       op=statement.op,
                                                                       expr=statement.expr))
            read, write, indices, conditional_write_vars = self.arrays_helper([statement])
            try:
                # We make sure that we only add code to `lines` after it went
                # through completely
                ufunc_lines = []
                # No need to load a variable if it is only in read because of
                # the in-place operation
                if (statement.inplace and
                            variable_indices[statement.var] != '_idx' and
                            statement.var not in get_identifiers(statement.expr)):
                    read = read - {statement.var}
                ufunc_lines.extend(self.read_arrays(read, write, indices,
                                              variables, variable_indices))
                ufunc_lines.append(self.ufunc_at_vectorisation(statement,
                                                               variables,
                                                               variable_indices,
                                                               conditional_write_vars,
                                                               created_vars,
                                                               index=index))
                # Do not write back such values, the ufuncs have modified the
                # underlying array already
                if statement.inplace and variable_indices[statement.var] != '_idx':
                    write = write - {statement.var}
                ufunc_lines.extend(self.write_arrays([statement], read, write,
                                                     variables,
                                                     variable_indices))
                lines.extend(ufunc_lines)
            except VectorisationError:
                logger.warn("Failed to vectorise synapses code, falling back on Python loop: note that "
                            "this will be very slow! Switch to another code generation target for "
                            "best performance (e.g. cython or weave).",
                            once=True)
                lines.extend(['_full_idx = _idx',
                              'for _idx in _full_idx:'])
                lines.extend(indent(code) for code in
                             self.read_arrays(read, write, indices,
                                              variables, variable_indices))
                line = self.translate_statement(statement)
                line = self.conditional_write(line, statement, variables,
                                              conditional_write_vars,
                                              created_vars)
                lines.append(indent(line))
                lines.extend(indent(code) for code in
                             self.write_arrays(statements, read, write,
                                               variables, variable_indices))
                lines.append('_idx = _full_idx')

        return lines
Ejemplo n.º 36
0
    def ufunc_at_vectorisation(self, statement, variables, indices,
                               conditional_write_vars, created_vars,
                               used_variables):
        if not self._use_ufunc_at_vectorisation:
            raise VectorisationError()
        # Avoids circular import
        from brian2.devices.device import device

        # See https://github.com/brian-team/brian2/pull/531 for explanation
        used = set(get_identifiers(statement.expr))
        used = used.intersection(k for k in list(variables.keys())
                                 if k in indices and indices[k] != '_idx')
        used_variables.update(used)
        if statement.var in used_variables:
            raise VectorisationError()
        expr = NumpyNodeRenderer(
            auto_vectorise=self.auto_vectorise).render_expr(statement.expr)

        if statement.op == ':=' or indices[
                statement.var] == '_idx' or not statement.inplace:
            if statement.op == ':=':
                op = '='
            else:
                op = statement.op
            line = '{var} {op} {expr}'.format(var=statement.var,
                                              op=op,
                                              expr=expr)
        elif statement.inplace:
            if statement.op == '+=':
                ufunc_name = '_numpy.add'
            elif statement.op == '*=':
                ufunc_name = '_numpy.multiply'
            elif statement.op == '/=':
                ufunc_name = '_numpy.divide'
            elif statement.op == '-=':
                ufunc_name = '_numpy.subtract'
            else:
                raise VectorisationError()

            line = '{ufunc_name}.at({array_name}, {idx}, {expr})'.format(
                ufunc_name=ufunc_name,
                array_name=device.get_array_name(variables[statement.var]),
                idx=indices[statement.var],
                expr=expr)
            line = self.conditional_write(
                line,
                statement,
                variables,
                conditional_write_vars=conditional_write_vars,
                created_vars=created_vars)
        else:
            raise VectorisationError()

        if len(statement.comment):
            line += ' # ' + statement.comment

        return line
Ejemplo n.º 37
0
 def synapses_connect(self,
                      synapses,
                      condition=None,
                      i=None,
                      j=None,
                      p=1,
                      n=1,
                      skip_if_invalid=False,
                      namespace=None,
                      level=0):
     """
     Override synapses_connect() to get details from Synapses.connect()
     """
     # if namespace not defined
     if namespace is None:
         namespace = get_local_namespace(level + 2)
     # prepare objects using `before_run()`
     synapses.before_run(namespace)
     if condition is not None:
         if i is not None or j is not None:
             raise ValueError("Cannot combine condition with i or j "
                              "arguments")
     connect = {}
     # string statements that shall have identifers
     strings_with_identifers = []
     if condition:
         connect.update({'condition': condition})
         strings_with_identifers.append(condition)
     elif i is not None or j is not None:
         if i is not None:
             connect.update({'i': i})
         if j is not None:
             connect.update({'j': j})
     connect.update({
         'probability': p,
         'n_connections': n,
         'synapses': synapses.name,
         'source': synapses.source.name,
         'target': synapses.target.name,
         'type': 'connect'
     })
     # get resolved and clean identifiers
     strings_with_identifers.append(str(p))
     strings_with_identifers.append(str(n))
     identifers_set = set()
     for string_expr in strings_with_identifers:
         identifers_set = identifers_set | get_identifiers(string_expr)
     ident_dict = synapses.resolve_all(identifers_set, namespace)
     ident_dict = _prepare_identifiers(ident_dict)
     if ident_dict:
         connect.update({'identifiers': ident_dict})
     self.initializers_connectors.append(connect)
     # update `_connect_called` to allow initialization on
     # synaptic variables
     synapses._connect_called = True
Ejemplo n.º 38
0
def check_for_order_independence(statements, variables, indices):
    '''
    '''

    main_index_variables = set([
        v for v in variables if (indices[v] in (
            '_idx', '0') or getattr(variables[indices[v]], 'unique', False))
    ])
    different_index_variables = set(variables.keys()) - main_index_variables
    all_variables = variables.keys()

    permutation_independent = list(different_index_variables)
    changed_permutation_independent = True
    while changed_permutation_independent:
        changed_permutation_independent = False
        for statement in statements:
            vars_in_expr = get_identifiers(
                statement.expr).intersection(all_variables)
            nonsyn_vars_in_expr = vars_in_expr.intersection(
                different_index_variables)
            permdep = any(var not in permutation_independent
                          for var in nonsyn_vars_in_expr)
            if statement.op == ':=':
                continue  # auxiliary variable
            elif statement.var in main_index_variables:
                if permdep:
                    raise OrderDependenceError()
            elif statement.var in different_index_variables:
                if statement.op == '+=' or statement.op == '*=':
                    if permdep:
                        raise OrderDependenceError()
                    if statement.var in permutation_independent:
                        permutation_independent.remove(statement.var)
                        changed_permutation_independent = True
                elif statement.op == '=':
                    sameidx = [
                        v for v in variables
                        if indices[v] == indices[statement.var]
                    ]
                    otheridx = [
                        v for v in variables
                        if indices[v] not in (indices[statement.var], '_idx',
                                              '0')
                    ]
                    if any(var in otheridx for var in vars_in_expr):
                        raise OrderDependenceError()
                    if permdep:
                        raise OrderDependenceError()
                    if any(var in main_index_variables
                           for var in vars_in_expr):
                        raise OrderDependenceError()
                else:
                    raise OrderDependenceError()
            else:
                raise AssertionError('Should never get here...')
Ejemplo n.º 39
0
def check_pre_code(codegen, stmts, vars_pre, vars_syn, vars_post):
    '''
    Given a set of statements stmts where the variables names in vars_pre are
    presynaptic, in vars_syn are synaptic and in vars_post are postsynaptic,
    check that the conditions for compatibility with GeNN are met, and return
    a new statement sequence translated for compatibility with GeNN, along
    with the name of the targeted variable.
    '''
    read, write, indices = codegen.array_read_write(stmts)

    post_write = set(write).intersection(set(vars_post))
    if len(post_write) == 0:
        raise NotImplementedError(
            "GeNN does not support Synapses with no postsynaptic effect.")
    if len(post_write) > 1:
        raise NotImplementedError(
            "GeNN only supports writing to a single postsynaptic variable.")

    post_write_var = list(post_write)[0]

    found_write_statement = False
    new_stmts = []
    for stmt in stmts:
        ids = get_identifiers(stmt.expr)
        if stmt.var == post_write_var:
            if stmt.inplace:
                if stmt.op != '+=':
                    raise NotImplementedError(
                        "GeNN only supports the += in place operation on postsynaptic variables."
                    )
                accumulation_expr = stmt.expr
            else:
                # TODO: we could support expressions like v = v + expr, but this requires some additional work
                # namely, for an expression like v = expr we need to check if (expr-v) when simplified reduces to
                # an expression that only has postsynaptic variables using sympy
                raise NotImplementedError(
                    "GeNN only supports in-place modification of postsynaptic variables"
                )
            new_stmt = Statement('addtoinSyn',
                                 '=',
                                 '_hidden_weightmatrix*(' + accumulation_expr +
                                 ')',
                                 comment=stmt.comment,
                                 dtype=stmt.dtype)
            new_stmts.append(new_stmt)
            if found_write_statement:
                raise NotImplementedError(
                    "GeNN does not support multiple writes to postsynaptic variables."
                )
            found_write_statement = True
        else:
            new_stmts.append(stmt)

    return post_write_var, new_stmts
Ejemplo n.º 40
0
def get_identifiers_recursively(expr, variables):
    '''
    Gets all the identifiers in a code, recursing down into subexpressions.
    '''
    identifiers = get_identifiers(expr)
    for name in set(identifiers):
        if name in variables and isinstance(variables[name], Subexpression):
            s_identifiers = get_identifiers_recursively(
                variables[name].expr, variables)
            identifiers |= s_identifiers
    return identifiers
Ejemplo n.º 41
0
def get_identifiers_recursively(expr, variables):
    '''
    Gets all the identifiers in a code, recursing down into subexpressions.
    '''
    identifiers = get_identifiers(expr)
    for name in set(identifiers):
        if name in variables and isinstance(variables[name], Subexpression):
            s_identifiers = get_identifiers_recursively(variables[name].expr,
                                                        variables)
            identifiers |= s_identifiers
    return identifiers
Ejemplo n.º 42
0
def numerically_check_permutation_code(code):
    # numerically checks that a code block used in the test below is permutation-independent by creating a
    # presynaptic and postsynaptic group of 3 neurons each, and a full connectivity matrix between them, then
    # repeatedly filling in random values for each of the variables, and checking for several random shuffles of
    # the synapse order that the result doesn't depend on it. This is a sort of test of the test itself, to make
    # sure we didn't accidentally assign a good/bad example to the wrong class.
    code = deindent(code)
    from collections import defaultdict
    vars = get_identifiers(code)
    indices = defaultdict(lambda: '_idx')
    vals = {}
    for var in vars:
        if var.endswith('_syn'):
            indices[var] = '_idx'
            vals[var] = zeros(9)
        elif var.endswith('_pre'):
            indices[var] ='_presynaptic_idx'
            vals[var] = zeros(3)
        elif var.endswith('_post'):
            indices[var] = '_postsynaptic_idx'
            vals[var] = zeros(3)
    subs = dict((var, var+'['+idx+']') for var, idx in indices.iteritems())
    code = word_substitute(code, subs)
    code = '''
from numpy import *
from numpy.random import rand, randn
for _idx in shuffled_indices:
    _presynaptic_idx = presyn[_idx]
    _postsynaptic_idx = postsyn[_idx]
{code}
    '''.format(code=indent(code))
    ns = vals.copy()
    ns['shuffled_indices'] = arange(9)
    ns['presyn'] = arange(9)%3
    ns['postsyn'] = arange(9)/3
    for _ in xrange(10):
        origvals = {}
        for k, v in vals.iteritems():
            v[:] = randn(len(v))
            origvals[k] = v.copy()
        exec code in ns
        endvals = {}
        for k, v in vals.iteritems():
            endvals[k] = v.copy()
        for _ in xrange(10):
            for k, v in vals.iteritems():
                v[:] = origvals[k]
            shuffle(ns['shuffled_indices'])
            exec code in ns
            for k, v in vals.iteritems():
                try:
                    assert_allclose(v, endvals[k])
                except AssertionError:
                    raise OrderDependenceError()
Ejemplo n.º 43
0
def neurongroup_description(neurongroup, run_namespace):
    eqs = eq_string(neurongroup.user_equations)
    identifiers = neurongroup.user_equations.identifiers
    desc = "%d,\n'''%s'''" % (len(neurongroup), eqs)
    if 'spike' in neurongroup.events:
        threshold = neurongroup.events['spike']
        desc += ',\nthreshold=%r' % threshold
        identifiers |= get_identifiers(threshold)
    if 'spike' in neurongroup.event_codes:
        reset = neurongroup.event_codes['spike']
        desc += ',\nreset=%r' % reset
        identifiers |= get_identifiers(reset)
    if neurongroup._refractory is not None:
        refractory = neurongroup._refractory
        desc += ',\nrefractory=%r' % refractory
        if isinstance(refractory, basestring):
            identifiers |= get_identifiers(refractory)
    namespace = get_namespace_dict(identifiers, neurongroup, run_namespace)
    desc += ',\nname=%r' % neurongroup.name
    desc = '%s = NeuronGroup(%s)' % (neurongroup.name, desc)
    return desc, namespace
Ejemplo n.º 44
0
def ufunc_at_vectorisation(statements, variables, indices, index):
    '''
    '''
    # We assume that the code has passed the test for synapse order independence

    main_index_variables = [v for v in variables if indices[v] == index]

    lines = []
    need_unique_indices = set()

    for statement in statements:
        vars_in_expr = get_identifiers(statement.expr).intersection(variables)
        subs = {}
        for var in vars_in_expr:
            subs[var] = '{var}[{idx}]'.format(var=var, idx=indices[var])
        expr = word_substitute(statement.expr, subs)
        if statement.var in main_index_variables:
            line = '{var}[{idx}] {op} {expr}'.format(var=statement.var,
                                                     op=statement.op,
                                                     expr=expr,
                                                     idx=index)
            lines.append(line)
        else:
            if statement.inplace:
                if statement.op == '+=':
                    ufunc_name = '_numpy.add'
                elif statement.op == '*=':
                    ufunc_name = '_numpy.multiply'
                else:
                    raise SynapseVectorisationError()
                line = '{ufunc_name}.at({var}, {idx}, {expr})'.format(
                    ufunc_name=ufunc_name,
                    var=statement.var,
                    idx=indices[statement.var],
                    expr=expr)
                lines.append(line)
            else:
                # if statement is not in-place then we assume the expr has no synaptic
                # variables in it otherwise it would have failed the order independence
                # check. In this case, we only need to work with the unique indices
                need_unique_indices.add(indices[statement.var])
                idx = '_unique_' + indices[statement.var]
                expr = word_substitute(expr, {indices[statement.var]: idx})
                line = '{var}[{idx}] = {expr}'.format(var=statement.var,
                                                      idx=idx,
                                                      expr=expr)
                lines.append(line)

    for unique_idx in need_unique_indices:
        lines.insert(
            0, '_unique_{idx} = _numpy.unique({idx})'.format(idx=unique_idx))

    return '\n'.join(lines)
Ejemplo n.º 45
0
 def update_abstract_code(self, run_namespace=None, level=0):
     code = self.group.threshold
     identifiers = get_identifiers(code)
     variables = self.group.resolve_all(identifiers,
                                        run_namespace=run_namespace,
                                        level=level+1)
     if not is_boolean_expression(self.group.threshold, variables):
         raise TypeError(('Threshold condition "%s" is not a boolean '
                          'expression') % self.group.threshold)
     if self.group._refractory is False:
         self.abstract_code = '_cond = %s' % self.group.threshold
     else:
         self.abstract_code = '_cond = (%s) and not_refractory' % self.group.threshold
Ejemplo n.º 46
0
 def update_abstract_code(self, run_namespace=None, level=0):
     code = self.group.threshold
     identifiers = get_identifiers(code)
     variables = self.group.resolve_all(identifiers,
                                        run_namespace=run_namespace,
                                        level=level + 1)
     if not is_boolean_expression(self.group.threshold, variables):
         raise TypeError(('Threshold condition "%s" is not a boolean '
                          'expression') % self.group.threshold)
     if self.group._refractory is False:
         self.abstract_code = '_cond = %s' % self.group.threshold
     else:
         self.abstract_code = '_cond = (%s) and not_refractory' % self.group.threshold
Ejemplo n.º 47
0
def neurongroup_description(neurongroup, run_namespace):
    eqs = eq_string(neurongroup.user_equations)
    identifiers = neurongroup.user_equations.identifiers
    desc = "%d,\n'''%s'''" % (len(neurongroup), eqs)
    if 'spike' in neurongroup.events:
        threshold = neurongroup.events['spike']
        desc += ',\nthreshold=%r' % threshold
        identifiers |= get_identifiers(threshold)
    if 'spike' in neurongroup.event_codes:
        reset = neurongroup.event_codes['spike']
        desc += ',\nreset=%r' % reset
        identifiers |= get_identifiers(reset)
    if neurongroup._refractory is not None:
        refractory = neurongroup._refractory
        desc += ',\nrefractory=%r' % refractory
        if isinstance(refractory, basestring):
            identifiers |= get_identifiers(refractory)
    namespace = get_namespace_dict(identifiers, neurongroup,
                                   run_namespace)
    desc += ',\nname=%r' % neurongroup.name
    desc = '%s = NeuronGroup(%s)' % (neurongroup.name, desc)
    return desc, namespace
Ejemplo n.º 48
0
def check_pre_code(codegen, stmts, vars_pre, vars_syn, vars_post,
                   conditional_write_vars):
    '''
    Given a set of statements stmts where the variables names in vars_pre are
    presynaptic, in vars_syn are synaptic and in vars_post are postsynaptic,
    check that the conditions for compatibility with GeNN are met, and return
    a new statement sequence translated for compatibility with GeNN, along
    with the name of the targeted variable.

    Also adapts the synaptic statement to be multiplied by 0 for a refractory
    post-synaptic cell.
    '''
    read, write, indices = codegen.array_read_write(stmts)
    
    post_write = set(write).intersection(set(vars_post))
    if len(post_write)==0:
        raise NotImplementedError("GeNN does not support Synapses with no postsynaptic effect.")
    if len(post_write)>1:
        raise NotImplementedError("GeNN only supports writing to a single postsynaptic variable.")
    
    post_write_var = list(post_write)[0]
        
    found_write_statement = False
    new_stmts = []
    for stmt in stmts:
        ids = get_identifiers(stmt.expr)
        if stmt.var==post_write_var:
            if stmt.inplace:
                if stmt.op!='+=':
                    raise NotImplementedError("GeNN only supports the += in place operation on postsynaptic variables.")
                accumulation_expr = stmt.expr
                # "write-protect" a variable during refractoriness to match Brian's semantics
                if stmt.var in conditional_write_vars:
                    assert conditional_write_vars[stmt.var] == 'not_refractory'
                    accumulation_expr = 'int(not_refractory_post) * ({})'.format(accumulation_expr)
            else:
                # TODO: we could support expressions like v = v + expr, but this requires some additional work
                # namely, for an expression like v = expr we need to check if (expr-v) when simplified reduces to
                # an expression that only has postsynaptic variables using sympy
                raise NotImplementedError("GeNN only supports in-place modification of postsynaptic variables")
            new_stmt = Statement('addtoinSyn', '=', '_hidden_weightmatrix*('+accumulation_expr+')',
                                 comment=stmt.comment, dtype=stmt.dtype)
            new_stmts.append(new_stmt)
            if found_write_statement:
                raise NotImplementedError("GeNN does not support multiple writes to postsynaptic variables.")
            found_write_statement = True
        else:
            new_stmts.append(stmt)
    
    return post_write_var, new_stmts
Ejemplo n.º 49
0
def ufunc_at_vectorisation(statements, variables, indices, index):
    '''
    '''
    # We assume that the code has passed the test for synapse order independence

    main_index_variables = [v for v in variables if indices[v] == index]
    
    lines = []
    need_unique_indices = set()
    
    for statement in statements:
        vars_in_expr = get_identifiers(statement.expr).intersection(variables)
        subs = {}
        for var in vars_in_expr:
            subs[var] = '{var}[{idx}]'.format(var=var, idx=indices[var])
        expr = word_substitute(statement.expr, subs)
        if statement.var in main_index_variables:
            line = '{var}[{idx}] {op} {expr}'.format(var=statement.var,
                                                     op=statement.op,
                                                     expr=expr,
                                                     idx=index)
            lines.append(line)
        else:
            if statement.inplace:
                if statement.op=='+=':
                    ufunc_name = '_numpy.add'
                elif statement.op=='*=':
                    ufunc_name = '_numpy.multiply'
                else:
                    raise SynapseVectorisationError()
                line = '{ufunc_name}.at({var}, {idx}, {expr})'.format(ufunc_name=ufunc_name,
                                                                      var=statement.var,
                                                                      idx=indices[statement.var],
                                                                      expr=expr)
                lines.append(line)
            else:
                # if statement is not in-place then we assume the expr has no synaptic
                # variables in it otherwise it would have failed the order independence
                # check. In this case, we only need to work with the unique indices
                need_unique_indices.add(indices[statement.var])
                idx = '_unique_' + indices[statement.var]
                expr = word_substitute(expr, {indices[statement.var]: idx})
                line = '{var}[{idx}] = {expr}'.format(var=statement.var,
                                                      idx=idx, expr=expr)
                lines.append(line)

    for unique_idx in need_unique_indices:
        lines.insert(0, '_unique_{idx} = _numpy.unique({idx})'.format(idx=unique_idx))
        
    return '\n'.join(lines)
Ejemplo n.º 50
0
def get_sensitivity_init(group, parameters, param_init):
    """
    Calculate the initial values for the sensitivity parameters (necessary if
    initial values are functions of parameters).

    Parameters
    ----------
    group : `NeuronGroup`
        The group of neurons that will be simulated.
    parameters : list of str
        Names of the parameters that are fit.
    param_init : dict
        The dictionary with expressions to initialize the model variables.

    Returns
    -------
    sensitivity_init : dict
        Dictionary of expressions to initialize the sensitivity
        parameters.
    """
    sensitivity_dict = {}
    for var_name, expr in param_init.items():
        if not isinstance(expr, str):
            continue
        identifiers = get_identifiers(expr)
        for identifier in identifiers:
            if (identifier in group.variables
                    and getattr(group.variables[identifier], 'type',
                                None) == SUBEXPRESSION):
                raise NotImplementedError('Initializations that refer to a '
                                          'subexpression are currently not '
                                          'supported')
            sympy_expr = str_to_sympy(expr)
            for parameter in parameters:
                diffed = sympy_expr.diff(str_to_sympy(parameter))
                if diffed != sympy.S.Zero:
                    if getattr(group.variables[parameter], 'type',
                               None) == SUBEXPRESSION:
                        raise NotImplementedError(
                            'Sensitivity '
                            f'S_{var_name}_{parameter} '
                            'is initialized to a non-zero '
                            'value, but it has been '
                            'removed from the equations. '
                            'Set optimize=False to avoid '
                            'this.')
                    init_expr = sympy_to_str(diffed)
                    sensitivity_dict[f'S_{var_name}_{parameter}'] = init_expr
    return sensitivity_dict
Ejemplo n.º 51
0
    def ufunc_at_vectorisation(self, statement, variables, indices,
                               conditional_write_vars, created_vars, used_variables):
        if not self._use_ufunc_at_vectorisation:
            raise VectorisationError()
        # Avoids circular import
        from brian2.devices.device import device

        # See https://github.com/brian-team/brian2/pull/531 for explanation
        used = set(get_identifiers(statement.expr))
        used = used.intersection(k for k in variables.keys() if k in indices and indices[k]!='_idx')
        used_variables.update(used)
        if statement.var in used_variables:
            raise VectorisationError()

        expr = NumpyNodeRenderer().render_expr(statement.expr)

        if statement.op == ':=' or indices[statement.var] == '_idx' or not statement.inplace:
            if statement.op == ':=':
                op = '='
            else:
                op = statement.op
            line = '{var} {op} {expr}'.format(var=statement.var, op=op, expr=expr)
        elif statement.inplace:
            if statement.op == '+=':
                ufunc_name = '_numpy.add'
            elif statement.op == '*=':
                ufunc_name = '_numpy.multiply'
            elif statement.op == '/=':
                ufunc_name = '_numpy.divide'
            elif statement.op == '-=':
                ufunc_name = '_numpy.subtract'
            else:
                raise VectorisationError()

            line = '{ufunc_name}.at({array_name}, {idx}, {expr})'.format(
                ufunc_name=ufunc_name,
                array_name=device.get_array_name(variables[statement.var]),
                idx=indices[statement.var],
                expr=expr)
            line = self.conditional_write(line, statement, variables,
                                          conditional_write_vars=conditional_write_vars,
                                          created_vars=created_vars)
        else:
            raise VectorisationError()

        if len(statement.comment):
            line += ' # ' + statement.comment

        return line
Ejemplo n.º 52
0
 def before_run(self, run_namespace=None):
     rates_var = self.variables['rates']
     if isinstance(rates_var, Subexpression):
         # Check that the units of the expression make sense
         expr = rates_var.expr
         identifiers = get_identifiers(expr)
         variables = self.resolve_all(identifiers,
                                      run_namespace,
                                      user_identifiers=identifiers)
         unit = parse_expression_dimensions(rates_var.expr, variables)
         fail_for_dimension_mismatch(unit, Hz, "The expression provided for "
                                               "PoissonGroup's 'rates' "
                                               "argument, has to have units "
                                               "of Hz")
     super(PoissonGroup, self).before_run(run_namespace)
Ejemplo n.º 53
0
 def before_run(self, run_namespace=None):
     rates_var = self.variables['rates']
     if isinstance(rates_var, Subexpression):
         # Check that the units of the expression make sense
         expr = rates_var.expr
         identifiers = get_identifiers(expr)
         variables = self.resolve_all(identifiers,
                                      run_namespace,
                                      user_identifiers=identifiers)
         unit = parse_expression_dimensions(rates_var.expr, variables)
         fail_for_dimension_mismatch(unit, Hz, "The expression provided for "
                                               "PoissonGroup's 'rates' "
                                               "argument, has to have units "
                                               "of Hz")
     super(PoissonGroup, self).before_run(run_namespace)
Ejemplo n.º 54
0
def check_for_order_independence(statements, variables, indices):
    '''
    '''

    main_index_variables = set([v for v in variables
                                if (indices[v] in ('_idx', '0')
                                    or getattr(variables[indices[v]],
                                               'unique',
                                               False))])
    different_index_variables = set(variables.keys()) - main_index_variables
    all_variables = variables.keys()

    permutation_independent = list(different_index_variables)
    changed_permutation_independent = True
    while changed_permutation_independent:
        changed_permutation_independent = False
        for statement in statements:
            vars_in_expr = get_identifiers(statement.expr).intersection(all_variables)
            nonsyn_vars_in_expr = vars_in_expr.intersection(different_index_variables)
            permdep = any(var not in permutation_independent for var in  nonsyn_vars_in_expr)
            if statement.op == ':=':
                continue  # auxiliary variable
            elif statement.var in main_index_variables:
                if permdep:
                    raise OrderDependenceError()
            elif statement.var in different_index_variables:
                if statement.op == '+=' or statement.op == '*=':
                    if permdep:
                        raise OrderDependenceError()
                    if statement.var in permutation_independent:
                        permutation_independent.remove(statement.var)
                        changed_permutation_independent = True
                elif statement.op == '=':
                    sameidx = [v for v in variables
                               if indices[v] == indices[statement.var]]
                    otheridx = [v for v in variables
                                if indices[v] not in (indices[statement.var],
                                                      '_idx', '0')]
                    if any(var in otheridx for var in vars_in_expr):
                        raise OrderDependenceError()
                    if permdep:
                        raise OrderDependenceError()
                    if any(var in main_index_variables for var in vars_in_expr):
                        raise OrderDependenceError()
                else:
                    raise OrderDependenceError()
            else:
                raise AssertionError('Should never get here...')
Ejemplo n.º 55
0
 def array_read_write(self, statements):
     '''
     Helper function, gives the set of ArrayVariables that are read from and
     written to in the series of statements. Returns the pair read, write
     of sets of variable names.
     '''
     variables = self.variables
     variable_indices = self.variable_indices
     read = set()
     write = set()
     for stmt in statements:
         ids = get_identifiers(stmt.expr)
         # if the operation is inplace this counts as a read.
         if stmt.inplace:
             ids.add(stmt.var)
         read = read.union(ids)
         if stmt.scalar or variable_indices[stmt.var] == '0':
             if stmt.op != ':=' and not self.allows_scalar_write:
                 raise SyntaxError(
                     ('Writing to scalar variable %s '
                      'not allowed in this context.' % stmt.var))
             for name in ids:
                 if (name in variables
                         and isinstance(variables[name], ArrayVariable)
                         and not (variables[name].scalar
                                  or variable_indices[name] == '0')):
                     raise SyntaxError(
                         ('Cannot write to scalar variable %s '
                          'with an expression referring to '
                          'vector variable %s') % (stmt.var, name))
         write.add(stmt.var)
     read = set(varname for varname, var in variables.items()
                if isinstance(var, ArrayVariable) and varname in read)
     write = set(varname for varname, var in variables.items()
                 if isinstance(var, ArrayVariable) and varname in write)
     # Gather the indices stored as arrays (ignore _idx which is special)
     indices = set()
     indices |= set(
         variable_indices[varname] for varname in read
         if not variable_indices[varname] in ('_idx', '0') and isinstance(
             variables[variable_indices[varname]], ArrayVariable))
     indices |= set(
         variable_indices[varname] for varname in write
         if not variable_indices[varname] in ('_idx', '0') and isinstance(
             variables[variable_indices[varname]], ArrayVariable))
     # don't list arrays that are read explicitly and used as indices twice
     read -= indices
     return read, write, indices
Ejemplo n.º 56
0
    def __init__(self, name, unit, owner, expr, device, dtype=None,
                 is_bool=False):
        super(Subexpression, self).__init__(unit=unit,
                                            name=name, dtype=dtype,
                                            is_bool=is_bool, scalar=False,
                                            constant=False, read_only=True)
        #: The `Group` to which this variable belongs
        self.owner = owner

        #: The `Device` responsible for memory access
        self.device = device

        #: The expression defining the subexpression
        self.expr = expr.strip()
        #: The identifiers used in the expression
        self.identifiers = get_identifiers(expr)
Ejemplo n.º 57
0
def test_parse_expression_unit_wrong_expressions(expr):
    Var = namedtuple('Var', ['dim', 'dtype'])
    variables = {
        'a': Var(dim=(volt * amp).dim, dtype=np.float64),
        'b': Var(dim=volt.dim, dtype=np.float64),
        'c': Var(dim=amp.dim, dtype=np.float64)
    }
    all_variables = {}
    group = SimpleGroup(namespace={}, variables=variables)
    for name in get_identifiers(expr):
        if name in variables:
            all_variables[name] = variables[name]
        else:
            all_variables[name] = group._resolve(name, {})
    with pytest.raises(SyntaxError):
        parse_expression_dimensions(expr, all_variables)
Ejemplo n.º 58
0
 def __init__(self, code, namespace=None, exhaustive=True, level=0):
     
     self._code = code
     
     # extract identifiers from the code
     self._identifiers = set(get_identifiers(code))
     
     self._namespaces = {}
     
     if namespace is None or not exhaustive:
         frame = inspect.stack()[level + 1][0]
         self._namespaces['locals'] = dict(frame.f_locals)
         self._namespaces['globals'] = dict(frame.f_globals)
     
     if namespace is not None:
         self._namespaces['user-defined'] = dict(namespace)
Ejemplo n.º 59
0
 def array_read_write(self, statements):
     '''
     Helper function, gives the set of ArrayVariables that are read from and
     written to in the series of statements. Returns the pair read, write
     of sets of variable names.
     '''
     variables = self.variables
     variable_indices = self.variable_indices
     read = set()
     write = set()
     for stmt in statements:
         ids = get_identifiers(stmt.expr)
         # if the operation is inplace this counts as a read.
         if stmt.inplace:
             ids.add(stmt.var)
         read = read.union(ids)
         if stmt.scalar or variable_indices[stmt.var] == '0':
             if stmt.op != ':=' and not self.allows_scalar_write:
                 raise SyntaxError(('Writing to scalar variable %s '
                                    'not allowed in this context.' % stmt.var))
             for name in ids:
                 if (name in variables and isinstance(variables[name], ArrayVariable)
                                       and not (variables[name].scalar or
                                                        variable_indices[name] == '0')):
                     raise SyntaxError(('Cannot write to scalar variable %s '
                                        'with an expression referring to '
                                        'vector variable %s') %
                                       (stmt.var, name))
         write.add(stmt.var)
     read = set(varname for varname, var in variables.items()
                if isinstance(var, ArrayVariable) and varname in read)
     write = set(varname for varname, var in variables.items()
                 if isinstance(var, ArrayVariable) and varname in write)
     # Gather the indices stored as arrays (ignore _idx which is special)
     indices = set()
     indices |= set(variable_indices[varname] for varname in read
                    if not variable_indices[varname] in ('_idx', '0')
                        and isinstance(variables[variable_indices[varname]],
                                       ArrayVariable))
     indices |= set(variable_indices[varname] for varname in write
                    if not variable_indices[varname] in ('_idx', '0')
                        and isinstance(variables[variable_indices[varname]],
                                       ArrayVariable))
     # don't list arrays that are read explicitly and used as indices twice
     read -= indices
     return read, write, indices