예제 #1
0
def test_write_to_subexpression():
    variables = {
        'a': Subexpression(name='a', dtype=np.float32,
                           owner=FakeGroup(variables={}), device=None,
                           expr='2*z'),
        'z': Variable(name='z')
    }

    # Writing to a subexpression is not allowed
    code = 'a = z'
    with pytest.raises(SyntaxError):
        make_statements(code, variables, np.float32)
예제 #2
0
def test_repeated_subexpressions():
    variables = {
        'a':
        Subexpression(name='a',
                      dtype=np.float32,
                      owner=FakeGroup(variables={}),
                      device=None,
                      expr='2*z'),
        'x':
        Variable(name='x'),
        'y':
        Variable(name='y'),
        'z':
        Variable(name='z')
    }
    # subexpression a (referring to z) is used twice, but can be reused the
    # second time (no change to z)
    code = '''
    x = a
    y = a
    '''
    scalar_stmts, vector_stmts = make_statements(code, variables, np.float32)
    assert len(scalar_stmts) == 0
    assert [stmt.var for stmt in vector_stmts] == ['a', 'x', 'y']
    assert vector_stmts[0].constant

    code = '''
    x = a
    z *= 2
    '''
    scalar_stmts, vector_stmts = make_statements(code, variables, np.float32)
    assert len(scalar_stmts) == 0
    assert [stmt.var for stmt in vector_stmts] == ['a', 'x', 'z']
    # Note that we currently do not mark the subexpression as constant in this
    # case, because its use after the "z *=2" line would actually redefine it.
    # Our algorithm is currently not smart enough to detect that it is actually
    # not used afterwards

    # a refers to z, therefore we have to redefine a after z changed, and a
    # cannot be constant
    code = '''
    x = a
    z *= 2
    y = a
    '''
    scalar_stmts, vector_stmts = make_statements(code, variables, np.float32)
    assert len(scalar_stmts) == 0
    assert [stmt.var for stmt in vector_stmts] == ['a', 'x', 'z', 'a', 'y']
    assert not any(stmt.constant for stmt in vector_stmts)
예제 #3
0
def test_nested_subexpressions():
    '''
    This test checks that code translation works with nested subexpressions.
    '''
    code = '''
    x = a + b + c
    c = 1
    x = a + b + c
    d = 1
    x = a + b + c
    '''
    variables = {
        'a': Subexpression(name='a', dtype=np.float32, owner=FakeGroup(variables={}), device=None,
                           expr='b*b+d'),
        'b': Subexpression(name='b', dtype=np.float32, owner=FakeGroup(variables={}), device=None,
                           expr='c*c*c'),
        'c': Variable(name='c'),
        'd': Variable(name='d'),
        }
    scalar_stmts, vector_stmts = make_statements(code, variables, np.float32)
    assert len(scalar_stmts) == 0
    evalorder = ''.join(stmt.var for stmt in vector_stmts)
    # This is the order that variables ought to be evaluated in (note that
    # previously this test did not expect the last "b" evaluation, because its
    # value did not change (c was not changed). We have since removed this
    # subexpression caching, because it did not seem to apply in practical
    # use cases)
    assert evalorder == 'baxcbaxdbax'
예제 #4
0
    def translate(self, code, dtype):
        '''
        Translates an abstract code block into the target language.
        '''
        scalar_statements = {}
        vector_statements = {}
        for ac_name, ac_code in code.iteritems():
            scalar_statements[ac_name], vector_statements[
                ac_name] = make_statements(ac_code, self.variables, dtype)
        for vs in vector_statements.itervalues():
            # Check that the statements are meaningful independent on the order of
            # execution (e.g. for synapses)
            try:
                check_for_order_independence(vs, self.variables,
                                             self.variable_indices)
            except OrderDependenceError:
                # If the abstract code is only one line, display it in full
                if len(vs) <= 1:
                    error_msg = 'Abstract code: "%s"\n' % vs[0]
                else:
                    error_msg = ('%d lines of abstract code, first line is: '
                                 '"%s"\n') % (len(vs), vs[0])
                logger.warn(('Came across an abstract code block that is not '
                             'well-defined: the outcome may depend on the '
                             'order of execution. ' + error_msg))

        return self.translate_statement_sequence(scalar_statements,
                                                 vector_statements)
예제 #5
0
def test_nested_subexpressions():
    '''
    This test checks that code translation works with nested subexpressions.
    '''
    code = '''
    x = a + b + c
    c = 1
    x = a + b + c
    d = 1
    x = a + b + c
    '''
    variables = {
        'a': Subexpression(name='a', dtype=np.float32, owner=FakeGroup(variables={}), device=None,
                           expr='b*b+d'),
        'b': Subexpression(name='b', dtype=np.float32, owner=FakeGroup(variables={}), device=None,
                           expr='c*c*c'),
        'c': Variable(name='c'),
        'd': Variable(name='d'),
        }
    scalar_stmts, vector_stmts = make_statements(code, variables, np.float32)
    assert len(scalar_stmts) == 0
    evalorder = ''.join(stmt.var for stmt in vector_stmts)
    # This is the order that variables ought to be evaluated in (note that
    # previously this test did not expect the last "b" evaluation, because its
    # value did not change (c was not changed). We have since removed this
    # subexpression caching, because it did not seem to apply in practical
    # use cases)
    assert evalorder == 'baxcbaxdbax'
