Пример #1
0
    def update_abstract_code(self, run_namespace=None, level=0):

        # Update the not_refractory variable for the refractory period mechanism
        self.abstract_code = self._get_refractory_code(run_namespace=run_namespace,
                                                       level=level+1)

        # Get the names used in the refractory code
        _, used_known, unknown = analyse_identifiers(self.abstract_code, self.group.variables,
                                                     recursive=True)

        # Get all names used in the equations (and always get "dt")
        names = self.group.equations.names
        external_names = self.group.equations.identifiers | {'dt'}

        variables = self.group.resolve_all(used_known | unknown | names | external_names,
                                           # we don't need to raise any warnings
                                           # for the user here, warnings will
                                           # be raised in create_runner_codeobj
                                           set(),
                                           run_namespace=run_namespace, level=level+1)

        # Since we did not necessarily no all the functions at creation time,
        # we might want to reconsider our numerical integration method
        self.method = StateUpdateMethod.determine_stateupdater(self.group.equations,
                                                               variables,
                                                               self.method_choice)
        self.abstract_code += self.method(self.group.equations, variables)
        user_code = '\n'.join(['{var} = {expr}'.format(var=var, expr=expr)
                               for var, expr in
                               self.group.equations.substituted_expressions])
        self.user_code = user_code
Пример #2
0
    def update_abstract_code(self, run_namespace=None, level=0):

        # Update the not_refractory variable for the refractory period mechanism
        self.abstract_code = self._get_refractory_code(
            run_namespace=run_namespace, level=level + 1)

        # Get the names used in the refractory code
        _, used_known, unknown = analyse_identifiers(self.abstract_code,
                                                     self.group.variables,
                                                     recursive=True)

        # Get all names used in the equations (and always get "dt")
        names = self.group.equations.names
        external_names = self.group.equations.identifiers | set(['dt'])

        variables = self.group.resolve_all(used_known | unknown | names
                                           | external_names,
                                           run_namespace=run_namespace,
                                           level=level + 1)

        # Since we did not necessarily no all the functions at creation time,
        # we might want to reconsider our numerical integration method
        self.method = StateUpdateMethod.determine_stateupdater(
            self.group.equations, variables, self.method_choice)
        self.abstract_code += self.method(self.group.equations, variables)
Пример #3
0
    def update_abstract_code(self, run_namespace):

        # Update the not_refractory variable for the refractory period mechanism
        self.abstract_code = self._get_refractory_code(run_namespace=run_namespace)

        # Get the names used in the refractory code
        _, used_known, unknown = analyse_identifiers(self.abstract_code, self.group.variables,
                                                     recursive=True)

        # Get all names used in the equations (and always get "dt")
        names = self.group.equations.names
        external_names = self.group.equations.identifiers | {'dt'}

        variables = self.group.resolve_all(used_known | unknown | names | external_names,
                                           run_namespace,
                                           # we don't need to raise any warnings
                                           # for the user here, warnings will
                                           # be raised in create_runner_codeobj
                                           user_identifiers=set())
        if len(self.group.equations.diff_eq_names) > 0:
            self.abstract_code += StateUpdateMethod.apply_stateupdater(self.group.equations,
                                                                       variables,
                                                                       self.method_choice,
                                                                       group_name=self.group.name)
        user_code = '\n'.join(['{var} = {expr}'.format(var=var, expr=expr)
                               for var, expr in
                               self.group.equations.get_substituted_expressions(variables)])
        self.user_code = user_code
