コード例 #1
0
ファイル: test_codegen.py プロジェクト: treestreamymw/brian2
def test_nested_subexpressions():
    '''
    This test checks that code translation works with nested subexpressions.
    '''
    code = '''
    x = a + b + c
    c = 1
    x = a + b + c
    d = 1
    x = a + b + c
    '''
    variables = {
        'a': Subexpression(name='a', dtype=np.float32, owner=FakeGroup(variables={}), device=None,
                           expr='b*b+d'),
        'b': Subexpression(name='b', dtype=np.float32, owner=FakeGroup(variables={}), device=None,
                           expr='c*c*c'),
        'c': Variable(name='c'),
        'd': Variable(name='d'),
        }
    scalar_stmts, vector_stmts = make_statements(code, variables, np.float32)
    assert len(scalar_stmts) == 0
    evalorder = ''.join(stmt.var for stmt in vector_stmts)
    # This is the order that variables ought to be evaluated in (note that
    # previously this test did not expect the last "b" evaluation, because its
    # value did not change (c was not changed). We have since removed this
    # subexpression caching, because it did not seem to apply in practical
    # use cases)
    assert evalorder == 'baxcbaxdbax'
コード例 #2
0
def test_nested_subexpressions():
    '''
    This test checks that code translation works with nested subexpressions.
    '''
    code = '''
    x = a + b + c
    c = 1
    x = a + b + c
    d = 1
    x = a + b + c
    '''
    variables = {
        'a':
        Subexpression(name='a',
                      unit=Unit(1),
                      dtype=np.float32,
                      owner=FakeGroup(variables={}),
                      device=None,
                      expr='b*b+d'),
        'b':
        Subexpression(name='b',
                      unit=Unit(1),
                      dtype=np.float32,
                      owner=FakeGroup(variables={}),
                      device=None,
                      expr='c*c*c'),
        'c':
        Variable(unit=None, name='c'),
        'd':
        Variable(unit=None, name='d'),
    }
    stmts = make_statements(code, variables, np.float32)
    evalorder = ''.join(stmt.var for stmt in stmts)
    # This is the order that variables ought to be evaluated in
    assert evalorder == 'baxcbaxdax'
コード例 #3
0
ファイル: test_codegen.py プロジェクト: divyashivaram/brian2
def test_get_identifiers_recursively():
    '''
    Test finding identifiers including subexpressions.
    '''
    variables = {}
    variables['sub1'] = Subexpression(Unit(1), np.float32, 'sub2 * z',
                                      variables, {})
    variables['sub2'] = Subexpression(Unit(1), np.float32, '5 + y', variables,
                                      {})
    variables['x'] = Variable(unit=None)
    identifiers = get_identifiers_recursively('_x = sub1 + x', variables)
    assert identifiers == set(['x', '_x', 'y', 'z', 'sub1', 'sub2'])
コード例 #4
0
ファイル: test_codegen.py プロジェクト: treestreamymw/brian2
def test_get_identifiers_recursively():
    '''
    Test finding identifiers including subexpressions.
    '''
    variables = {'sub1': Subexpression(name='sub1',
                                       dtype=np.float32, expr='sub2 * z',
                                       owner=FakeGroup(variables={}),
                                       device=None),
                 'sub2': Subexpression(name='sub2',
                                       dtype=np.float32, expr='5 + y',
                                       owner=FakeGroup(variables={}),
                                       device=None),
                 'x': Variable(name='x')}
    identifiers = get_identifiers_recursively(['_x = sub1 + x'],
                                              variables)
    assert identifiers == {'x', '_x', 'y', 'z', 'sub1', 'sub2'}
コード例 #5
0
def test_write_to_subexpression():
    variables = {
        'a': Subexpression(name='a', dtype=np.float32,
                           owner=FakeGroup(variables={}), device=None,
                           expr='2*z'),
        'z': Variable(name='z')
    }

    # Writing to a subexpression is not allowed
    code = 'a = z'
    assert_raises(SyntaxError, make_statements, code, variables, np.float32)
コード例 #6
0
    def _create_variables(self):
        '''
        Create the variables dictionary for this `NeuronGroup`, containing
        entries for the equation variables and some standard entries.
        '''
        # Get the standard variables for all groups
        s = Group._create_variables(self)

        # Standard variables always present
        s.update({
            '_spikespace':
            ArrayVariable('_spikespace',
                          Unit(1),
                          self._spikespace,
                          group_name=self.name)
        })
        s.update({
            '_spikes':
            AttributeVariable(Unit(1), self, 'spikes', constant=False)
        })

        for eq in self.equations.itervalues():
            if eq.type in (DIFFERENTIAL_EQUATION, PARAMETER):
                array = self.arrays[eq.varname]
                constant = ('constant' in eq.flags)
                s.update({
                    eq.varname:
                    ArrayVariable(eq.varname,
                                  eq.unit,
                                  array,
                                  group_name=self.name,
                                  constant=constant,
                                  is_bool=eq.is_bool)
                })

            elif eq.type == STATIC_EQUATION:
                s.update({
                    eq.varname:
                    Subexpression(eq.unit,
                                  brian_prefs['core.default_scalar_dtype'],
                                  str(eq.expr),
                                  variables=s,
                                  namespace=self.namespace,
                                  is_bool=eq.is_bool)
                })
            else:
                raise AssertionError('Unknown type of equation: ' + eq.eq_type)

        # Stochastic variables
        for xi in self.equations.stochastic_variables:
            s.update({xi: StochasticVariable()})

        return s