예제 #6
0
파일: base.py 프로젝트: appusom/brian2
    def translate(self, code, dtype):
        '''
        Translates an abstract code block into the target language.
        '''
        scalar_statements = {}
        vector_statements = {}
        for ac_name, ac_code in code.iteritems():
            scalar_statements[ac_name], vector_statements[ac_name] = make_statements(ac_code,
                                                                                     self.variables,
                                                                                     dtype)
        for vs in vector_statements.itervalues():
            # Check that the statements are meaningful independent on the order of
            # execution (e.g. for synapses)
            try:
                check_for_order_independence(vs,
                                             self.variables,
                                             self.variable_indices)
            except OrderDependenceError:
                # If the abstract code is only one line, display it in full
                if len(vs) <= 1:
                    error_msg = 'Abstract code: "%s"\n' % vs[0]
                else:
                    error_msg = ('%d lines of abstract code, first line is: '
                                 '"%s"\n') % (len(vs), vs[0])
                logger.warn(('Came across an abstract code block that is not '
                             'well-defined: the outcome may depend on the '
                             'order of execution. ' + error_msg))

        return self.translate_statement_sequence(scalar_statements,
                                                 vector_statements)
예제 #7
0
def test_nested_subexpressions():
    '''
    This test checks that code translation works with nested subexpressions.
    '''
    code = '''
    x = a + b + c
    c = 1
    x = a + b + c
    d = 1
    x = a + b + c
    '''
    variables = {
        'a':
        Subexpression(name='a',
                      unit=Unit(1),
                      dtype=np.float32,
                      owner=FakeGroup(variables={}),
                      device=None,
                      expr='b*b+d'),
        'b':
        Subexpression(name='b',
                      unit=Unit(1),
                      dtype=np.float32,
                      owner=FakeGroup(variables={}),
                      device=None,
                      expr='c*c*c'),
        'c':
        Variable(unit=None, name='c'),
        'd':
        Variable(unit=None, name='d'),
    }
    stmts = make_statements(code, variables, np.float32)
    evalorder = ''.join(stmt.var for stmt in stmts)
    # This is the order that variables ought to be evaluated in
    assert evalorder == 'baxcbaxdax'
예제 #8
0
 def translate(self, code, dtype):
     '''
     Translates an abstract code block into the target language.
     '''
     scalar_statements = {}
     vector_statements = {}
     for ac_name, ac_code in code.iteritems():
         statements = make_statements(ac_code,
                                      self.variables,
                                      dtype,
                                      optimise=True,
                                      blockname=ac_name)
         scalar_statements[ac_name], vector_statements[ac_name] = statements
     for vs in vector_statements.itervalues():
         # Check that the statements are meaningful independent on the order of
         # execution (e.g. for synapses)
         try:
             if self.has_repeated_indices(vs):  # only do order dependence if there are repeated indices
                 check_for_order_independence(vs,
                                              self.variables,
                                              self.variable_indices)
         except OrderDependenceError:
             # If the abstract code is only one line, display it in full
             if len(vs) <= 1:
                 error_msg = 'Abstract code: "%s"\n' % vs[0]
             else:
                 error_msg = ('%d lines of abstract code, first line is: '
                              '"%s"\n') % (len(vs), vs[0])
             logger.warn(('Came across an abstract code block that may not be '
                          'well-defined: the outcome may depend on the '
                          'order of execution. You can ignore this warning if '
                          'you are sure that the order of operations does not '
                          'matter. ' + error_msg))
     return self.translate_statement_sequence(scalar_statements,
                                              vector_statements)
예제 #9
0
파일: base.py 프로젝트: msGenDev/brian2
 def translate(self, code, dtype):
     '''
     Translates an abstract code block into the target language.
     '''
     statements = {}
     for ac_name, ac_code in code.iteritems():
         statements[ac_name] = make_statements(ac_code, self.variables, dtype)
     return self.translate_statement_sequence(statements)
예제 #10
0
 def translate(self, code, dtype):
     '''
     Translates an abstract code block into the target language.
     '''
     statements = {}
     for ac_name, ac_code in code.iteritems():
         statements[ac_name] = make_statements(ac_code, self.variables, dtype)
     return self.translate_statement_sequence(statements)