Пример #4
0
    def update_abstract_code(self, run_namespace=None, level=0):

        # Update the not_refractory variable for the refractory period mechanism
        self.abstract_code = self._get_refractory_code(
            run_namespace=run_namespace, level=level + 1)

        # Get the names used in the refractory code
        _, used_known, unknown = analyse_identifiers(self.abstract_code,
                                                     self.group.variables,
                                                     recursive=True)

        # Get all names used in the equations (and always get "dt")
        names = self.group.equations.names
        external_names = self.group.equations.identifiers | {'dt'}

        variables = self.group.resolve_all(
            used_known | unknown | names | external_names,
            # we don't need to raise any warnings
            # for the user here, warnings will
            # be raised in create_runner_codeobj
            set(),
            run_namespace=run_namespace,
            level=level + 1)

        # Since we did not necessarily no all the functions at creation time,
        # we might want to reconsider our numerical integration method
        self.method = StateUpdateMethod.determine_stateupdater(
            self.group.equations, variables, self.method_choice)
        self.abstract_code += self.method(self.group.equations, variables)
        user_code = '\n'.join([
            '{var} = {expr}'.format(var=var, expr=expr)
            for var, expr in self.group.equations.substituted_expressions
        ])
        self.user_code = user_code
Пример #5
0
    def update_abstract_code(self, run_namespace):

        # Update the not_refractory variable for the refractory period mechanism
        self.abstract_code = self._get_refractory_code(run_namespace=run_namespace)

        # Get the names used in the refractory code
        _, used_known, unknown = analyse_identifiers(self.abstract_code, self.group.variables,
                                                     recursive=True)

        # Get all names used in the equations (and always get "dt")
        names = self.group.equations.names
        external_names = self.group.equations.identifiers | {'dt'}

        variables = self.group.resolve_all(used_known | unknown | names | external_names,
                                           # we don't need to raise any warnings
                                           # for the user here, warnings will
                                           # be raised in create_runner_codeobj
                                           set(),
                                           run_namespace=run_namespace)

        self.abstract_code += StateUpdateMethod.apply_stateupdater(self.group.equations,
                                                                   variables,
                                                                   self.method_choice)
        user_code = '\n'.join(['{var} = {expr}'.format(var=var, expr=expr)
                               for var, expr in
                               self.group.equations.substituted_expressions])
        self.user_code = user_code
Пример #6
0
def check_code_units(code, group, additional_variables=None,
                additional_namespace=None,
                ignore_keyerrors=False):
    '''
    Check statements for correct units.

    Parameters
    ----------
    code : str
        The series of statements to check
    group : `Group`
        The context for the code execution
    additional_variables : dict-like, optional
        A mapping of names to `Variable` objects, used in addition to the
        variables saved in `self.group`.
    additional_namespace : dict-like, optional
        An additional namespace, as provided to `Group.pre_run`
    ignore_keyerrors : boolean, optional
        Whether to silently ignore unresolvable identifiers. Should be set
         to ``False`` (the default) if the namespace is expected to be
         complete (e.g. in `Group.pre_run`) but to ``True`` when the check
         is done during object initialisation where the namespace is not
         necessarily complete yet

    Raises
    ------
    DimensionMismatchError
        If `code` has unit mismatches
    '''
    all_variables = dict(group.variables)
    if additional_variables is not None:
        all_variables.update(additional_variables)

    # Resolve the namespace, resulting in a dictionary containing only the
    # external variables that are needed by the code -- keep the units for
    # the unit checks
    # Note that here we do not need to recursively descend into
    # subexpressions. For unit checking, we only need to know the units of
    # the subexpressions not what variables they refer to
    _, _, unknown = analyse_identifiers(code, all_variables)
    try:
        resolved_namespace = group.namespace.resolve_all(unknown,
                                                         additional_namespace,
                                                         strip_units=False)
    except KeyError as ex:
        if ignore_keyerrors:
            logger.debug('Namespace not complete (yet), ignoring: %s ' % str(ex),
                         'check_code_units')
            return
        else:
            raise ex

    check_units_statements(code, resolved_namespace, all_variables)
Пример #7
0
def test_analyse_identifiers():
    '''
    Test that the analyse_identifiers function works on a simple clear example.
    '''
    code = '''
    a = b+c
    d = e+f
    '''
    known = ['b', 'c', 'd', 'g']
    
    defined, used_known, dependent = analyse_identifiers(code, known)
    
    assert defined==set(['a'])
    assert used_known==set(['b', 'c', 'd'])
    assert dependent==set(['e', 'f'])