コード例 #7
0
ファイル: test_codegen.py プロジェクト: rgerkin/brian2
def test_repeated_subexpressions():
    variables = {
        'a':
        Subexpression(name='a',
                      dtype=np.float32,
                      owner=FakeGroup(variables={}),
                      device=None,
                      expr='2*z'),
        'x':
        Variable(name='x'),
        'y':
        Variable(name='y'),
        'z':
        Variable(name='z')
    }
    # subexpression a (referring to z) is used twice, but can be reused the
    # second time (no change to z)
    code = '''
    x = a
    y = a
    '''
    scalar_stmts, vector_stmts = make_statements(code, variables, np.float32)
    assert len(scalar_stmts) == 0
    assert [stmt.var for stmt in vector_stmts] == ['a', 'x', 'y']
    assert vector_stmts[0].constant

    code = '''
    x = a
    z *= 2
    '''
    scalar_stmts, vector_stmts = make_statements(code, variables, np.float32)
    assert len(scalar_stmts) == 0
    assert [stmt.var for stmt in vector_stmts] == ['a', 'x', 'z']
    # Note that we currently do not mark the subexpression as constant in this
    # case, because its use after the "z *=2" line would actually redefine it.
    # Our algorithm is currently not smart enough to detect that it is actually
    # not used afterwards

    # a refers to z, therefore we have to redefine a after z changed, and a
    # cannot be constant
    code = '''
    x = a
    z *= 2
    y = a
    '''
    scalar_stmts, vector_stmts = make_statements(code, variables, np.float32)
    assert len(scalar_stmts) == 0
    assert [stmt.var for stmt in vector_stmts] == ['a', 'x', 'z', 'a', 'y']
    assert not any(stmt.constant for stmt in vector_stmts)
コード例 #8
0
ファイル: synapses.py プロジェクト: divyashivaram/brian2
    def _create_variables(self):
        '''
        Create the variables dictionary for this `Synapses`, containing
        entries for the equation variables and some standard entries.
        '''
        # Add all the pre and post variables with _pre and _post suffixes
        v = {}
        self.variable_indices = defaultdict(lambda: '_idx')
        for name, var in getattr(self.source, 'variables', {}).iteritems():
            if isinstance(var, (ArrayVariable, Subexpression)):
                v[name + '_pre'] = var
                self.variable_indices[name + '_pre'] = '_presynaptic_idx'
        for name, var in getattr(self.target, 'variables', {}).iteritems():
            if isinstance(var, (ArrayVariable, Subexpression)):
                v[name + '_post'] = var
                self.variable_indices[name + '_post'] = '_postsynaptic_idx'
                # Also add all the post variables without a suffix -- if this
                # clashes with the name of a state variable defined in this
                # Synapses group, the latter will overwrite the entry later and
                # take precedence
                v[name] = var
                self.variable_indices[name] = '_postsynaptic_idx'

        # Standard variables always present
        v.update({
            't':
            AttributeVariable(second, self.clock, 't_', constant=False),
            'dt':
            AttributeVariable(second, self.clock, 'dt_', constant=True),
            '_num_source_neurons':
            Variable(Unit(1), len(self.source), constant=True),
            '_num_target_neurons':
            Variable(Unit(1), len(self.target), constant=True),
            '_synaptic_pre':
            DynamicArrayVariable('_synaptic_pre', Unit(1),
                                 self.item_mapping.synaptic_pre),
            '_synaptic_post':
            DynamicArrayVariable('_synaptic_pre', Unit(1),
                                 self.item_mapping.synaptic_post),
            # We don't need "proper" specifier for these -- they go
            # back to Python code currently
            '_pre_synaptic':
            Variable(Unit(1), self.item_mapping.pre_synaptic),
            '_post_synaptic':
            Variable(Unit(1), self.item_mapping.post_synaptic)
        })

        for eq in itertools.chain(
                self.equations.itervalues(),
                self.event_driven.itervalues()
                if self.event_driven is not None else []):
            if eq.type in (DIFFERENTIAL_EQUATION, PARAMETER):
                array = self.arrays[eq.varname]
                constant = ('constant' in eq.flags)
                # We are dealing with dynamic arrays here, code generation
                # shouldn't directly access the specifier.array attribute but
                # use specifier.get_value() to get a reference to the underlying
                # array
                v[eq.varname] = DynamicArrayVariable(eq.varname,
                                                     eq.unit,
                                                     array,
                                                     group_name=self.name,
                                                     constant=constant,
                                                     is_bool=eq.is_bool)
                if eq.varname in self.variable_indices:
                    # we are overwriting a postsynaptic variable of the same
                    # name, delete the reference to the postsynaptic index
                    del self.variable_indices[eq.varname]
                # Register the array with the `SynapticItemMapping` object so
                # it gets automatically resized
                self.item_mapping.register_variable(array)
            elif eq.type == STATIC_EQUATION:
                v.update({
                    eq.varname:
                    Subexpression(eq.unit,
                                  brian_prefs['core.default_scalar_dtype'],
                                  str(eq.expr),
                                  variables=v,
                                  namespace=self.namespace,
                                  is_bool=eq.is_bool)
                })
            else:
                raise AssertionError('Unknown type of equation: ' + eq.eq_type)

        # Stochastic variables
        for xi in self.equations.stochastic_variables:
            v.update({xi: StochasticVariable()})

        return v