예제 #11
0
def test_automatic_augmented_assignments():
    # We test that statements that could be rewritten as augmented assignments
    # are correctly rewritten (using sympy to test for symbolic equality)
    variables = {
        'x': ArrayVariable('x', owner=None, size=10, device=device),
        'y': ArrayVariable('y', owner=None, size=10, device=device),
        'z': ArrayVariable('y', owner=None, size=10, device=device),
        'b': ArrayVariable('b', owner=None, size=10, dtype=bool,
                           device=device),
        'clip': DEFAULT_FUNCTIONS['clip'],
        'inf': DEFAULT_CONSTANTS['inf']
    }
    statements = [
        # examples that should be rewritten
        # Note that using our approach, we will never get -= or /= but always
        # the equivalent += or *= statements
        ('x = x + 1.0', 'x += 1.0'),
        ('x = 2.0 * x', 'x *= 2.0'),
        ('x = x - 3.0', 'x += -3.0'),
        ('x = x/2.0', 'x *= 0.5'),
        ('x = y + (x + 1.0)', 'x += y + 1.0'),
        ('x = x + x', 'x *= 2.0'),
        ('x = x + y + z', 'x += y + z'),
        ('x = x + y + z', 'x += y + z'),
        # examples that should not be rewritten
        ('x = 1.0/x', 'x = 1.0/x'),
        ('x = 1.0', 'x = 1.0'),
        ('x = 2.0*(x + 1.0)', 'x = 2.0*(x + 1.0)'),
        ('x = clip(x + y, 0.0, inf)', 'x = clip(x + y, 0.0, inf)'),
        ('b = b or False', 'b = b or False')
    ]
    for orig, rewritten in statements:
        scalar, vector = make_statements(orig, variables, np.float32)
        try:  # we augment the assertion error with the original statement
            assert len(
                scalar
            ) == 0, 'Did not expect any scalar statements but got ' + str(
                scalar)
            assert len(
                vector
            ) == 1, 'Did expect a single statement but got ' + str(vector)
            statement = vector[0]
            expected_var, expected_op, expected_expr, _ = parse_statement(
                rewritten)
            assert expected_var == statement.var, 'expected write to variable %s, not to %s' % (
                expected_var, statement.var)
            assert expected_op == statement.op, 'expected operation %s, not %s' % (
                expected_op, statement.op)
            # Compare the two expressions using sympy to allow for different order etc.
            sympy_expected = str_to_sympy(expected_expr)
            sympy_actual = str_to_sympy(statement.expr)
            assert sympy_expected == sympy_actual, (
                'RHS expressions "%s" and "%s" are not identical' %
                (sympy_to_str(sympy_expected), sympy_to_str(sympy_actual)))
        except AssertionError as ex:
            raise AssertionError(
                'Transformation for statement "%s" gave an unexpected result: %s'
                % (orig, str(ex)))
예제 #12
0
def test_repeated_subexpressions():
    variables = {
        'a': Subexpression(name='a', dtype=np.float32,
                           owner=FakeGroup(variables={}), device=None,
                           expr='2*z'),
        'x': Variable(name='x'),
        'y': Variable(name='y'),
        'z': Variable(name='z')
    }
    # subexpression a (referring to z) is used twice, but can be reused the
    # second time (no change to z)
    code = '''
    x = a
    y = a
    '''
    scalar_stmts, vector_stmts = make_statements(code, variables, np.float32)
    assert len(scalar_stmts) == 0
    assert [stmt.var for stmt in vector_stmts] == ['a', 'x', 'y']
    assert vector_stmts[0].constant

    code = '''
    x = a
    z *= 2
    '''
    scalar_stmts, vector_stmts = make_statements(code, variables, np.float32)
    assert len(scalar_stmts) == 0
    assert [stmt.var for stmt in vector_stmts] == ['a', 'x', 'z']
    # Note that we currently do not mark the subexpression as constant in this
    # case, because its use after the "z *=2" line would actually redefine it.
    # Our algorithm is currently not smart enough to detect that it is actually
    # not used afterwards

    # a refers to z, therefore we have to redefine a after z changed, and a
    # cannot be constant
    code = '''
    x = a
    z *= 2
    y = a
    '''
    scalar_stmts, vector_stmts = make_statements(code, variables, np.float32)
    assert len(scalar_stmts) == 0
    assert [stmt.var for stmt in vector_stmts] == ['a', 'x', 'z', 'a', 'y']
    assert not any(stmt.constant for stmt in vector_stmts)
예제 #13
0
def test_automatic_augmented_assignments():
    # We test that statements that could be rewritten as augmented assignments
    # are correctly rewritten (using sympy to test for symbolic equality)
    variables = {
        'x': ArrayVariable('x', owner=None, size=10,
                           device=device),
        'y': ArrayVariable('y', owner=None, size=10,
                           device=device),
        'z': ArrayVariable('y', owner=None, size=10,
                           device=device),
        'b': ArrayVariable('b', owner=None, size=10,
                           dtype=np.bool, device=device),
        'clip': DEFAULT_FUNCTIONS['clip'],
        'inf': DEFAULT_CONSTANTS['inf']
    }
    statements = [
        # examples that should be rewritten
        # Note that using our approach, we will never get -= or /= but always
        # the equivalent += or *= statements
        ('x = x + 1', 'x += 1'),
        ('x = 2 * x', 'x *= 2'),
        ('x = x - 3', 'x += -3'),
        ('x = x/2', 'x *= 0.5'),
        ('x = y + (x + 1)', 'x += y + 1'),
        ('x = x + x', 'x *= 2'),
        ('x = x + y + z', 'x += y + z'),
        ('x = x + y + z', 'x += y + z'),
        # examples that should not be rewritten
        ('x = 1/x', 'x = 1/x'),
        ('x = 1', 'x = 1'),
        ('x = 2*(x + 1)', 'x = 2*(x + 1)'),
        ('x = clip(x + y, 0, inf)', 'x = clip(x + y, 0, inf)'),
        ('b = b or False', 'b = b or False')
    ]
    for orig, rewritten in statements:
        scalar, vector = make_statements(orig, variables, np.float32)
        try:  # we augment the assertion error with the original statement
            assert len(scalar) == 0, 'Did not expect any scalar statements but got ' + str(scalar)
            assert len(vector) == 1, 'Did expect a single statement but got ' + str(vector)
            statement = vector[0]
            expected_var, expected_op, expected_expr, _ = parse_statement(rewritten)
            assert expected_var == statement.var, 'expected write to variable %s, not to %s' % (expected_var, statement.var)
            assert expected_op == statement.op, 'expected operation %s, not %s' % (expected_op, statement.op)
            # Compare the two expressions using sympy to allow for different order etc.
            sympy_expected = str_to_sympy(expected_expr)
            sympy_actual = str_to_sympy(statement.expr)
            assert sympy_expected == sympy_actual, ('RHS expressions "%s" and "%s" are not identical' % (sympy_to_str(sympy_expected),
                                                                                                         sympy_to_str(sympy_actual)))
        except AssertionError as ex:
            raise AssertionError('Transformation for statement "%s" gave an unexpected result: %s' % (orig, str(ex)))