Пример #8
0
def test_analyse_identifiers():
    '''
    Test that the analyse_identifiers function works on a simple clear example.
    '''
    code = '''
    a = b+c
    d = e+f
    '''
    known = ['b', 'c', 'd', 'g']

    defined, used_known, dependent = analyse_identifiers(code, known)

    assert defined == set(['a'])
    assert used_known == set(['b', 'c', 'd'])
    assert dependent == set(['e', 'f'])
Пример #9
0
def test_analyse_identifiers():
    '''
    Test that the analyse_identifiers function works on a simple clear example.
    '''
    code = '''
    a = b+c
    d = e+f
    '''
    known = {'b': Variable(name='b'),
             'c': Variable(name='c'),
             'd': Variable(name='d'),
             'g': Variable(name='g')}
    
    defined, used_known, dependent = analyse_identifiers(code, known)
    assert 'a' in defined  # There might be an additional constant added by the
                           # loop-invariant optimisation
    assert used_known == {'b', 'c', 'd'}
    assert dependent == {'e', 'f'}
Пример #10
0
def test_analyse_identifiers():
    '''
    Test that the analyse_identifiers function works on a simple clear example.
    '''
    code = '''
    a = b+c
    d = e+f
    '''
    known = {'b': Variable(name='b'),
             'c': Variable(name='c'),
             'd': Variable(name='d'),
             'g': Variable(name='g')}
    
    defined, used_known, dependent = analyse_identifiers(code, known)
    assert 'a' in defined  # There might be an additional constant added by the
                           # loop-invariant optimisation
    assert used_known == {'b', 'c', 'd'}
    assert dependent == {'e', 'f'}
Пример #11
0
def test_analyse_identifiers():
    '''
    Test that the analyse_identifiers function works on a simple clear example.
    '''
    code = '''
    a = b+c
    d = e+f
    '''
    known = {'b': Variable(unit=None, name='b'),
             'c': Variable(unit=None, name='c'),
             'd': Variable(unit=None, name='d'),
             'g': Variable(unit=None, name='g')}
    
    defined, used_known, dependent = analyse_identifiers(code, known)
    
    assert defined==set(['a'])
    assert used_known==set(['b', 'c', 'd'])
    assert dependent==set(['e', 'f'])
Пример #12
0
    def update_abstract_code(self, run_namespace):

        # Update the not_refractory variable for the refractory period mechanism
        self.abstract_code = self._get_refractory_code(
            run_namespace=run_namespace)

        # Get the names used in the refractory code
        _, used_known, unknown = analyse_identifiers(self.abstract_code,
                                                     self.group.variables,
                                                     recursive=True)

        # Get all names used in the equations (and always get "dt")
        names = self.group.equations.names
        external_names = self.group.equations.identifiers | {'dt'}

        variables = self.group.resolve_all(
            used_known | unknown | names | external_names,
            run_namespace,
            # we don't need to raise any warnings
            # for the user here, warnings will
            # be raised in create_runner_codeobj
            user_identifiers=set())
        if len(self.group.equations.diff_eq_names) > 0:
            stateupdate_output = StateUpdateMethod.apply_stateupdater(
                self.group.equations,
                variables,
                self.method_choice,
                method_options=self.method_options,
                group_name=self.group.name)
            if isinstance(stateupdate_output, basestring):
                self.abstract_code += stateupdate_output
            else:
                # Note that the reason to send self along with this method is so the StateUpdater
                # can be modified! i.e. in GSL StateUpdateMethod a custom CodeObject gets added
                # to the StateUpdater together with some auxiliary information
                self.abstract_code += stateupdate_output(self)

        user_code = '\n'.join([
            '{var} = {expr}'.format(var=var, expr=expr) for var, expr in
            self.group.equations.get_substituted_expressions(variables)
        ])
        self.user_code = user_code