예제 #14
0
def check_permutation_code(code):
    from collections import defaultdict
    vars = get_identifiers(code)
    indices = defaultdict(lambda: '_idx')
    for var in vars:
        if var.endswith('_syn'):
            indices[var] = '_idx'
        elif var.endswith('_pre'):
            indices[var] ='_presynaptic_idx'
        elif var.endswith('_post'):
            indices[var] = '_postsynaptic_idx'
    variables = dict()
    for var in indices:
        variables[var] = ArrayVariable(var, 1, None, 10, device)
    variables['_presynaptic_idx'] = ArrayVariable(var, 1, None, 10, device)
    variables['_postsynaptic_idx'] = ArrayVariable(var, 1, None, 10, device)
    scalar_statements, vector_statements = make_statements(code, variables, float64)
    check_for_order_independence(vector_statements, variables, indices)
예제 #15
0
def check_permutation_code(code):
    from collections import defaultdict
    vars = get_identifiers(code)
    indices = defaultdict(lambda: '_idx')
    for var in vars:
        if var.endswith('_syn'):
            indices[var] = '_idx'
        elif var.endswith('_pre'):
            indices[var] = '_presynaptic_idx'
        elif var.endswith('_post'):
            indices[var] = '_postsynaptic_idx'
    variables = dict()
    for var in indices:
        variables[var] = ArrayVariable(var, 1, None, 10, device)
    variables['_presynaptic_idx'] = ArrayVariable(var, 1, None, 10, device)
    variables['_postsynaptic_idx'] = ArrayVariable(var, 1, None, 10, device)
    scalar_statements, vector_statements = make_statements(
        code, variables, float64)
    check_for_order_independence(vector_statements, variables, indices)
예제 #16
0
파일: base.py 프로젝트: brian-team/brian2
    def translate(self, code, dtype):
        """
        Translates an abstract code block into the target language.
        """
        scalar_statements = {}
        vector_statements = {}
        for ac_name, ac_code in code.items():
            statements = make_statements(ac_code,
                                         self.variables,
                                         dtype,
                                         optimise=True,
                                         blockname=ac_name)
            scalar_statements[ac_name], vector_statements[ac_name] = statements
        for vs in vector_statements.values():
            # Check that the statements are meaningful independent on the order of
            # execution (e.g. for synapses)
            try:
                if self.has_repeated_indices(
                        vs
                ):  # only do order dependence if there are repeated indices
                    check_for_order_independence(vs, self.variables,
                                                 self.variable_indices)
            except OrderDependenceError:
                # If the abstract code is only one line, display it in full
                if len(vs) <= 1:
                    error_msg = f"Abstract code: '{vs[0]}'\n"
                else:
                    error_msg = (
                        f"{len(vs)} lines of abstract code, first line is: "
                        f"'{vs[0]}'\n")
                logger.warn(
                    ('Came across an abstract code block that may not be '
                     'well-defined: the outcome may depend on the '
                     'order of execution. You can ignore this warning if '
                     'you are sure that the order of operations does not '
                     'matter. ' + error_msg))

        translated = self.translate_statement_sequence(scalar_statements,
                                                       vector_statements)

        return translated
예제 #17
0
def test_nested_subexpressions():
    '''
    This test checks that code translation works with nested subexpressions.
    '''
    code = '''
    x = a + b + c
    c = 1
    x = a + b + c
    d = 1
    x = a + b + c
    '''
    variables = {
        'a': Subexpression(name='a', unit=Unit(1), dtype=np.float32, owner=FakeGroup(variables={}), device=None,
                           expr='b*b+d'),
        'b': Subexpression(name='b', unit=Unit(1), dtype=np.float32, owner=FakeGroup(variables={}), device=None,
                           expr='c*c*c'),
        'c': Variable(unit=None, name='c'),
        'd': Variable(unit=None, name='d'),
        }
    stmts = make_statements(code, variables, np.float32)
    evalorder = ''.join(stmt.var for stmt in stmts)
    # This is the order that variables ought to be evaluated in
    assert evalorder=='baxcbaxdax'
예제 #18
0
from brian2.core.variables import Variable, ArrayVariable

owner = None

code = '''
x = a*b+c
'''

variables = {
    'a': ArrayVariable('a', Unit(1), owner, 10, get_device()),
    'b': ArrayVariable('b', Unit(1), owner, 10, get_device()),
    'x': ArrayVariable('x', Unit(1), owner, 10, get_device()),
}
namespace = {}
variable_indices = {'a': '_idx', 'b': '_idx', 'x': '_idx'}

gen = CythonCodeGenerator(variables,
                          variable_indices,
                          owner,
                          iterate_all=True,
                          codeobj_class=None)

#print gen.translate_expression('a*b+c')
#print gen.translate_statement(Statement('x', '=', 'a*b+c', '', float))

stmts = make_statements(code, variables, float)
#for stmt in stmts:
# print stmt

print '\n'.join(gen.translate_one_statement_sequence(stmts))
예제 #19
0
 def code_object(self,
                 owner,
                 name,
                 abstract_code,
                 variables,
                 template_name,
                 variable_indices,
                 codeobj_class=None,
                 template_kwds=None,
                 override_conditional_write=None):
     if template_kwds == None:
         template_kwds = {}
     if hasattr(self, 'profile'):
         template_kwds['profile'] = self.profile
     no_or_const_delay_mode = False
     if isinstance(
             owner, (SynapticPathway, Synapses)
     ) and "delay" in owner.variables and owner.variables["delay"].scalar:
         # catches Synapses(..., delay=...) syntax, does not catch the case when no delay is specified at all
         no_or_const_delay_mode = True
     template_kwds["no_or_const_delay_mode"] = no_or_const_delay_mode
     if template_name == "synapses":
         ##################################################################
         # This code is copied from CodeGenerator.translate() and CodeGenerator.array_read_write()
         # and should give us a set of variables to which will be written in `vector_code`
         vector_statements = {}
         for ac_name, ac_code in abstract_code.iteritems():
             statements = make_statements(ac_code,
                                          variables,
                                          prefs['core.default_float_dtype'],
                                          optimise=True,
                                          blockname=ac_name)
             _, vector_statements[ac_name] = statements
         write = set()
         for statements in vector_statements.itervalues():
             for stmt in statements:
                 write.add(stmt.var)
         write = set(varname for varname, var in variables.items()
                     if isinstance(var, ArrayVariable) and varname in write)
         ##################################################################
         prepost = template_kwds['pathway'].prepost
         synaptic_effects = "synapse"
         for varname in variables.iterkeys():
             if varname in write:
                 idx = variable_indices[varname]
                 if (prepost == 'pre' and idx == '_postsynaptic_idx') or (
                         prepost == 'post' and idx == '_presynaptic_idx'):
                     # The SynapticPathways 'target' group variables are modified
                     if synaptic_effects == "synapse":
                         synaptic_effects = "target"
                 if (prepost == 'pre' and idx == '_presynaptic_idx') or (
                         prepost == 'post' and idx == '_postsynaptic_idx'):
                     # The SynapticPathways 'source' group variables are modified
                     synaptic_effects = "source"
         template_kwds["synaptic_effects"] = synaptic_effects
         print('debug syn effect mdoe ', synaptic_effects)
         logger.debug(
             "Synaptic effects of Synapses object {syn} modify {mod} group variables."
             .format(syn=name, mod=synaptic_effects))
     if template_name in [
             "synapses_create_generator", "synapses_create_array"
     ]:
         if owner.multisynaptic_index is not None:
             template_kwds["multisynaptic_idx_var"] = owner.variables[
                 owner.multisynaptic_index]
     codeobj = super(CUDAStandaloneDevice, self).code_object(
         owner,
         name,
         abstract_code,
         variables,
         template_name,
         variable_indices,
         codeobj_class=codeobj_class,
         template_kwds=template_kwds,
         override_conditional_write=override_conditional_write)
     return codeobj