Пример #13
0
def test_analyse_identifiers():
    '''
    Test that the analyse_identifiers function works on a simple clear example.
    '''
    code = '''
    a = b+c
    d = e+f
    '''
    known = {
        'b': Variable(unit=None, name='b'),
        'c': Variable(unit=None, name='c'),
        'd': Variable(unit=None, name='d'),
        'g': Variable(unit=None, name='g')
    }

    defined, used_known, dependent = analyse_identifiers(code, known)

    assert defined == set(['a'])
    assert used_known == set(['b', 'c', 'd'])
    assert dependent == set(['e', 'f'])
Пример #14
0
    def update_abstract_code(self, run_namespace=None, level=0):

        # Update the not_refractory variable for the refractory period mechanism
        self.abstract_code = self._get_refractory_code(run_namespace=run_namespace,
                                                       level=level+1)

        # Get the names used in the refractory code
        _, used_known, unknown = analyse_identifiers(self.abstract_code, self.group.variables,
                                                     recursive=True)

        # Get all names used in the equations (and always get "dt")
        names = self.group.equations.names
        external_names = self.group.equations.identifiers | set(['dt'])

        variables = self.group.resolve_all(used_known | unknown | names | external_names,
                                           run_namespace=run_namespace, level=level+1)

        # Since we did not necessarily no all the functions at creation time,
        # we might want to reconsider our numerical integration method
        self.method = StateUpdateMethod.determine_stateupdater(self.group.equations,
                                                               variables,
                                                               self.method_choice)
        self.abstract_code += self.method(self.group.equations, variables)
Пример #15
0
    def update_abstract_code(self, run_namespace):

        # Update the not_refractory variable for the refractory period mechanism
        self.abstract_code = self._get_refractory_code(run_namespace=run_namespace)

        # Get the names used in the refractory code
        _, used_known, unknown = analyse_identifiers(self.abstract_code, self.group.variables,
                                                     recursive=True)

        # Get all names used in the equations (and always get "dt")
        names = self.group.equations.names
        external_names = self.group.equations.identifiers | {'dt'}

        variables = self.group.resolve_all(used_known | unknown | names | external_names,
                                           run_namespace,
                                           # we don't need to raise any warnings
                                           # for the user here, warnings will
                                           # be raised in create_runner_codeobj
                                           user_identifiers=set())
        if len(self.group.equations.diff_eq_names) > 0:
            stateupdate_output = StateUpdateMethod.apply_stateupdater(self.group.equations,
                                                                      variables,
                                                                      self.method_choice,
                                                                      method_options=self.method_options,
                                                                      group_name=self.group.name)
            if isinstance(stateupdate_output, basestring):
                self.abstract_code += stateupdate_output
            else:
                # Note that the reason to send self along with this method is so the StateUpdater
                # can be modified! i.e. in GSL StateUpdateMethod a custom CodeObject gets added
                # to the StateUpdater together with some auxiliary information
                self.abstract_code += stateupdate_output(self)

        user_code = '\n'.join(['{var} = {expr}'.format(var=var, expr=expr)
                               for var, expr in
                               self.group.equations.get_substituted_expressions(variables)])
        self.user_code = user_code