예제 #20
0
    def translate(
        self, code, dtype
    ):  # TODO: it's not so nice we have to copy the contents of this function..
        '''
        Translates an abstract code block into the target language.
        '''
        # first check if user code is not using variables that are also used by GSL
        reserved_variables = [
            '_dataholder', '_fill_y_vector', '_empty_y_vector',
            '_GSL_dataholder', '_GSL_y', '_GSL_func'
        ]
        if any([var in self.variables for var in reserved_variables]):
            # import here to avoid circular import
            raise ValueError(("The variables %s are reserved for the GSL "
                              "internal code." % (str(reserved_variables))))

        # if the following statements are not added, Brian translates the
        # differential expressions in the abstract code for GSL to scalar statements
        # in the case no non-scalar variables are used in the expression
        diff_vars = self.find_differential_variables(code.values())
        self.add_gsl_variables_as_non_scalar(diff_vars)

        # add arrays we want to use in generated code before self.generator.translate() so
        # brian does namespace unpacking for us
        pointer_names = self.add_meta_variables(self.method_options)

        scalar_statements = {}
        vector_statements = {}
        for ac_name, ac_code in code.iteritems():
            statements = make_statements(ac_code,
                                         self.variables,
                                         dtype,
                                         optimise=True,
                                         blockname=ac_name)
            scalar_statements[ac_name], vector_statements[ac_name] = statements
        for vs in vector_statements.itervalues():
            # Check that the statements are meaningful independent on the order of
            # execution (e.g. for synapses)
            try:
                if self.has_repeated_indices(
                        vs
                ):  # only do order dependence if there are repeated indices
                    check_for_order_independence(
                        vs, self.generator.variables,
                        self.generator.variable_indices)
            except OrderDependenceError:
                # If the abstract code is only one line, display it in ful   l
                if len(vs) <= 1:
                    error_msg = 'Abstract code: "%s"\n' % vs[0]
                else:
                    error_msg = (
                        '%_GSL_driver lines of abstract code, first line is: '
                        '"%s"\n') % (len(vs), vs[0])

        # save function names because self.generator.translate_statement_sequence
        # deletes these from self.variables but we need to know which identifiers
        # we can safely ignore (i.e. we can ignore the functions because they are
        # handled by the original generator)
        self.function_names = self.find_function_names()

        scalar_code, vector_code, kwds = self.generator.translate_statement_sequence(
            scalar_statements, vector_statements)

        ############ translate code for GSL

        # first check if any indexing other than '_idx' is used (currently not supported)
        for code_list in scalar_code.values() + vector_code.values():
            for code in code_list:
                m = re.search('\[(\w+)\]', code)
                if m is not None:
                    if m.group(1) != '0' and m.group(1) != '_idx':
                        from brian2.stateupdaters.base import UnsupportedEquationsException
                        raise UnsupportedEquationsException(
                            ("Equations result in state "
                             "updater code with indexing "
                             "other than '_idx', which "
                             "is currently not supported "
                             "in combination with the "
                             "GSL stateupdater."))

        # differential variable specific operations
        to_replace = self.diff_var_to_replace(diff_vars)
        GSL_support_code = self.get_dimension_code(len(diff_vars))
        GSL_support_code += self.yvector_code(diff_vars)

        # analyze all needed variables; if not in self.variables: put in separate dic.
        # also keep track of variables needed for scalar statements and vector statements
        other_variables = self.find_undefined_variables(
            scalar_statements[None] + vector_statements[None])
        variables_in_scalar = self.find_used_variables(scalar_statements[None],
                                                       other_variables)
        variables_in_vector = self.find_used_variables(vector_statements[None],
                                                       other_variables)
        # so that _dataholder holds diff_vars as well, even if they don't occur
        # in the actual statements
        for var in diff_vars.keys():
            if not var in variables_in_vector:
                variables_in_vector[var] = self.variables[var]
        # lets keep track of the variables that eventually need to be added to
        # the _GSL_dataholder somehow
        self.variables_to_be_processed = variables_in_vector.keys()

        # add code for _dataholder struct
        GSL_support_code = self.write_dataholder(
            variables_in_vector) + GSL_support_code
        # add e.g. _lio_1 --> _GSL_dataholder._lio_1 to replacer
        to_replace.update(
            self.to_replace_vector_vars(variables_in_vector,
                                        ignore=diff_vars.keys()))
        # write statements that unpack (python) namespace to _dataholder struct
        # or local namespace
        GSL_main_code = self.unpack_namespace(variables_in_vector,
                                              variables_in_scalar, ['t'])

        # rewrite actual calculations described by vector_code and put them in _GSL_func
        func_code = self.translate_one_statement_sequence(
            vector_statements[None], scalar=False)
        GSL_support_code += self.make_function_code(
            self.translate_vector_code(func_code, to_replace))
        scalar_func_code = self.translate_one_statement_sequence(
            scalar_statements[None], scalar=True)
        # rewrite scalar code, keep variables that are needed in scalar code normal
        # and add variables to _dataholder for vector_code
        GSL_main_code += '\n' + self.translate_scalar_code(
            scalar_func_code, variables_in_scalar, variables_in_vector)
        if len(self.variables_to_be_processed) > 0:
            raise AssertionError(
                ("Not all variables that will be used in the vector "
                 "code have been added to the _GSL_dataholder. This "
                 "might mean that the _GSL_func is using unitialized "
                 "variables."
                 "\nThe unprocessed variables "
                 "are: %s" % (str(self.variables_to_be_processed))))

        scalar_code['GSL'] = GSL_main_code
        kwds['define_GSL_scale_array'] = self.scale_array_code(
            diff_vars, self.method_options)
        kwds['n_diff_vars'] = len(diff_vars)
        kwds['GSL_settings'] = dict(self.method_options)
        kwds['GSL_settings']['integrator'] = self.integrator
        kwds['support_code_lines'] += GSL_support_code.split('\n')
        kwds['t_array'] = self.get_array_name(self.variables['t']) + '[0]'
        kwds['dt_array'] = self.get_array_name(self.variables['dt']) + '[0]'
        kwds['define_dt'] = 'dt' not in variables_in_scalar
        kwds['cpp_standalone'] = self.is_cpp_standalone()
        for key, value in pointer_names.items():
            kwds[key] = value
        return scalar_code, vector_code, kwds
예제 #21
0
파일: gen.py 프로젝트: Kwartke/brian2
from pylab import *
from brian2 import *
from brian2.codegen.generators.cython_generator import CythonCodeGenerator
from brian2.codegen.statements import Statement
from brian2.codegen.translation import make_statements
from brian2.core.variables import Variable, ArrayVariable

owner = None

code = '''
x = a*b+c
'''

variables = {'a':ArrayVariable('a', Unit(1), owner, 10, get_device()),
             'b':ArrayVariable('b', Unit(1), owner, 10, get_device()),
             'x':ArrayVariable('x', Unit(1), owner, 10, get_device()),
             }
namespace = {}
variable_indices = {'a': '_idx', 'b': '_idx', 'x': '_idx'}

gen = CythonCodeGenerator(variables, variable_indices, owner, iterate_all=True, codeobj_class=None)

#print gen.translate_expression('a*b+c')
#print gen.translate_statement(Statement('x', '=', 'a*b+c', '', float))

stmts = make_statements(code, variables, float)
#for stmt in stmts:
# print stmt

print '\n'.join(gen.translate_one_statement_sequence(stmts))
예제 #22
0
        #                                           expr=statement.expr)
        #     line = word_substitute(line, subs)
        #     lines.append(line)
        # return '\n'.join(lines)


if __name__ == '__main__':
    from brian2.codegen.translation import make_statements
    from brian2.core.variables import ArrayVariable
    from brian2 import device
    from numpy import float64
    code = '''
    w_syn = v_pre
    v_pre += -a_syn # change operator -= or += to see efficient/inefficient code
    x_post = y_post
    '''
    indices = {
        'w_syn': '_idx',
        'a_syn': '_idx',
        'u_pre': '_presynaptic_idx',
        'v_pre': '_presynaptic_idx',
        'x_post': '_postsynaptic_idx',
        'y_post': '_postsynaptic_idx'
    }
    variables = dict()
    for var in indices:
        variables[var] = ArrayVariable(var, 1, None, 10, device)
    scalar_statements, vector_statements = make_statements(
        code, variables, float64)
    print vectorise_synapses_code(vector_statements, variables, indices)
예제 #23
0
    def translate(self, code, dtype): # TODO: it's not so nice we have to copy the contents of this function..
        '''
        Translates an abstract code block into the target language.
        '''
        # first check if user code is not using variables that are also used by GSL
        reserved_variables = ['_dataholder', '_fill_y_vector', '_empty_y_vector',
                              '_GSL_dataholder', '_GSL_y', '_GSL_func']
        if any([var in self.variables for var in reserved_variables]):
            # import here to avoid circular import
            raise ValueError(("The variables %s are reserved for the GSL "
                              "internal code."%(str(reserved_variables))))

        # if the following statements are not added, Brian translates the
        # differential expressions in the abstract code for GSL to scalar statements
        # in the case no non-scalar variables are used in the expression
        diff_vars = self.find_differential_variables(code.values())
        self.add_gsl_variables_as_non_scalar(diff_vars)

        # add arrays we want to use in generated code before self.generator.translate() so
        # brian does namespace unpacking for us
        pointer_names = self.add_meta_variables(self.method_options)

        scalar_statements = {}
        vector_statements = {}
        for ac_name, ac_code in code.iteritems():
            statements = make_statements(ac_code,
                                         self.variables,
                                         dtype,
                                         optimise=True,
                                         blockname=ac_name)
            scalar_statements[ac_name], vector_statements[ac_name] = statements
        for vs in vector_statements.itervalues():
            # Check that the statements are meaningful independent on the order of
            # execution (e.g. for synapses)
            try:
                if self.has_repeated_indices(vs):     # only do order dependence if there are repeated indices
                    check_for_order_independence(vs,
                                                 self.generator.variables,
                                                 self.generator.variable_indices)
            except OrderDependenceError:
                # If the abstract code is only one line, display it in ful   l
                if len(vs) <= 1:
                    error_msg = 'Abstract code: "%s"\n' % vs[0]
                else:
                    error_msg = ('%_GSL_driver lines of abstract code, first line is: '
                                 '"%s"\n') % (len(vs), vs[0])

        # save function names because self.generator.translate_statement_sequence
        # deletes these from self.variables but we need to know which identifiers
        # we can safely ignore (i.e. we can ignore the functions because they are
        # handled by the original generator)
        self.function_names = self.find_function_names()

        scalar_code, vector_code, kwds = self.generator.translate_statement_sequence(scalar_statements,
                                                                                     vector_statements)

        ############ translate code for GSL

        # first check if any indexing other than '_idx' is used (currently not supported)
        for code_list in scalar_code.values()+vector_code.values():
            for code in code_list:
                m = re.search('\[(\w+)\]', code)
                if m is not None:
                    if m.group(1) != '0' and m.group(1) != '_idx':
                        from brian2.stateupdaters.base import UnsupportedEquationsException
                        raise UnsupportedEquationsException(("Equations result in state "
                                                             "updater code with indexing "
                                                             "other than '_idx', which "
                                                             "is currently not supported "
                                                             "in combination with the "
                                                             "GSL stateupdater."))

        # differential variable specific operations
        to_replace = self.diff_var_to_replace(diff_vars)
        GSL_support_code = self.get_dimension_code(len(diff_vars))
        GSL_support_code += self.yvector_code(diff_vars)

        # analyze all needed variables; if not in self.variables: put in separate dic.
        # also keep track of variables needed for scalar statements and vector statements
        other_variables = self.find_undefined_variables(scalar_statements[None] +
                                                        vector_statements[None])
        variables_in_scalar = self.find_used_variables(scalar_statements[None],
                                                       other_variables)
        variables_in_vector = self.find_used_variables(vector_statements[None],
                                                       other_variables)
        # so that _dataholder holds diff_vars as well, even if they don't occur
        # in the actual statements
        for var in diff_vars.keys():
            if not var in variables_in_vector:
                variables_in_vector[var] = self.variables[var]
        # lets keep track of the variables that eventually need to be added to
        # the _GSL_dataholder somehow
        self.variables_to_be_processed = variables_in_vector.keys()

        # add code for _dataholder struct
        GSL_support_code = self.write_dataholder(variables_in_vector) + GSL_support_code
        # add e.g. _lio_1 --> _GSL_dataholder._lio_1 to replacer
        to_replace.update(self.to_replace_vector_vars(variables_in_vector,
                                                      ignore=diff_vars.keys()))
        # write statements that unpack (python) namespace to _dataholder struct
        # or local namespace
        GSL_main_code = self.unpack_namespace(variables_in_vector, variables_in_scalar, ['t'])

        # rewrite actual calculations described by vector_code and put them in _GSL_func
        func_code = self.translate_one_statement_sequence(vector_statements[None],
                                                          scalar=False)
        GSL_support_code += self.make_function_code(self.translate_vector_code(func_code,
                                                                               to_replace))
        scalar_func_code = self.translate_one_statement_sequence(scalar_statements[None],
                                                                 scalar=True)
        # rewrite scalar code, keep variables that are needed in scalar code normal
        # and add variables to _dataholder for vector_code
        GSL_main_code += '\n' + self.translate_scalar_code(scalar_func_code,
                                                    variables_in_scalar,
                                                    variables_in_vector)
        if len(self.variables_to_be_processed) > 0:
            raise AssertionError(("Not all variables that will be used in the vector "
                                  "code have been added to the _GSL_dataholder. This "
                                  "might mean that the _GSL_func is using unitialized "
                                  "variables."
                                  "\nThe unprocessed variables "
                                  "are: %s" % (str(self.variables_to_be_processed))))

        scalar_code['GSL'] = GSL_main_code
        kwds['define_GSL_scale_array'] = self.scale_array_code(diff_vars,
                                                               self.method_options)
        kwds['n_diff_vars'] = len(diff_vars)
        kwds['GSL_settings'] = dict(self.method_options)
        kwds['GSL_settings']['integrator'] = self.integrator
        kwds['support_code_lines'] += GSL_support_code.split('\n')
        kwds['t_array'] = self.get_array_name(self.variables['t']) + '[0]'
        kwds['dt_array'] = self.get_array_name(self.variables['dt']) + '[0]'
        kwds['define_dt'] = 'dt' not in variables_in_scalar
        kwds['cpp_standalone'] = self.is_cpp_standalone()
        for key, value in pointer_names.items():
            kwds[key] = value
        return scalar_code, vector_code, kwds