Пример #16
0
def check_units_statements(code, variables):
    '''
    Check the units for a series of statements. Setting a model variable has to
    use the correct unit. For newly introduced temporary variables, the unit
    is determined and used to check the following statements to ensure
    consistency.
    
    Parameters
    ----------
    code : str
        The statements as a (multi-line) string
    variables : dict of `Variable` objects
        The information about all variables used in `code` (including
        `Constant` objects for external variables)
    
    Raises
    ------
    KeyError
        In case on of the identifiers cannot be resolved.
    DimensionMismatchError
        If an unit mismatch occurs during the evaluation.
    '''
    # Avoid a circular import
    from brian2.codegen.translation import analyse_identifiers
    known = set(variables.keys())
    newly_defined, _, unknown = analyse_identifiers(code, known)
    
    if len(unknown):
        raise AssertionError(('Encountered unknown identifiers, this should '
                             'not happen at this stage. Unkown identifiers: %s'
                             % unknown))

    
    code = re.split(r'[;\n]', code)
    for line in code:
        line = line.strip()
        if not len(line):
            continue  # skip empty lines
        
        varname, op, expr, comment = parse_statement(line)
        if op in ('+=', '-=', '*=', '/=', '%='):
            # Replace statements such as "w *=2" by "w = w * 2"
            expr = '{var} {op_first} {expr}'.format(var=varname,
                                                    op_first=op[0],
                                                    expr=expr)
            op = '='
        elif op == '=':
            pass
        else:
            raise AssertionError('Unknown operator "%s"' % op) 

        expr_unit = parse_expression_unit(expr, variables)

        if varname in variables:
            fail_for_dimension_mismatch(variables[varname].unit,
                                        expr_unit,
                                        ('Code statement "%s" does not use '
                                         'correct units' % line))
        elif varname in newly_defined:
            # note the unit for later
            variables[varname] = Variable(name=varname, unit=expr_unit,
                                          scalar=False)
        else:
            raise AssertionError(('Variable "%s" is neither in the variables '
                                  'dictionary nor in the list of undefined '
                                  'variables.' % varname))
Пример #17
0
def collect_Synapses(synapses, run_namespace):
    """
    Collect information from `brian2.synapses.synapses.Synapses`
    and represent them in dictionary format

    Parameters
    ----------
    synapses : brian2.synapses.synapses.Synapses
        Synapses object

    run_namespace : dict
        Namespace dictionary

    Returns
    -------
    synapse_dict : dict
        Standard dictionary format with collected information
    """
    identifiers = set()
    synapse_dict = {}
    # get synapses object name
    synapse_dict['name'] = synapses.name

    # get source and target groups
    synapse_dict['source'] = collect_SpikeSource(synapses.source)
    synapse_dict['target'] = collect_SpikeSource(synapses.target)

    # get governing equations
    synapse_equations = collect_Equations(synapses.equations)
    # get identifiers from equations
    identifiers = identifiers | synapses.equations.identifiers
    if synapses.event_driven:
        synapse_equations.update(collect_Equations(synapses.event_driven))
        identifiers = identifiers | synapses.event_driven.identifiers
    # check equations is not empty
    if synapse_equations:
        synapse_dict['equations'] = synapse_equations
    # check state updaters
    if (synapses.state_updater
            and isinstance(synapses.state_updater.method_choice, str)):
        synapse_dict['user_method'] = synapses.state_updater.method_choice
    # loop over the contained objects
    summed_variables = []
    pathways = []
    for obj in synapses.contained_objects:
        # check summed variables
        if isinstance(obj, SummedVariableUpdater):
            summed_var = {
                'code': obj.expression,
                'target': collect_SpikeSource(obj.target),
                'name': obj.name,
                'dt': obj.clock.dt,
                'when': obj.when,
                'order': obj.order
            }
            summed_variables.append(summed_var)
        # check synapse pathways
        if isinstance(obj, SynapticPathway):
            path = {
                'prepost': obj.prepost,
                'event': obj.event,
                'code': obj.code,
                'source': collect_SpikeSource(obj.source),
                'target': collect_SpikeSource(obj.target),
                'name': obj.name,
                'dt': obj.clock.dt,
                'order': obj.order,
                'when': obj.when
            }
            # check delay is defined
            if obj.variables['delay'].scalar:
                path.update({'delay': obj.delay[:]})
            pathways.append(path)
            # check any identifiers specific to pathway expression
            _, _, unknown = analyse_identifiers(obj.code, obj.variables)
            identifiers = identifiers | unknown

    # check any summed variables are used
    if summed_variables:
        synapse_dict['summed_variables'] = summed_variables
    # check any pathways are defined
    if pathways:
        synapse_dict['pathways'] = pathways
    # resolve identifiers and add to dict
    identifiers = synapses.resolve_all(identifiers, run_namespace)
    identifiers = _prepare_identifiers(identifiers)
    if identifiers:
        synapse_dict['identifiers'] = identifiers

    return synapse_dict