예제 #24
0
from codeprint import codeprint

abstract = """
V += w
"""

specifiers = {
    "V": ArrayVariable("_array_V", "_postsynaptic_idx", float64),
    "w": ArrayVariable("_array_w", "_synapse_idx", float64),
    "_spiking_synapse_idx": Index(),
    "_postsynaptic_idx": Index(all=False),
    "_synapse_idx": Index(all=False),
    "_presynaptic_idx": Index(all=False),
}

intermediate = make_statements(abstract, specifiers, float64)

print "ABSTRACT CODE:"
print abstract
print "INTERMEDIATE STATEMENTS:"
print
for stmt in intermediate:
    print stmt
print

for lang in [PythonLanguage(), CPPLanguage()]:
    innercode = translate(abstract, specifiers, float64, lang)
    code = lang.apply_template(innercode, lang.template_synapses())
    print lang.__class__.__name__
    print "=" * len(lang.__class__.__name__)
    codeprint(code)
from codeprint import codeprint

test_compile = True
do_plot = True

tau = 10 * ms # external variable
eqs = '''   
    dV/dt = (-V + I + J)/tau : 1
    dI/dt = -I/tau : 1
    J = V * 0.1 : 1
    '''
from brian2.stateupdaters.integration import euler, rk2, rk4
G = NeuronGroup(1, eqs, method=rk4)

intermediate = make_statements(G.abstract_code, G.specifiers, float64)

print 'EQUATIONS:'
print eqs
print 'ABSTRACT CODE:'
print G.abstract_code
print 'INTERMEDIATE STATEMENTS:'
print
for stmt in intermediate:
    print stmt
print

def getlang(cls, *args, **kwds):
    try:
        return cls(*args, **kwds)
    except Exception as e:
예제 #26
0
        #     subs[var] = '{var}[{idx}]'.format(var=var, idx=idx)
        # for statement in statements:
        #     line = '    {var} {op} {expr}'.format(var=statement.var, op=statement.op,
        #                                           expr=statement.expr)
        #     line = word_substitute(line, subs)
        #     lines.append(line)
        # return '\n'.join(lines)


if __name__=='__main__':
    from brian2.codegen.translation import make_statements
    from brian2.core.variables import ArrayVariable
    from brian2 import device
    from numpy import float64
    code = '''
    w_syn = v_pre
    v_pre += -a_syn # change operator -= or += to see efficient/inefficient code
    x_post = y_post
    '''
    indices = {'w_syn': '_idx',
               'a_syn': '_idx',
               'u_pre': '_presynaptic_idx',
               'v_pre': '_presynaptic_idx',
               'x_post': '_postsynaptic_idx',
               'y_post': '_postsynaptic_idx'}
    variables = dict()
    for var in indices:
        variables[var] = ArrayVariable(var, 1, None, 10, device)
    scalar_statements, vector_statements = make_statements(code, variables, float64)
    print vectorise_synapses_code(vector_statements, variables, indices)