Пример #18
0
def check_units_statements(code, variables):
    '''
    Check the units for a series of statements. Setting a model variable has to
    use the correct unit. For newly introduced temporary variables, the unit
    is determined and used to check the following statements to ensure
    consistency.
    
    Parameters
    ----------
    code : str
        The statements as a (multi-line) string
    variables : dict of `Variable` objects
        The information about all variables used in `code` (including
        `Constant` objects for external variables)
    
    Raises
    ------
    KeyError
        In case on of the identifiers cannot be resolved.
    DimensionMismatchError
        If an unit mismatch occurs during the evaluation.
    '''
    # Avoid a circular import
    from brian2.codegen.translation import analyse_identifiers
    known = set(variables.keys())
    newly_defined, _, unknown = analyse_identifiers(code, known)
    
    if len(unknown):
        raise AssertionError(('Encountered unknown identifiers, this should '
                             'not happen at this stage. Unkown identifiers: %s'
                             % unknown))

    
    code = re.split(r'[;\n]', code)
    for line in code:
        line = line.strip()
        if not len(line):
            continue  # skip empty lines
        
        varname, op, expr, comment = parse_statement(line)
        if op in ('+=', '-=', '*=', '/=', '%='):
            # Replace statements such as "w *=2" by "w = w * 2"
            expr = '{var} {op_first} {expr}'.format(var=varname,
                                                    op_first=op[0],
                                                    expr=expr)
            op = '='
        elif op == '=':
            pass
        else:
            raise AssertionError('Unknown operator "%s"' % op) 

        expr_unit = parse_expression_unit(expr, variables)

        if varname in variables:
            expected_unit = variables[varname].unit
            fail_for_dimension_mismatch(expr_unit, expected_unit,
                                        ('The right-hand-side of code '
                                         'statement ""%s" does not have the '
                                         'expected unit %r') % (line,
                                                               expected_unit))
        elif varname in newly_defined:
            # note the unit for later
            variables[varname] = Variable(name=varname, unit=expr_unit,
                                          scalar=False)
        else:
            raise AssertionError(('Variable "%s" is neither in the variables '
                                  'dictionary nor in the list of undefined '
                                  'variables.' % varname))
Пример #19
0
def create_runner_codeobj(group, code, template_name, indices=None,
                          variable_indices=None,
                          name=None, check_units=True,
                          additional_variables=None,
                          additional_namespace=None,
                          template_kwds=None):
    ''' Create a `CodeObject` for the execution of code in the context of a
    `Group`.

    Parameters
    ----------
    group : `Group`
        The group where the code is to be run
    code : str
        The code to be executed.
    template : `LanguageTemplater`
        The template to use for the code.
    indices : dict-like, optional
        A mapping from index name to `Index` objects, describing the indices
        used for the variables in the code. If none are given, uses the
        corresponding attribute of `group`.
    variable_indices : dict-like, optional
        A mapping from `Variable` objects to index names (strings).  If none is
        given, uses the corresponding attribute of `group`.
    name : str, optional
        A name for this code object, will use ``group + '_codeobject*'`` if
        none is given.
    check_units : bool, optional
        Whether to check units in the statement. Defaults to ``True``.
    additional_variables : dict-like, optional
        A mapping of names to `Variable` objects, used in addition to the
        variables saved in `group`.
    additional_namespace : dict-like, optional
        A mapping from names to objects, used in addition to the namespace
        saved in `group`.
        template_kwds : dict, optional
        A dictionary of additional information that is passed to the template.
    '''
    logger.debug('Creating code object for abstract code:\n' + str(code))
        
    template = get_codeobject_template(template_name,
                                       codeobj_class=group.codeobj_class)

    all_variables = dict(group.variables)
    if additional_variables is not None:
        all_variables.update(additional_variables)

    # Determine the identifiers that were used
    _, used_known, unknown = analyse_identifiers(code, all_variables,
                                                 recursive=True)

    logger.debug('Unknown identifiers in the abstract code: ' + str(unknown))
    resolved_namespace = group.namespace.resolve_all(unknown,
                                                     additional_namespace)

    # Only pass the variables that are actually used
    variables = {}
    for var in used_known:
        if not isinstance(all_variables[var], StochasticVariable):
            variables[var] = all_variables[var]

    # Also add the variables that the template needs
    for var in template.variables:
        try:
            variables[var] = all_variables[var]
        except KeyError as ex:
            # We abuse template.variables here to also store names of things
            # from the namespace (e.g. rand) that are needed
            # TODO: Improve all of this namespace/specifier handling
            if group is not None:
                # Try to find the name in the group's namespace
                resolved_namespace[var] = group.namespace.resolve(var,
                                                                  additional_namespace)
            else:
                raise ex

    if name is None:
        if group is not None:
            name = group.name + '_codeobject*'
        else:
            name = '_codeobject*'

    if indices is None:
        indices = group.indices
    if variable_indices is None:
        variable_indices = group.variable_indices

    return create_codeobject(name,
                             code,
                             resolved_namespace,
                             variables,
                             template_name,
                             indices=indices,
                             variable_indices=variable_indices,
                             template_kwds=template_kwds,
                             codeobj_class=group.codeobj_class)
Пример #20
0
def check_units_statements(code, variables):
    """
    Check the units for a series of statements. Setting a model variable has to
    use the correct unit. For newly introduced temporary variables, the unit
    is determined and used to check the following statements to ensure
    consistency.
    
    Parameters
    ----------
    code : str
        The statements as a (multi-line) string
    variables : dict of `Variable` objects
        The information about all variables used in `code` (including
        `Constant` objects for external variables)
    
    Raises
    ------
    KeyError
        In case on of the identifiers cannot be resolved.
    DimensionMismatchError
        If an unit mismatch occurs during the evaluation.
    """
    variables = dict(variables)
    # Avoid a circular import
    from brian2.codegen.translation import analyse_identifiers
    newly_defined, _, unknown = analyse_identifiers(code, variables)

    if len(unknown):
        raise AssertionError(
            f"Encountered unknown identifiers, this should not "
            f"happen at this stage. Unknown identifiers: {unknown}")

    code = re.split(r'[;\n]', code)
    for line in code:
        line = line.strip()
        if not len(line):
            continue  # skip empty lines

        varname, op, expr, comment = parse_statement(line)
        if op in ('+=', '-=', '*=', '/=', '%='):
            # Replace statements such as "w *=2" by "w = w * 2"
            expr = f'{varname} {op[0]} {expr}'
        elif op == '=':
            pass
        else:
            raise AssertionError(f'Unknown operator "{op}"')

        expr_unit = parse_expression_dimensions(expr, variables)

        if varname in variables:
            expected_unit = variables[varname].dim
            fail_for_dimension_mismatch(expr_unit,
                                        expected_unit,
                                        ('The right-hand-side of code '
                                         'statement "%s" does not have the '
                                         'expected unit {expected}') % line,
                                        expected=expected_unit)
        elif varname in newly_defined:
            # note the unit for later
            variables[varname] = Variable(name=varname,
                                          dimensions=get_dimensions(expr_unit),
                                          scalar=False)
        else:
            raise AssertionError(
                f"Variable '{varname}' is neither in the variables "
                f"dictionary nor in the list of undefined "
                f"variables.")
Пример #21
0
def check_units_statements(code, namespace, variables):
    '''
    Check the units for a series of statements. Setting a model variable has to
    use the correct unit. For newly introduced temporary variables, the unit
    is determined and used to check the following statements to ensure
    consistency.
    
    Parameters
    ----------
    expression : str
        The expression to evaluate.
    namespace : dict-like
        The namespace of external variables.
    variables : dict of `Variable` objects
        The information about the internal variables
    
    Raises
    ------
    KeyError
        In case on of the identifiers cannot be resolved.
    DimensionMismatchError
        If an unit mismatch occurs during the evaluation.
    '''
    known = set(variables.keys()) | set(namespace.keys())
    newly_defined, _, unknown = analyse_identifiers(code, known)

    if len(unknown):
        raise AssertionError(
            ('Encountered unknown identifiers, this should '
             'not happen at this stage. Unkown identifiers: %s' % unknown))

    # We want to add newly defined variables to the variables dictionary so we
    # make a copy now
    variables = dict(variables)

    code = re.split(r'[;\n]', code)
    for line in code:
        line = line.strip()
        if not len(line):
            continue  # skip empty lines

        varname, op, expr = parse_statement(line)
        if op in ('+=', '-=', '*=', '/=', '%='):
            # Replace statements such as "w *=2" by "w = w * 2"
            expr = '{var} {op_first} {expr}'.format(var=varname,
                                                    op_first=op[0],
                                                    expr=expr)
            op = '='
        elif op == '=':
            pass
        else:
            raise AssertionError('Unknown operator "%s"' % op)

        expr_unit = parse_expression_unit(expr, namespace, variables)

        if varname in variables:
            fail_for_dimension_mismatch(variables[varname].unit, expr_unit,
                                        ('Code statement "%s" does not use '
                                         'correct units' % line))
        elif varname in newly_defined:
            # note the unit for later
            variables[varname] = Variable(expr_unit,
                                          is_bool=False,
                                          scalar=False)
        else:
            raise AssertionError(('Variable "%s" is neither in the variables '
                                  'dictionary nor in the list of undefined '
                                  'variables.' % varname))
Пример #22
0
def check_units_statements(code, namespace, variables):
    '''
    Check the units for a series of statements. Setting a model variable has to
    use the correct unit. For newly introduced temporary variables, the unit
    is determined and used to check the following statements to ensure
    consistency.
    
    Parameters
    ----------
    expression : str
        The expression to evaluate.
    namespace : dict-like
        The namespace of external variables.
    variables : dict of `Variable` objects
        The information about the internal variables
    
    Raises
    ------
    KeyError
        In case on of the identifiers cannot be resolved.
    DimensionMismatchError
        If an unit mismatch occurs during the evaluation.
    '''
    known = set(variables.keys()) | set(namespace.keys())
    newly_defined, _, unknown = analyse_identifiers(code, known)
    
    if len(unknown):
        raise AssertionError(('Encountered unknown identifiers, this should '
                             'not happen at this stage. Unkown identifiers: %s'
                             % unknown))
    
    # We want to add newly defined variables to the variables dictionary so we
    # make a copy now
    variables = dict(variables)
    
    code = re.split(r'[;\n]', code)
    for line in code:
        line = line.strip()
        if not len(line):
            continue  # skip empty lines
        
        varname, op, expr = parse_statement(line)
        if op in ('+=', '-=', '*=', '/=', '%='):
            # Replace statements such as "w *=2" by "w = w * 2"
            expr = '{var} {op_first} {expr}'.format(var=varname,
                                                    op_first=op[0],
                                                    expr=expr)
            op = '='
        elif op == '=':
            pass
        else:
            raise AssertionError('Unknown operator "%s"' % op) 

        expr_unit = parse_expression_unit(expr, namespace, variables)

        if varname in variables:
            fail_for_dimension_mismatch(variables[varname].unit,
                                        expr_unit,
                                        ('Code statement "%s" does not use '
                                         'correct units' % line))
        elif varname in newly_defined:
            # note the unit for later
            variables[varname] = Variable(expr_unit, is_bool=False,
                                          scalar=False)
        else:
            raise AssertionError(('Variable "%s" is neither in the variables '
                                  'dictionary nor in the list of undefined '
                                  'variables.' % varname))