示例#1
0
def codeprint(code):
    if isinstance(code, str):
        print deindent(code)
    else:
        for k, v in code.items():
            print k+':'
            print indent(deindent(v))
示例#2
0
 def __init__(self):
     self.prefs = {}
     self.backup_prefs = {}
     self.prefs_unvalidated = {}
     self.pref_register = {}
     self.eval_namespace = {}
     exec deindent('''
         from numpy import *
         from brian2.units import *            
         from brian2.units.stdunits import *
         ''') in self.eval_namespace
示例#3
0
 def translate_statement_sequence(self, statements, specifiers):
     read, write = self.array_read_write(statements, specifiers)
     lines = []
     # read arrays
     for var in read:
         index_var = specifiers[var].index
         index_spec = specifiers[index_var]
         spec = specifiers[var]
         if var not in write:
             line = 'const '
         else:
             line = ''
         line = line+c_data_type(spec.dtype)+' '+var+' = '
         line = line+'_ptr'+spec.array+'['+index_var+'];'
         lines.append(line)
     # simply declare variables that will be written but not read
     for var in write:
         if var not in read:
             spec = specifiers[var]
             line = c_data_type(spec.dtype)+' '+var+';'
             lines.append(line)
     # the actual code
     lines.extend([self.translate_statement(stmt) for stmt in statements])
     # write arrays
     for var in write:
         index_var = specifiers[var].index
         index_spec = specifiers[index_var]
         spec = specifiers[var]
         line = '_ptr'+spec.array+'['+index_var+'] = '+var+';'
         lines.append(line)
     code = '\n'.join(lines)
     # set up the restricted pointers, these are used so that the compiler
     # knows there is no aliasing in the pointers, for optimisation
     lines = []
     for var in read.union(write):
         spec = specifiers[var]
         line = c_data_type(spec.dtype)+' * '+self.restrict+'_ptr'+spec.array+' = '+spec.array+';'
         lines.append(line)
     pointers = '\n'.join(lines)
     # set up the user-defined functions
     support_code = ''
     hash_defines = ''
     for var, spec in specifiers.items():
         if isinstance(spec, UserFunction):
             speccode = spec.code(self, var)
             support_code += '\n'+deindent(speccode['support_code'])
             hash_defines += deindent(speccode['hashdefine_code'])
     # return
     translation = {'%CODE%': code,
                    '%POINTERS%': pointers,
                    '%SUPPORT_CODE%': support_code,
                    '%HASHDEFINES%': hash_defines,
                    }
     return translation
示例#4
0
 def _as_pref_file(self, valuefunc):
     '''
     Helper function used to generate the preference file for the default or current preference values.
     '''
     s = ''
     for basename, (prefdefs, basedoc) in self.pref_register.items():
         s += '#' + '-' * 79 + '\n'
         s += '\n'.join(['# ' + line for line in deindent(basedoc, docstring=True).strip().split('\n')]) + '\n'
         s += '#' + '-' * 79 + '\n\n'
         s += '[' + basename + ']\n\n'
         for name in sorted(prefdefs.keys()):
             pref = prefdefs[name]
             s += '\n'.join(['# ' + line for line in deindent(pref.docs, docstring=True).strip().split('\n')]) + '\n\n'
             s += name + ' = ' + pref.representor(valuefunc(pref, basename + '.' + name)) + '\n\n'
     return s
示例#5
0
 def apply_template(self, code, template):
     '''
     Applies the inner code to the template. The code should either be a
     string (in which case it goes in the ``%CODE%`` slot) or it should be
     a dict of pairs ``(slot, section)`` where the string ``section``
     goes in slot ``slot``. The template should be a string (in which case
     it is assigned to the slot ``%MAIN%`` or a dict of ``(slot, code)``
     pairs. Returns either a string (if the template was a string) or a
     dict with the same keys as the template.
     '''
     if isinstance(code, str):
         code = {'%CODE%': code}
     if isinstance(template, str):
         return_str = True
         template = {'%MAIN%': template}
     else:
         return_str = False
     output = template.copy()
     for name, tmp in output.items():
         tmp = deindent(tmp)
         for slot, section in code.items():
             tmp = apply_code_template(section, tmp, placeholder=slot)
         output[name] = tmp
     if return_str:
         return output['%MAIN%']
     else:
         return output
示例#6
0
def test_substitute_abstract_code_functions():
    def f(x):
        y = x * x
        return y

    def g(x):
        return f(x) + 1

    code = '''
    z = f(x)
    z = f(x)+f(y)
    w = f(z)
    h = f(f(w))
    p = g(g(x))
    '''
    funcs = [
        abstract_code_from_function(f),
        abstract_code_from_function(g),
    ]
    subcode = substitute_abstract_code_functions(code, funcs)
    for x, y in [(0, 1), (1, 0), (0.124323, 0.4549483)]:
        ns1 = {'x': x, 'y': y, 'f': f, 'g': g}
        ns2 = {'x': x, 'y': y}
        exec(deindent(code), ns1)
        exec(subcode, ns2)
        for k in ['z', 'w', 'h', 'p']:
            assert ns1[k] == ns2[k]
示例#7
0
def extract_abstract_code_functions(code):
    '''
    Returns a set of abstract code functions from function definitions.
    
    Returns all functions defined at the top level and ignores any other
    code in the string.
    
    Parameters
    ----------
    code : str
        The code string defining some functions.
        
    Returns
    -------
    funcs : dict
        A mapping ``(name, func)`` for ``func`` an `AbstractCodeFunction`.
    '''
    code = deindent(code)
    nodes = ast.parse(code, mode='exec').body
    funcs = {}
    for node in nodes:
        if node.__class__ is ast.FunctionDef:
            func = abstract_code_from_function(node)
            funcs[func.name] = func
    return funcs
示例#8
0
def extract_abstract_code_functions(code):
    '''
    Returns a set of abstract code functions from function definitions.
    
    Returns all functions defined at the top level and ignores any other
    code in the string.
    
    Parameters
    ----------
    code : str
        The code string defining some functions.
        
    Returns
    -------
    funcs : dict
        A mapping ``(name, func)`` for ``func`` an `AbstractCodeFunction`.
    '''
    code = deindent(code)
    nodes = ast.parse(code, mode='exec').body
    funcs = {}
    for node in nodes:
        if node.__class__ is ast.FunctionDef:
            func = abstract_code_from_function(node)
            funcs[func.name] = func
    return funcs
示例#9
0
    def create_extension(self, code, force=False, name=None,
                         include_dirs=None,
                         library_dirs=None,
                         runtime_library_dirs=None,
                         extra_compile_args=None,
                         extra_link_args=None,
                         libraries=None,
                         compiler=None,
                         ):

        if Cython is None:
            raise ImportError('Cython is not available')

        code = deindent(code)

        lib_dir = os.path.expanduser('~/.brian/cython_extensions')
        try:
            os.makedirs(lib_dir)
        except OSError:
            if not os.path.exists(lib_dir):
                raise

        key = code, sys.version_info, sys.executable, Cython.__version__
            
        if force:
            # Force a new module name by adding the current time to the
            # key which is hashed to determine the module name.
            key += time.time(),            

        if key in self._code_cache:
            return self._code_cache[key]

        if name is not None:
            module_name = name#py3compat.unicode_to_str(args.name)
        else:
            module_name = "_cython_magic_" + hashlib.md5(str(key).encode('utf-8')).hexdigest()



        module_path = os.path.join(lib_dir, module_name + self.so_ext)

        if prefs['codegen.runtime.cython.multiprocess_safe']:
            lock_file = os.path.join(lib_dir, module_name + '.lock')
            with open(lock_file, 'w') as f:
                if msvcrt:
                    msvcrt.locking(f.fileno(), msvcrt.LK_RLCK,
                                   os.stat(lock_file).st_size)
                else:
                    fcntl.flock(f, fcntl.LOCK_EX)
                return self._load_module(module_path, include_dirs,
                                         library_dirs,
                                         extra_compile_args, extra_link_args,
                                         libraries, code, lib_dir, module_name,
                                         runtime_library_dirs, compiler, key)
        else:
            return self._load_module(module_path, include_dirs, library_dirs,
                                     extra_compile_args, extra_link_args,
                                     libraries, code, lib_dir, module_name,
                                     runtime_library_dirs, compiler, key)
示例#10
0
    def _add_user_function(self, varname, variable):
        impl = variable.implementations[self.codeobj_class]
        support_code = []
        hash_defines = []
        pointers = []
        user_functions = [(varname, variable)]
        funccode = impl.get_code(self.owner)
        if isinstance(funccode, basestring):
            funccode = {'support_code': funccode}
        if funccode is not None:
            # To make namespace variables available to functions, we
            # create global variables and assign to them in the main
            # code
            func_namespace = impl.get_namespace(self.owner) or {}
            for ns_key, ns_value in func_namespace.iteritems():
                if hasattr(ns_value, 'dtype'):
                    if ns_value.shape == ():
                        raise NotImplementedError((
                        'Directly replace scalar values in the function '
                        'instead of providing them via the namespace'))
                    type_str = c_data_type(ns_value.dtype) + '*'
                else:  # e.g. a function
                    type_str = 'py::object'
                support_code.append('static {0} _namespace{1};'.format(type_str,
                                                                       ns_key))
                pointers.append('_namespace{0} = {1};'.format(ns_key, ns_key))
            support_code.append(deindent(funccode.get('support_code', '')))
            hash_defines.append(deindent(funccode.get('hashdefine_code', '')))

        dep_hash_defines = []
        dep_pointers = []
        dep_support_code = []
        if impl.dependencies is not None:
            for dep_name, dep in impl.dependencies.iteritems():
                if dep_name not in self.variables:  # do not add a dependency twice
                    self.variables[dep_name] = dep
                    hd, ps, sc, uf = self._add_user_function(dep_name, dep)
                    dep_hash_defines.extend(hd)
                    dep_pointers.extend(ps)
                    dep_support_code.extend(sc)
                    user_functions.extend(uf)

        return (dep_hash_defines + hash_defines,
                dep_pointers + pointers,
                dep_support_code + support_code,
                user_functions)
示例#11
0
    def _add_user_function(self, varname, variable):
        impl = variable.implementations[self.codeobj_class]
        support_code = []
        hash_defines = []
        pointers = []
        user_functions = [(varname, variable)]
        funccode = impl.get_code(self.owner)
        if isinstance(funccode, basestring):
            funccode = {'support_code': funccode}
        if funccode is not None:
            # To make namespace variables available to functions, we
            # create global variables and assign to them in the main
            # code
            func_namespace = impl.get_namespace(self.owner) or {}
            for ns_key, ns_value in func_namespace.iteritems():
                if hasattr(ns_value, 'dtype'):
                    if ns_value.shape == ():
                        raise NotImplementedError((
                        'Directly replace scalar values in the function '
                        'instead of providing them via the namespace'))
                    type_str = c_data_type(ns_value.dtype) + '*'
                else:  # e.g. a function
                    type_str = 'py::object'
                support_code.append('static {0} _namespace{1};'.format(type_str,
                                                                       ns_key))
                pointers.append('_namespace{0} = {1};'.format(ns_key, ns_key))
            support_code.append(deindent(funccode.get('support_code', '')))
            hash_defines.append(deindent(funccode.get('hashdefine_code', '')))

        dep_hash_defines = []
        dep_pointers = []
        dep_support_code = []
        if impl.dependencies is not None:
            for dep_name, dep in impl.dependencies.iteritems():
                self.variables[dep_name] = dep
                hd, ps, sc, uf = self._add_user_function(dep_name, dep)
                dep_hash_defines.extend(hd)
                dep_pointers.extend(ps)
                dep_support_code.extend(sc)
                user_functions.extend(uf)

        return (dep_hash_defines + hash_defines,
                dep_pointers + pointers,
                dep_support_code + support_code,
                user_functions)
示例#12
0
def abstract_code_from_function(func):
    '''
    Converts the body of the function to abstract code
    
    Parameters
    ----------
    func : function, str or ast.FunctionDef
        The function object to convert. Note that the arguments to the
        function are ignored.
        
    Returns
    -------
    func : AbstractCodeFunction
        The corresponding abstract code function
        
    Raises
    ------
    SyntaxError
        If unsupported features are used such as if statements or indexing.
    '''
    if callable(func):
        code = deindent(inspect.getsource(func))
        funcnode = ast.parse(code, mode='exec').body[0]
    elif isinstance(func, str):
        funcnode = ast.parse(func, mode='exec').body[0]
    elif func.__class__ is ast.FunctionDef:
        funcnode = func
    else:
        raise TypeError("Unsupported function type")
    
    if funcnode.args.vararg is not None:
        raise SyntaxError("No support for variable number of arguments")
    if funcnode.args.kwarg is not None:
        raise SyntaxError("No support for arbitrary keyword arguments")
    if len(funcnode.args.defaults):
        raise SyntaxError("No support for default values in functions")
    
    nodes = funcnode.body
    nr = NodeRenderer()
    lines = []
    return_expr = None
    for node in nodes:
        if node.__class__ is ast.Return:
            return_expr = nr.render_node(node.value)
            break
        else:
            lines.append(nr.render_node(node))
    abstract_code = '\n'.join(lines)
    try:
        # Python 2
        args = [arg.id for arg in funcnode.args.args]
    except AttributeError:
        # Python 3
        args = [arg.arg for arg in funcnode.args.args]
    name = funcnode.name
    return AbstractCodeFunction(name, args, abstract_code, return_expr)
示例#13
0
def abstract_code_from_function(func):
    '''
    Converts the body of the function to abstract code
    
    Parameters
    ----------
    func : function, str or ast.FunctionDef
        The function object to convert. Note that the arguments to the
        function are ignored.
        
    Returns
    -------
    func : AbstractCodeFunction
        The corresponding abstract code function
        
    Raises
    ------
    SyntaxError
        If unsupported features are used such as if statements or indexing.
    '''
    if callable(func):
        code = deindent(inspect.getsource(func))
        funcnode = ast.parse(code, mode='exec').body[0]
    elif isinstance(func, str):
        funcnode = ast.parse(func, mode='exec').body[0]
    elif func.__class__ is ast.FunctionDef:
        funcnode = func
    else:
        raise TypeError("Unsupported function type")

    if funcnode.args.vararg is not None:
        raise SyntaxError("No support for variable number of arguments")
    if funcnode.args.kwarg is not None:
        raise SyntaxError("No support for arbitrary keyword arguments")
    if len(funcnode.args.defaults):
        raise SyntaxError("No support for default values in functions")

    nodes = funcnode.body
    nr = NodeRenderer()
    lines = []
    return_expr = None
    for node in nodes:
        if node.__class__ is ast.Return:
            return_expr = nr.render_node(node.value)
            break
        else:
            lines.append(nr.render_node(node))
    abstract_code = '\n'.join(lines)
    try:
        # Python 2
        args = [arg.id for arg in funcnode.args.args]
    except AttributeError:
        # Python 3
        args = [arg.arg for arg in funcnode.args.args]
    name = funcnode.name
    return AbstractCodeFunction(name, args, abstract_code, return_expr)
示例#14
0
 def __init__(self):
     self.prefs = {}
     self.backup_prefs = {}
     self.prefs_unvalidated = {}
     self.pref_register = {}
     self.eval_namespace = {}
     exec(deindent("""
         from numpy import *
         from brian2.units import *            
         from brian2.units.stdunits import *
         """), self.eval_namespace)
示例#15
0
def numerically_check_permutation_code(code):
    # numerically checks that a code block used in the test below is permutation-independent by creating a
    # presynaptic and postsynaptic group of 3 neurons each, and a full connectivity matrix between them, then
    # repeatedly filling in random values for each of the variables, and checking for several random shuffles of
    # the synapse order that the result doesn't depend on it. This is a sort of test of the test itself, to make
    # sure we didn't accidentally assign a good/bad example to the wrong class.
    code = deindent(code)
    from collections import defaultdict
    vars = get_identifiers(code)
    indices = defaultdict(lambda: '_idx')
    vals = {}
    for var in vars:
        if var.endswith('_syn'):
            indices[var] = '_idx'
            vals[var] = zeros(9)
        elif var.endswith('_pre'):
            indices[var] ='_presynaptic_idx'
            vals[var] = zeros(3)
        elif var.endswith('_post'):
            indices[var] = '_postsynaptic_idx'
            vals[var] = zeros(3)
    subs = dict((var, var+'['+idx+']') for var, idx in indices.iteritems())
    code = word_substitute(code, subs)
    code = '''
from numpy import *
from numpy.random import rand, randn
for _idx in shuffled_indices:
    _presynaptic_idx = presyn[_idx]
    _postsynaptic_idx = postsyn[_idx]
{code}
    '''.format(code=indent(code))
    ns = vals.copy()
    ns['shuffled_indices'] = arange(9)
    ns['presyn'] = arange(9)%3
    ns['postsyn'] = arange(9)/3
    for _ in xrange(10):
        origvals = {}
        for k, v in vals.iteritems():
            v[:] = randn(len(v))
            origvals[k] = v.copy()
        exec code in ns
        endvals = {}
        for k, v in vals.iteritems():
            endvals[k] = v.copy()
        for _ in xrange(10):
            for k, v in vals.iteritems():
                v[:] = origvals[k]
            shuffle(ns['shuffled_indices'])
            exec code in ns
            for k, v in vals.iteritems():
                try:
                    assert_allclose(v, endvals[k])
                except AssertionError:
                    raise OrderDependenceError()
示例#16
0
 def _get_documentation(self):
     s = ''
     for name in sorted(self._values.keys()):
         default = self._default_values[name]
         doc = str(self._docs[name])
         # Make a link target
         s += '.. _brian-pref-{name}:\n\n'.format(name=name.replace('_', '-'))
         s += '``{name}`` = ``{default}``\n'.format(name=name,
                                                    default=repr(default))
         s += indent(deindent(doc))
         s += '\n\n'
     return s
示例#17
0
    def _get_one_documentation(self, basename, link_targets):
        '''
        Document a single category of preferences.
        '''

        s = ''
        if not basename in self.pref_register:
            raise ValueError('No preferences under the name "%s" are registered' % basename)
        prefdefs, basedoc = self.pref_register[basename]
        s += deindent(basedoc, docstring=True).strip() + '\n\n'
        for name in sorted(prefdefs.keys()):
            pref = prefdefs[name]
            name = basename + '.' + name
            linkname = name.replace('_', '-').replace('.', '-')
            if link_targets:
                # Make a link target
                s += '.. _brian-pref-{name}:\n\n'.format(name=linkname)
            s += '``{name}`` = ``{default}``\n'.format(name=name,
                                                       default=pref.representor(pref.default))
            s += indent(deindent(pref.docs, docstring=True))
            s += '\n\n'
        return s
示例#18
0
 def _as_pref_file(self, valuefunc):
     '''
     Helper function used to generate the preference file for the default or current preference values.
     '''
     s = ''
     for basename, (prefdefs, basedoc) in self.pref_register.items():
         s += '#' + '-' * 79 + '\n'
         s += '\n'.join([
             '# ' + line for line in deindent(
                 basedoc, docstring=True).strip().split('\n')
         ]) + '\n'
         s += '#' + '-' * 79 + '\n\n'
         s += '[' + basename + ']\n\n'
         for name in sorted(prefdefs.keys()):
             pref = prefdefs[name]
             s += '\n'.join([
                 '# ' + line for line in deindent(
                     pref.docs, docstring=True).strip().split('\n')
             ]) + '\n\n'
             s += name + ' = ' + pref.representor(
                 valuefunc(pref, basename + '.' + name)) + '\n\n'
     return s
示例#19
0
    def _get_one_documentation(self, basename, link_targets):
        '''
        Document a single category of preferences.
        '''

        s = ''
        if not basename in self.pref_register:
            raise ValueError('No preferences under the name "%s" are registered' % basename)
        prefdefs, basedoc = self.pref_register[basename]
        s += deindent(basedoc, docstring=True).strip() + '\n\n'
        for name in sorted(prefdefs.keys()):
            pref = prefdefs[name]
            name = basename + '.' + name
            linkname = name.replace('_', '-').replace('.', '-')
            if link_targets:
                # Make a link target
                s += '.. _brian-pref-{name}:\n\n'.format(name=linkname)
            s += '``{name}`` = ``{default}``\n'.format(name=name,
                                                       default=pref.representor(pref.default))
            s += indent(deindent(pref.docs, docstring=True))
            s += '\n\n'
        return s
示例#20
0
    def _get_one_documentation(self, basename, link_targets):
        """
        Document a single category of preferences.
        """

        s = ''
        if not basename in self.pref_register:
            raise ValueError(
                f"No preferences under the name '{basename}' are registered")
        prefdefs, basedoc = self.pref_register[basename]
        s += deindent(basedoc, docstring=True).strip() + '\n\n'
        for name in sorted(prefdefs.keys()):
            pref = prefdefs[name]
            name = basename + '.' + name
            linkname = name.replace('_', '-').replace('.', '-')
            if link_targets:
                # Make a link target
                s += f".. _brian-pref-{linkname}:\n\n"
            s += f"``{name}`` = ``{pref.representor(pref.default)}``\n"
            s += indent(deindent(pref.docs, docstring=True))
            s += "\n\n"
        return s
示例#21
0
def test_substitute_abstract_code_functions():
    def f(x):
        y = x*x
        return y
    def g(x):
        return f(x)+1
    code = '''
    z = f(x)
    z = f(x)+f(y)
    w = f(z)
    h = f(f(w))
    p = g(g(x))
    '''
    funcs = [abstract_code_from_function(f),
             abstract_code_from_function(g),
             ]
    subcode = substitute_abstract_code_functions(code, funcs)
    for x, y in [(0, 1), (1, 0), (0.124323, 0.4549483)]:
        ns1 = {'x':x, 'y':y, 'f':f, 'g':g}
        ns2 = {'x':x, 'y':y}
        exec deindent(code) in ns1
        exec subcode in ns2
        for k in ['z', 'w', 'h', 'p']:
            assert ns1[k]==ns2[k]
示例#22
0
def apply_code_template(code, template, placeholder='%CODE%'):
    '''
    Inserts the string ``code`` into ``template`` at ``placeholder``.
    
    The ``code`` is deindented, and inserted into ``template`` with the
    indentation level at the place where ``placeholder`` appears. The
    placeholder should appear on its own line.
    
    All tab characters are replaced by four spaces.
    '''
    code = deindent(code)
    code = strip_empty_lines(code)
    template = template.replace('\t', ' '*4)
    lines = template.split('\n')
    newlines = []
    for line in lines:
        if placeholder in line:
            indentlevel = len(line)-len(line.lstrip())
            newlines.append(indent(code, indentlevel, tab=' '))
        else:
            newlines.append(line)
    return '\n'.join(newlines)
示例#23
0
def _mod_support_code():
    code = ''
    typestrs = ['int', 'float', 'double']
    floattypestrs = ['float', 'double']
    for ix, xtype in enumerate(typestrs):
        for iy, ytype in enumerate(typestrs):
            hightype = typestrs[max(ix, iy)]
            if xtype in floattypestrs or ytype in floattypestrs:
                expr = 'fmod(fmod(x, y)+y, y)'
            else:
                expr = '((x%y)+y)%y'
            code += '''
            #ifdef CPU_ONLY
            inline {hightype} _brian_mod({xtype} ux, {ytype} uy)
            #else
            __host__ __device__ inline {hightype} _brian_mod({xtype} ux, {ytype} uy)
            #endif
            {{
                const {hightype} x = ({hightype})ux;
                const {hightype} y = ({hightype})uy;
                return {expr};
            }}
            '''.format(hightype=hightype, xtype=xtype, ytype=ytype, expr=expr)
    return deindent(code)
示例#24
0
    for iy, ytype in enumerate(typestrs):
        hightype = typestrs[max(ix, iy)]
        if xtype in floattypestrs or ytype in floattypestrs:
            expr = 'fmod(fmod(x, y)+y, y)'
        else:
            expr = '((x%y)+y)%y'
        mod_support_code += '''
        inline {hightype} _brian_mod({xtype} ux, {ytype} uy)
        {{
            const {hightype} x = ({hightype})ux;
            const {hightype} y = ({hightype})uy;
            return {expr};
        }}
        '''.format(hightype=hightype, xtype=xtype, ytype=ytype, expr=expr)

_universal_support_code = deindent(mod_support_code) + '''
#ifdef _MSC_VER
#define _brian_pow(x, y) (pow((double)(x), (y)))
#else
#define _brian_pow(x, y) (pow((x), (y)))
#endif
'''


class CPPCodeGenerator(CodeGenerator):
    '''
    C++ language
    
    C++ code templates should provide Jinja2 macros with the following names:
    
    ``main``
示例#25
0
def test_ufunc_at_vectorisation():
    if prefs.codegen.target != 'numpy':
        raise SkipTest('numpy-only test')
    for code in permutation_analysis_good_examples:
        should_be_able_to_use_ufunc_at = not 'NOT_UFUNC_AT_VECTORISABLE' in code
        if should_be_able_to_use_ufunc_at:
            use_ufunc_at_list = [False, True]
        else:
            use_ufunc_at_list = [True]
        code = deindent(code)
        vars = get_identifiers(code)
        vars_src = []
        vars_tgt = []
        vars_syn = []
        for var in vars:
            if var.endswith('_pre'):
                vars_src.append(var[:-4])
            if var.endswith('_post'):
                vars_tgt.append(var[:-5])
            if var.endswith('_syn'):
                vars_syn.append(var[:-4])
        eqs_src = '\n'.join(var+':1' for var in vars_src)
        eqs_tgt = '\n'.join(var+':1' for var in vars_tgt)
        eqs_syn = '\n'.join(var+':1' for var in vars_syn)
        origvals = {}
        endvals = {}
        try:
            BrianLogger._log_messages.clear()
            with catch_logs() as caught_logs:
                for use_ufunc_at in use_ufunc_at_list:
                    NumpyCodeGenerator._use_ufunc_at_vectorisation = use_ufunc_at
                    src = NeuronGroup(3, eqs_src, threshold='True', name='src')
                    tgt = NeuronGroup(3, eqs_tgt, name='tgt')
                    syn = Synapses(src, tgt, eqs_syn, pre=code.replace('_syn', ''), connect=True, name='syn')
                    for G, vars in [(src, vars_src), (tgt, vars_tgt), (syn, vars_syn)]:
                        for var in vars:
                            fullvar = var+G.name
                            if fullvar in origvals:
                                G.state(var)[:] = origvals[fullvar]
                            else:
                                val = rand(len(G))
                                G.state(var)[:] = val
                                origvals[fullvar] = val.copy()
                    Network(src, tgt, syn).run(defaultclock.dt)
                    for G, vars in [(src, vars_src), (tgt, vars_tgt), (syn, vars_syn)]:
                        for var in vars:
                            fullvar = var+G.name
                            val = G.state(var)[:].copy()
                            if fullvar in endvals:
                                assert_allclose(val, endvals[fullvar])
                            else:
                                endvals[fullvar] = val
                if should_be_able_to_use_ufunc_at:
                    assert len(caught_logs)==0
                else:
                    assert len(caught_logs)==1
                    log_lev, log_mod, log_msg = caught_logs[0]
                    assert log_lev=='WARNING'
                    assert log_mod=='brian2.codegen.generators.numpy_generator'
                    assert log_msg.startswith('Failed to vectorise code')
        finally:
            NumpyCodeGenerator._use_ufunc_at_vectorisation = True # restore it
示例#26
0
    def _add_user_function(self, varname, var):
        user_functions = []
        load_namespace = []
        support_code = []
        impl = var.implementations[self.codeobj_class]
        func_code = impl.get_code(self.owner)
        # Implementation can be None if the function is already
        # available in Cython (possibly under a different name)
        if func_code is not None:
            if isinstance(func_code, basestring):
                # Function is provided as Cython code
                # To make namespace variables available to functions, we
                # create global variables and assign to them in the main
                # code
                user_functions.append((varname, var))
                func_namespace = impl.get_namespace(self.owner) or {}
                for ns_key, ns_value in func_namespace.iteritems():
                    load_namespace.append('# namespace for function %s' %
                                          varname)
                    if hasattr(ns_value, 'dtype'):
                        if ns_value.shape == ():
                            raise NotImplementedError((
                                'Directly replace scalar values in the function '
                                'instead of providing them via the namespace'))
                        newlines = [
                            "global _namespace{var_name}",
                            "global _namespace_num{var_name}",
                            "cdef _numpy.ndarray[{cpp_dtype}, ndim=1, mode='c'] _buf_{var_name} = _namespace['{var_name}']",
                            "_namespace{var_name} = <{cpp_dtype} *> _buf_{var_name}.data",
                            "_namespace_num{var_name} = len(_namespace['{var_name}'])"
                        ]
                        support_code.append(
                            "cdef {cpp_dtype} *_namespace{var_name}".format(
                                cpp_dtype=get_cpp_dtype(ns_value.dtype),
                                var_name=ns_key))

                    else:  # e.g. a function
                        newlines = [
                            "_namespace{var_name} = namespace['{var_name}']"
                        ]
                    for line in newlines:
                        load_namespace.append(
                            line.format(
                                cpp_dtype=get_cpp_dtype(ns_value.dtype),
                                numpy_dtype=get_numpy_dtype(ns_value.dtype),
                                var_name=ns_key))
                support_code.append(deindent(func_code))
            elif callable(func_code):
                self.variables[varname] = func_code
                line = '{0} = _namespace["{1}"]'.format(varname, varname)
                load_namespace.append(line)
            else:
                raise TypeError(('Provided function implementation '
                                 'for function %s is neither a string '
                                 'nor callable (is type %s instead)') %
                                (varname, type(func_code)))

        dep_support_code = []
        dep_load_namespace = []
        dep_user_functions = []
        if impl.dependencies is not None:
            for dep_name, dep in impl.dependencies.iteritems():
                self.variables[dep_name] = dep
                sc, ln, uf = self._add_user_function(dep_name, dep)
                dep_support_code.extend(sc)
                dep_load_namespace.extend(ln)
                dep_user_functions.extend(uf)

        return (support_code + dep_support_code,
                dep_load_namespace + load_namespace,
                dep_user_functions + user_functions)
示例#27
0
def abstract_code_dependencies(code, known_vars=None, known_funcs=None):
    '''
    Analyses identifiers used in abstract code blocks
    
    Parameters
    ----------
    
    code : str
        The abstract code block.
    known_vars : set
        The set of known variable names.
    known_funcs : set
        The set of known function names.
    
    Returns
    -------
    
    results : namedtuple with the following fields
        ``all``
            The set of all identifiers that appear in this code block,
            including functions.
        ``read``
            The set of values that are read, excluding functions.
        ``write``
            The set of all values that are written to.
        ``funcs``
            The set of all function names.
        ``known_all``
            The set of all identifiers that appear in this code block and
            are known.
        ``known_read``
            The set of known values that are read, excluding functions.
        ``known_write``
            The set of known values that are written to.
        ``known_funcs``
            The set of known functions that are used.
        ``unknown_read``
            The set of all unknown variables whose values are read. Equal
            to ``read-known_vars``.
        ``unknown_write``
            The set of all unknown variables written to. Equal to
            ``write-known_vars``.
        ``unknown_funcs``
            The set of all unknown function names, equal to
            ``funcs-known_funcs``.
        ``undefined_read``
            The set of all unknown variables whose values are read before they
            are written to. If this set is nonempty it usually indicates an
            error, since a variable that is read should either have been
            defined in the code block (in which case it will appear in
            ``newly_defined``) or already be known.
        ``newly_defined``
            The set of all variable names which are newly defined in this
            abstract code block.
    '''
    if known_vars is None:
        known_vars = set([])
    if known_funcs is None:
        known_funcs = set([])
    if not isinstance(known_vars, set):
        known_vars = set(known_vars)
    if not isinstance(known_funcs, set):
        known_funcs = set(known_funcs)

    code = deindent(code, docstring=True)
    parsed_code = ast.parse(code, mode='exec')

    # Get the list of all variables that are read from and written to,
    # ignoring the order
    allids, read, write, funcs = get_read_write_funcs(parsed_code)

    # Now check if there are any values that are unknown and read before
    # they are written to
    defined = known_vars.copy()
    newly_defined = set([])
    undefined_read = set([])
    for line in parsed_code.body:
        _, cur_read, cur_write, _ = get_read_write_funcs(line)
        undef = cur_read - defined
        undefined_read |= undef
        newly_defined |= (cur_write - defined) - undefined_read
        defined |= cur_write

    # Return the results as a named tuple
    results = dict(
        all=allids,
        read=read,
        write=write,
        funcs=funcs,
        known_all=allids.intersection(known_vars.union(known_funcs)),
        known_read=read.intersection(known_vars),
        known_write=write.intersection(known_vars),
        known_funcs=funcs.intersection(known_funcs),
        unknown_read=read - known_vars,
        unknown_write=write - known_vars,
        unknown_funcs=funcs - known_funcs,
        undefined_read=undefined_read,
        newly_defined=newly_defined,
    )
    return namedtuple('AbstractCodeDependencies', results.keys())(**results)
示例#28
0
    def _add_user_function(self, varname, var):
        user_functions = []
        load_namespace = []
        support_code = []
        impl = var.implementations[self.codeobj_class]
        func_code= impl.get_code(self.owner)
        # Implementation can be None if the function is already
        # available in Cython (possibly under a different name)
        if func_code is not None:
            if isinstance(func_code, basestring):
                # Function is provided as Cython code
                # To make namespace variables available to functions, we
                # create global variables and assign to them in the main
                # code
                user_functions.append((varname, var))
                func_namespace = impl.get_namespace(self.owner) or {}
                for ns_key, ns_value in func_namespace.iteritems():
                    load_namespace.append(
                        '# namespace for function %s' % varname)
                    if hasattr(ns_value, 'dtype'):
                        if ns_value.shape == ():
                            raise NotImplementedError((
                            'Directly replace scalar values in the function '
                            'instead of providing them via the namespace'))
                        newlines = [
                            "global _namespace{var_name}",
                            "global _namespace_num{var_name}",
                            "cdef _numpy.ndarray[{cpp_dtype}, ndim=1, mode='c'] _buf_{var_name} = _namespace['{var_name}'].view(dtype=_numpy.{numpy_dtype})",
                            "_namespace{var_name} = <{cpp_dtype} *> _buf_{var_name}.data",
                            "_namespace_num{var_name} = len(_namespace['{var_name}'])"
                        ]
                        support_code.append(
                            "cdef {cpp_dtype} *_namespace{var_name}".format(
                                cpp_dtype=get_cpp_dtype(ns_value.dtype),
                                var_name=ns_key))

                    else:  # e.g. a function
                        newlines = [
                            "_namespace{var_name} = namespace['{var_name}']"
                        ]
                    for line in newlines:
                        load_namespace.append(
                            line.format(cpp_dtype=get_cpp_dtype(ns_value.dtype),
                                        numpy_dtype=get_numpy_dtype(
                                            ns_value.dtype),
                                        var_name=ns_key))
                support_code.append(deindent(func_code))
            elif callable(func_code):
                self.variables[varname] = func_code
                line = '{0}} = _namespace["{1}}"]'.format(varname, varname)
                load_namespace.append(line)
            else:
                raise TypeError(('Provided function implementation '
                                 'for function %s is neither a string '
                                 'nor callable (is type %s instead)') % (
                                varname,
                                type(func_code)))

        dep_support_code = []
        dep_load_namespace = []
        dep_user_functions = []
        if impl.dependencies is not None:
            for dep_name, dep in impl.dependencies.iteritems():
                self.variables[dep_name] = dep
                sc, ln, uf = self._add_user_function(dep_name, dep)
                dep_support_code.extend(sc)
                dep_load_namespace.extend(ln)
                dep_user_functions.extend(uf)

        return (support_code + dep_support_code,
                dep_load_namespace + load_namespace,
                dep_user_functions + user_functions)
示例#29
0
    def create_extension(self, code, force=False, name=None,
                         include_dirs=None,
                         library_dirs=None,
                         runtime_library_dirs=None,
                         extra_compile_args=None,
                         extra_link_args=None,
                         libraries=None,
                         compiler=None,
                         owner_name='',
                         ):

        self._simplify_paths()

        if Cython is None:
            raise ImportError('Cython is not available')

        code = deindent(code)

        lib_dir = prefs.codegen.runtime.cython.cache_dir
        if lib_dir is None:
            lib_dir = os.path.join(get_cython_cache_dir(), 'brian_extensions')
        if '~' in lib_dir:
            lib_dir = os.path.expanduser(lib_dir)
        try:
            os.makedirs(lib_dir)
        except OSError:
            if not os.path.exists(lib_dir):
                raise IOError("Couldn't create Cython cache directory '%s', try setting the "
                              "cache directly with prefs.codegen.runtime.cython.cache_dir." % lib_dir)

        key = code, sys.version_info, sys.executable, Cython.__version__
            
        if force:
            # Force a new module name by adding the current time to the
            # key which is hashed to determine the module name.
            key += time.time(),            

        if key in self._code_cache:
            return self._code_cache[key]

        if name is not None:
            module_name = name#py3compat.unicode_to_str(args.name)
        else:
            module_name = "_cython_magic_" + hashlib.md5(str(key).encode('utf-8')).hexdigest()
        if owner_name:
            logger.diagnostic('"{owner_name}" using Cython module "{module_name}"'.format(owner_name=owner_name,
                                                                                     module_name=module_name))


        module_path = os.path.join(lib_dir, module_name + self.so_ext)

        if prefs['codegen.runtime.cython.multiprocess_safe']:
            lock_file = os.path.join(lib_dir, module_name + '.lock')
            with open(lock_file, 'w') as f:
                if msvcrt:
                    msvcrt.locking(f.fileno(), msvcrt.LK_RLCK,
                                   os.stat(lock_file).st_size)
                else:
                    fcntl.flock(f, fcntl.LOCK_EX)
                return self._load_module(module_path, include_dirs,
                                         library_dirs,
                                         extra_compile_args, extra_link_args,
                                         libraries, code, lib_dir, module_name,
                                         runtime_library_dirs, compiler, key)
        else:
            return self._load_module(module_path, include_dirs, library_dirs,
                                     extra_compile_args, extra_link_args,
                                     libraries, code, lib_dir, module_name,
                                     runtime_library_dirs, compiler, key)
示例#30
0
    def _add_user_function(self, varname, variable, added):
        impl = variable.implementations[self.codeobj_class]
        if (impl.name, variable) in added:
            return  # nothing to do
        else:
            added.add((impl.name, variable))
        support_code = []
        hash_defines = []
        pointers = []
        user_functions = [(varname, variable)]
        funccode = impl.get_code(self.owner)
        if isinstance(funccode, str):
            # Rename references to any dependencies if necessary
            for dep_name, dep in impl.dependencies.items():
                dep_impl = dep.implementations[self.codeobj_class]
                dep_impl_name = dep_impl.name
                if dep_impl_name is None:
                    dep_impl_name = dep.pyfunc.__name__
                if dep_name != dep_impl_name:
                    funccode = word_substitute(funccode,
                                               {dep_name: dep_impl_name})
            funccode = {'support_code': funccode}
        if funccode is not None:
            # To make namespace variables available to functions, we
            # create global variables and assign to them in the main
            # code
            func_namespace = impl.get_namespace(self.owner) or {}
            for ns_key, ns_value in func_namespace.items():
                if hasattr(ns_value, 'dtype'):
                    if ns_value.shape == ():
                        raise NotImplementedError(
                            ('Directly replace scalar values in the function '
                             'instead of providing them via the namespace'))
                    type_str = self.c_data_type(ns_value.dtype) + '*'
                else:  # e.g. a function
                    type_str = 'py::object'
                support_code.append('static {0} _namespace{1};'.format(
                    type_str, ns_key))
                pointers.append('_namespace{0} = {1};'.format(ns_key, ns_key))
            support_code.append(deindent(funccode.get('support_code', '')))
            hash_defines.append(deindent(funccode.get('hashdefine_code', '')))

        dep_hash_defines = []
        dep_pointers = []
        dep_support_code = []
        if impl.dependencies is not None:
            for dep_name, dep in impl.dependencies.items():
                if dep_name not in self.variables:
                    self.variables[dep_name] = dep
                    dep_impl = dep.implementations[self.codeobj_class]
                    if dep_name != dep_impl.name:
                        self.func_name_replacements[dep_name] = dep_impl.name
                    user_function = self._add_user_function(
                        dep_name, dep, added)
                    if user_function is not None:
                        hd, ps, sc, uf = user_function
                        dep_hash_defines.extend(hd)
                        dep_pointers.extend(ps)
                        dep_support_code.extend(sc)
                        user_functions.extend(uf)

        return (dep_hash_defines + hash_defines, dep_pointers + pointers,
                dep_support_code + support_code, user_functions)
示例#31
0
def abstract_code_dependencies(code, known_vars=None, known_funcs=None):
    '''
    Analyses identifiers used in abstract code blocks
    
    Parameters
    ----------
    
    code : str
        The abstract code block.
    known_vars : set
        The set of known variable names.
    known_funcs : set
        The set of known function names.
    
    Returns
    -------
    
    results : namedtuple with the following fields
        ``all``
            The set of all identifiers that appear in this code block,
            including functions.
        ``read``
            The set of values that are read, excluding functions.
        ``write``
            The set of all values that are written to.
        ``funcs``
            The set of all function names.
        ``known_all``
            The set of all identifiers that appear in this code block and
            are known.
        ``known_read``
            The set of known values that are read, excluding functions.
        ``known_write``
            The set of known values that are written to.
        ``known_funcs``
            The set of known functions that are used.
        ``unknown_read``
            The set of all unknown variables whose values are read. Equal
            to ``read-known_vars``.
        ``unknown_write``
            The set of all unknown variables written to. Equal to
            ``write-known_vars``.
        ``unknown_funcs``
            The set of all unknown function names, equal to
            ``funcs-known_funcs``.
        ``undefined_read``
            The set of all unknown variables whose values are read before they
            are written to. If this set is nonempty it usually indicates an
            error, since a variable that is read should either have been
            defined in the code block (in which case it will appear in
            ``newly_defined``) or already be known.
        ``newly_defined``
            The set of all variable names which are newly defined in this
            abstract code block.
    '''
    if known_vars is None:
        known_vars = set([])
    if known_funcs is None:
        known_funcs = set([])
    if not isinstance(known_vars, set):
        known_vars = set(known_vars)
    if not isinstance(known_funcs, set):
        known_funcs = set(known_funcs)
    
    code = deindent(code, docstring=True)
    parsed_code = ast.parse(code, mode='exec')
    
    # Get the list of all variables that are read from and written to,
    # ignoring the order
    allids, read, write, funcs = get_read_write_funcs(parsed_code) 
    
    # Now check if there are any values that are unknown and read before
    # they are written to
    defined = known_vars.copy()
    newly_defined = set([])
    undefined_read = set([])
    for line in parsed_code.body:
        _, cur_read, cur_write, _ = get_read_write_funcs(line)
        undef = cur_read-defined
        undefined_read |= undef
        newly_defined |= (cur_write-defined)-undefined_read
        defined |= cur_write
    
    # Return the results as a named tuple
    results = dict(
        all=allids,
        read=read,
        write=write,
        funcs=funcs,
        known_all=allids.intersection(known_vars.union(known_funcs)),
        known_read=read.intersection(known_vars),
        known_write=write.intersection(known_vars),
        known_funcs=funcs.intersection(known_funcs),
        unknown_read=read-known_vars,
        unknown_write=write-known_vars,
        unknown_funcs=funcs-known_funcs,
        undefined_read=undefined_read,
        newly_defined=newly_defined,
        )
    return namedtuple('AbstractCodeDependencies', results.keys())(**results)
示例#32
0
class GeNNCodeGenerator(CodeGenerator):
    '''
    "GeNN language"
    
    For user-defined functions, there are two keys to provide:
    
    ``support_code``
        The function definition which will be added to the support code.
    ``hashdefine_code``
        The ``#define`` code added to the main loop.
    '''

    class_name = 'genn'

    universal_support_code = _mod_support_code() + deindent('''
    #ifdef _MSC_VER
    #define _brian_pow(x, y) (pow((double)(x), (y)))
    #else
    #define _brian_pow(x, y) (pow((x), (y)))
    #endif
    ''')

    def __init__(self, *args, **kwds):
        super(GeNNCodeGenerator, self).__init__(*args, **kwds)
        self.c_data_type = c_data_type

    @property
    def restrict(self):
        return prefs['codegen.generators.cpp.restrict_keyword'] + ' '

    @property
    def flush_denormals(self):
        return prefs['codegen.generators.cpp.flush_denormals']

    @staticmethod
    def get_array_name(var, access_data=True):
        # We have to do the import here to avoid circular import dependencies.
        from brian2.devices.device import get_device
        device = get_device()
        if access_data:
            return '_ptr' + device.get_array_name(var)
        else:
            return device.get_array_name(var, access_data=False)

    def translate_expression(self, expr):
        for varname, var in self.variables.iteritems():
            if isinstance(var, Function):
                impl_name = var.implementations[self.codeobj_class].name
                if impl_name is not None:
                    expr = word_substitute(expr, {varname: impl_name})
        return CPPNodeRenderer().render_expr(expr).strip()

    def translate_statement(self, statement):
        var, op, expr, comment = (statement.var, statement.op, statement.expr,
                                  statement.comment)
        if op == ':=':
            decl = self.c_data_type(statement.dtype) + ' '
            op = '='
        else:
            decl = ''
        code = decl + var + ' ' + op + ' ' + self.translate_expression(
            expr) + ';'
        if len(comment):
            code += ' // ' + comment
        return code

    def translate_to_read_arrays(self, statements):
        return []

    def translate_to_declarations(self, statements):
        return []

    def translate_to_statements(self, statements):
        read, write, indices, conditional_write_vars = self.arrays_helper(
            statements)
        lines = []
        # the actual code
        for stmt in statements:
            line = self.translate_statement(stmt)
            if stmt.var in conditional_write_vars:
                subs = {}
                condvar = conditional_write_vars[stmt.var]
                lines.append('if(%s)' % condvar)
                lines.append('    ' + line)
            else:
                lines.append(line)
        return lines

    def translate_to_write_arrays(self, statements):
        return []

    def translate_one_statement_sequence(self, statements, scalar=False):
        if len(statements) and self.template_name == 'synapses':
            _, _, _, conditional_write_vars = self.arrays_helper(statements)
            vars_pre = [
                k for k, v in self.variable_indices.items()
                if v == '_presynaptic_idx'
            ]
            vars_syn = [
                k for k, v in self.variable_indices.items() if v == '_idx'
            ]
            vars_post = [
                k for k, v in self.variable_indices.items()
                if v == '_postsynaptic_idx'
            ]
            if '_pre_codeobject' in self.name:
                post_write_var, statements = check_pre_code(
                    self, statements, vars_pre, vars_syn, vars_post,
                    conditional_write_vars)
                self.owner._genn_post_write_var = post_write_var
        lines = []
        lines += self.translate_to_statements(statements)
        code = '\n'.join(lines)
        return stripped_deindented_lines(code)

    def denormals_to_zero_code(self):
        if self.flush_denormals:
            return '''
            #define CSR_FLUSH_TO_ZERO         (1 << 15)
            unsigned csr = __builtin_ia32_stmxcsr();
            csr |= CSR_FLUSH_TO_ZERO;
            __builtin_ia32_ldmxcsr(csr);
            '''
        else:
            return ''

    def _add_user_function(self, varname, variable):
        impl = variable.implementations[self.codeobj_class]
        support_code = []
        hash_defines = []
        pointers = []
        user_functions = [(varname, variable)]
        funccode = impl.get_code(self.owner)
        if isinstance(funccode, basestring):
            funccode = {'support_code': funccode}
        if funccode is not None:
            # To make namespace variables available to functions, we
            # create global variables and assign to them in the main
            # code
            func_namespace = impl.get_namespace(self.owner) or {}
            for ns_key, ns_value in func_namespace.iteritems():
                if hasattr(ns_value, 'dtype'):
                    if ns_value.shape == ():
                        raise NotImplementedError(
                            ('Directly replace scalar values in the function '
                             'instead of providing them via the namespace'))
                    type_str = c_data_type(ns_value.dtype) + '*'
                else:  # e.g. a function
                    type_str = 'py::object'
                support_code.append('static {0} _namespace{1};'.format(
                    type_str, ns_key))
                pointers.append('_namespace{0} = {1};'.format(ns_key, ns_key))
            support_code.append(deindent(funccode.get('support_code', '')))
            hash_defines.append(deindent(funccode.get('hashdefine_code', '')))

        dep_hash_defines = []
        dep_pointers = []
        dep_support_code = []
        if impl.dependencies is not None:
            for dep_name, dep in impl.dependencies.iteritems():
                if dep_name not in self.variables:  # do not add a dependency twice
                    self.variables[dep_name] = dep
                    hd, ps, sc, uf = self._add_user_function(dep_name, dep)
                    dep_hash_defines.extend(hd)
                    dep_pointers.extend(ps)
                    dep_support_code.extend(sc)
                    user_functions.extend(uf)

        return (dep_hash_defines + hash_defines, dep_pointers + pointers,
                dep_support_code + support_code, user_functions)

    def determine_keywords(self):
        # set up the restricted pointers, these are used so that the compiler
        # knows there is no aliasing in the pointers, for optimisation
        pointers = []
        # It is possible that several different variable names refer to the
        # same array. E.g. in gapjunction code, v_pre and v_post refer to the
        # same array if a group is connected to itself
        handled_pointers = set()
        template_kwds = {}
        # Again, do the import here to avoid a circular dependency.
        from brian2.devices.device import get_device
        device = get_device()
        for varname, var in self.variables.iteritems():
            if isinstance(var, ArrayVariable):
                # This is the "true" array name, not the restricted pointer.
                array_name = device.get_array_name(var)
                pointer_name = self.get_array_name(var)
                if pointer_name in handled_pointers:
                    continue
                if get_var_ndim(var, 1) > 1:
                    continue  # multidimensional (dynamic) arrays have to be treated differently
                line = '{0}* {1} {2} = {3};'.format(
                    self.c_data_type(var.dtype), self.restrict, pointer_name,
                    array_name)
                pointers.append(line)
                handled_pointers.add(pointer_name)

        # set up the functions
        user_functions = []
        support_code = []
        hash_defines = []
        for varname, variable in self.variables.items():
            if isinstance(variable, Function):
                hd, ps, sc, uf = self._add_user_function(varname, variable)
                user_functions.extend(uf)
                support_code.extend(sc)
                pointers.extend(ps)
                hash_defines.extend(hd)

        # delete the user-defined functions from the namespace and add the
        # function namespaces (if any)
        for funcname, func in user_functions:
            del self.variables[funcname]
            func_namespace = func.implementations[
                self.codeobj_class].get_namespace(self.owner)
            if func_namespace is not None:
                self.variables.update(func_namespace)

        support_code.append(self.universal_support_code)

        keywords = {
            'pointers_lines':
            stripped_deindented_lines('\n'.join(pointers)),
            'support_code_lines':
            stripped_deindented_lines('\n'.join(support_code)),
            'hashdefine_lines':
            stripped_deindented_lines('\n'.join(hash_defines)),
            'denormals_code_lines':
            stripped_deindented_lines('\n'.join(
                self.denormals_to_zero_code())),
        }
        keywords.update(template_kwds)
        return keywords
示例#33
0
    placeholder should appear on its own line.
    
    All tab characters are replaced by four spaces.
    '''
    code = deindent(code)
    code = strip_empty_lines(code)
    template = template.replace('\t', ' '*4)
    lines = template.split('\n')
    newlines = []
    for line in lines:
        if placeholder in line:
            indentlevel = len(line)-len(line.lstrip())
            newlines.append(indent(code, indentlevel, tab=' '))
        else:
            newlines.append(line)
    return '\n'.join(newlines)

if __name__=='__main__':
    code = '''
    if cond:
        do_something()
    '''
    template = '''
    for arg in args:
        cond = f(arg)
        %CODE%
        do_something_else(arg)
    '''
    
    print deindent(strip_empty_lines(apply_code_template(code, template)))
示例#34
0
def make_statements(code, variables, dtype):
    '''
    Turn a series of abstract code statements into Statement objects, inferring
    whether each line is a set/declare operation, whether the variables are
    constant or not, and handling the cacheing of subexpressions. Returns a
    list of Statement objects. For arguments, see documentation for
    :func:`translate`.
    '''
    code = strip_empty_lines(deindent(code))
    lines = re.split(r'[;\n]', code)
    lines = [LineInfo(code=line) for line in lines if len(line)]
    if DEBUG:
        print 'INPUT CODE:'
        print code
    dtypes = dict((name, var.dtype) for name, var in variables.iteritems()
                  if not isinstance(var, Function))
    # we will do inference to work out which lines are := and which are =
    defined = set(k for k, v in variables.iteritems()
                  if not isinstance(v, AuxiliaryVariable))
    scalars = set(k for k, v in variables.iteritems()
                  if getattr(v, 'scalar', False))
    for line in lines:
        # parse statement into "var op expr"
        var, op, expr, comment = parse_statement(line.code)
        if op == '=':
            if var not in defined:
                op = ':='
                defined.add(var)
                if var not in dtypes:
                    dtypes[var] = dtype
                # determine whether this is a scalar variable
                identifiers = get_identifiers_recursively([expr], variables)
                # In the following we assume that all unknown identifiers are
                # scalar constants -- this should cover numerical literals and
                # e.g. "True" or "inf".
                is_scalar = all((name in scalars) or not (name in defined)
                                for name in identifiers)
                if is_scalar:
                    scalars.add(var)

        statement = Statement(var,
                              op,
                              expr,
                              comment,
                              dtype=dtypes[var],
                              scalar=var in scalars)
        line.statement = statement
        # for each line will give the variable being written to
        line.write = var
        # each line will give a set of variables which are read
        line.read = get_identifiers_recursively([expr], variables)

    # All writes to scalar variables must happen before writes to vector
    # variables
    scalar_write_done = False
    for line in lines:
        stmt = line.statement
        if stmt.op != ':=' and stmt.var in scalars and scalar_write_done:
            raise SyntaxError(
                ('All writes to scalar variables in a code block '
                 'have to be made before writes to vector '
                 'variables. Illegal write to %s.') % line.write)
        elif not stmt.var in scalars:
            scalar_write_done = True

    if DEBUG:
        print 'PARSED STATEMENTS:'
        for line in lines:
            print line.statement, 'Read:' + str(
                line.read), 'Write:' + line.write

    # all variables which are written to at some point in the code block
    # used to determine whether they should be const or not
    all_write = set(line.write for line in lines)

    if DEBUG:
        print 'ALL WRITE:', all_write

    # backwards compute whether or not variables will be read again
    # note that will_read for a line gives the set of variables it will read
    # on the current line or subsequent ones. will_write gives the set of
    # variables that will be written after the current line
    will_read = set()
    will_write = set()
    for line in lines[::-1]:
        will_read = will_read.union(line.read)
        line.will_read = will_read.copy()
        line.will_write = will_write.copy()
        will_write.add(line.write)

    if DEBUG:
        print 'WILL READ/WRITE:'
        for line in lines:
            print line.statement, 'Read:' + str(
                line.will_read), 'Write:' + str(line.will_write)

    # generate cacheing statements for common subexpressions
    # cached subexpressions need to be recomputed whenever they are to be used
    # on the next line, and currently invalid (meaning that the current value
    # stored in the subexpression variable is no longer accurate because one
    # of the variables appearing in it has changed). All subexpressions start
    # as invalid, and are invalidated whenever one of the variables appearing
    # in the RHS changes value.
    subexpressions = dict((name, val) for name, val in variables.items()
                          if isinstance(val, Subexpression))
    # sort subexpressions into an order so that subexpressions that don't depend
    # on other subexpressions are first
    subexpr_deps = dict((name, [dep for dep in subexpr.identifiers if dep in subexpressions]) for \
                                                            name, subexpr in subexpressions.items())
    sorted_subexpr_vars = topsort(subexpr_deps)

    if DEBUG:
        print 'SUBEXPRESSIONS:', subexpressions.keys()
    statements = []
    # all start as invalid
    valid = dict((name, False) for name in subexpressions.keys())
    # none are yet defined (or declared)
    subdefined = dict((name, False) for name in subexpressions.keys())
    for line in lines:
        stmt = line.statement
        read = line.read
        write = line.write
        will_read = line.will_read
        will_write = line.will_write
        # check that all subexpressions in expr are valid, and if not
        # add a definition/set its value, and set it to be valid
        # scan through in sorted order so that recursive subexpression dependencies
        # are handled in the right order
        for var in sorted_subexpr_vars:
            if var not in read:
                continue
            # if subexpression, and invalid
            if not valid.get(var, True):  # all non-subexpressions are valid
                subexpression = subexpressions[var]
                # if already defined/declared
                if subdefined[var]:
                    op = '='
                    constant = False
                else:
                    op = ':='
                    subdefined[var] = True
                    dtypes[var] = variables[var].dtype
                    # set to constant only if we will not write to it again
                    constant = var not in will_write
                    # check all subvariables are not written to again as well
                    if constant:
                        ids = subexpression.identifiers
                        constant = all(v not in will_write for v in ids)
                valid[var] = True
                statement = Statement(var,
                                      op,
                                      subexpression.expr,
                                      comment='',
                                      dtype=variables[var].dtype,
                                      constant=constant,
                                      subexpression=True,
                                      scalar=var in scalars)
                statements.append(statement)
        var, op, expr, comment = stmt.var, stmt.op, stmt.expr, stmt.comment
        # invalidate any subexpressions including var, recursively
        # we do this by having a set of variables that are invalid that we
        # start with the changed var and increase by any subexpressions we
        # find that have a dependency on something in the invalid set. We
        # go through in sorted subexpression order so that the invalid set
        # is increased in the right order
        invalid = set([var])
        for subvar in sorted_subexpr_vars:
            spec = subexpressions[subvar]
            spec_ids = set(spec.identifiers)
            if spec_ids.intersection(invalid):
                valid[subvar] = False
                invalid.add(subvar)
        # constant only if we are declaring a new variable and we will not
        # write to it again
        constant = op == ':=' and var not in will_write
        statement = Statement(var,
                              op,
                              expr,
                              comment,
                              dtype=dtypes[var],
                              constant=constant,
                              scalar=var in scalars)
        statements.append(statement)

    if DEBUG:
        print 'OUTPUT STATEMENTS:'
        for stmt in statements:
            print stmt

    return statements
示例#35
0
    for iy, ytype in enumerate(typestrs):
        hightype = typestrs[max(ix, iy)]
        if xtype in floattypestrs or ytype in floattypestrs:
            expr = 'fmod(fmod(x, y)+y, y)'
        else:
            expr = '((x%y)+y)%y'
        mod_support_code += '''
        inline {hightype} _brian_mod({xtype} ux, {ytype} uy)
        {{
            const {hightype} x = ({hightype})ux;
            const {hightype} y = ({hightype})uy;
            return {expr};
        }}
        '''.format(hightype=hightype, xtype=xtype, ytype=ytype, expr=expr)

_universal_support_code = deindent(mod_support_code)+'''
#ifdef _MSC_VER
#define _brian_pow(x, y) (pow((double)(x), (y)))
#else
#define _brian_pow(x, y) (pow((x), (y)))
#endif
'''


class CPPCodeGenerator(CodeGenerator):
    '''
    C++ language
    
    C++ code templates should provide Jinja2 macros with the following names:
    
    ``main``
示例#36
0
def substitute_abstract_code_functions(code, funcs):
    '''
    Performs inline substitution of all the functions in the code
    
    Parameters
    ----------
    code : str
        The abstract code to make inline substitutions into.
    funcs : list, dict or set of AbstractCodeFunction
        The function substitutions to use, note in the case of a dict, the
        keys are ignored and the function name is used.
        
    Returns
    -------
    code : str
        The code with inline substitutions performed.
    '''
    if isinstance(funcs, (list, set)):
        newfuncs = dict()
        for f in funcs:
            newfuncs[f.name] = f
        funcs = newfuncs
        
    code = deindent(code)
    lines = ast.parse(code, mode='exec').body

    # This is a slightly nasty hack, but basically we just check by looking at
    # the existing identifiers how many inline operations have already been
    # performed by previous calls to this function
    ids = get_identifiers(code)
    funcstarts = {}
    for func in funcs.values():
        subids = set([id for id in ids if id.startswith('_inline_'+func.name+'_')])
        subids = set([id.replace('_inline_'+func.name+'_', '') for id in subids])
        alli = []
        for subid in subids:
            p = subid.find('_')
            if p>0:
                subid = subid[:p]
            i = int(subid)
            alli.append(i)
        if len(alli)==0:
            i = 0
        else:
            i = max(alli)+1
        funcstarts[func.name] = i
    
    # Now we rewrite all the lines, replacing each line with a sequence of
    # lines performing the inlining
    newlines = []
    for line in lines:
        for func in funcs.values():
            rw = FunctionRewriter(func, funcstarts[func.name])
            line = rw.visit(line)
            newlines.extend(rw.pre)
            funcstarts[func.name] = rw.numcalls
        newlines.append(line)
        
    # Now we render to a code string
    nr = NodeRenderer()
    newcode = '\n'.join(nr.render_node(line) for line in newlines)
    
    # We recurse until no changes in the code to ensure that all functions
    # are expanded if one function refers to another, etc.
    if newcode==code:
        return newcode
    else:
        return substitute_abstract_code_functions(newcode, funcs)
示例#37
0
    def _add_user_function(self, varname, variable):
        impl = variable.implementations[self.codeobj_class]
        support_code = []
        hash_defines = []
        pointers = []
        kernel_lines = []
        user_functions = [(varname, variable)]
        funccode = impl.get_code(self.owner)

        ### Different from CPPCodeGenerator: We format the funccode dtypes here
        from brian2.devices.device import get_device
        device = get_device()
        if varname in functions_C99:
            funccode = funccode.format(default_type=self.default_func_type,
                                       other_type=self.other_func_type)
        if varname == 'clip':
            funccode = funccode.format(float_dtype=self.float_dtype)
        ###

        if isinstance(funccode, basestring):
            funccode = {'support_code': funccode}
        if funccode is not None:
            # To make namespace variables available to functions, we
            # create global variables and assign to them in the main
            # code
            func_namespace = impl.get_namespace(self.owner) or {}
            for ns_key, ns_value in func_namespace.iteritems():
                # This section is adapted from CPPCodeGenerator such that file
                # global namespace pointers can be used in both host and device
                # code.
                assert hasattr(ns_value, 'dtype'), \
                    'This should not have happened. Please report at ' \
                    'https://github.com/brian-team/brian2cuda/issues/new'
                if ns_value.shape == ():
                    raise NotImplementedError(
                        ('Directly replace scalar values in the function '
                         'instead of providing them via the namespace'))
                type_str = self.c_data_type(ns_value.dtype) + '*'
                namespace_ptr = '''
                    #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ > 0))
                    __device__ {dtype} _namespace{name};
                    #else
                    {dtype} _namespace{name};
                    #endif
                    '''.format(dtype=type_str, name=ns_key)
                support_code.append(namespace_ptr)
                # pointer lines will be used in codeobjects running on the host
                pointers.append(
                    '_namespace{name} = {name};'.format(name=ns_key))
                # kernel lines will be used in codeobjects running on the device
                kernel_lines.append('''
                    #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ > 0))
                    _namespace{name} = d{name};
                    #else
                    _namespace{name} = {name};
                    #endif
                    '''.format(name=ns_key))
            support_code.append(deindent(funccode.get('support_code', '')))
            hash_defines.append(deindent(funccode.get('hashdefine_code', '')))

        dep_hash_defines = []
        dep_pointers = []
        dep_support_code = []
        dep_kernel_lines = []
        if impl.dependencies is not None:
            for dep_name, dep in impl.dependencies.iteritems():
                if dep_name not in self.variables:
                    self.variables[dep_name] = dep
                    hd, ps, sc, uf, kl = self._add_user_function(dep_name, dep)
                    dep_hash_defines.extend(hd)
                    dep_pointers.extend(ps)
                    dep_support_code.extend(sc)
                    user_functions.extend(uf)
                    dep_kernel_lines.extend(kl)

        return (dep_hash_defines + hash_defines, dep_pointers + pointers,
                dep_support_code + support_code, user_functions,
                dep_kernel_lines + kernel_lines)
示例#38
0
def make_statements(code, variables, dtype):
    '''
    Turn a series of abstract code statements into Statement objects, inferring
    whether each line is a set/declare operation, whether the variables are
    constant or not, and handling the cacheing of subexpressions. Returns a
    list of Statement objects. For arguments, see documentation for
    :func:`translate`.
    '''
    code = strip_empty_lines(deindent(code))
    lines = re.split(r'[;\n]', code)
    lines = [LineInfo(code=line) for line in lines if len(line)]
    if DEBUG:
        print 'INPUT CODE:'
        print code
    dtypes = dict((name, var.dtype) for name, var in variables.iteritems()
                  if not isinstance(var, Function))
    # we will do inference to work out which lines are := and which are =
    defined = set(k for k, v in variables.iteritems()
                  if not isinstance(v, AuxiliaryVariable))
    scalars = set(k for k,v in variables.iteritems()
                  if getattr(v, 'scalar', False))
    for line in lines:
        # parse statement into "var op expr"
        var, op, expr = parse_statement(line.code)
        if op=='=':
            if var not in defined:
                op = ':='
                defined.add(var)
                if var not in dtypes:
                    dtypes[var] = dtype
                # determine whether this is a scalar variable
                identifiers = get_identifiers_recursively(expr, variables)
                # In the following we assume that all unknown identifiers are
                # scalar constants -- this should cover numerical literals and
                # e.g. "True" or "inf".
                is_scalar = all((name in scalars) or not (name in defined)
                                for name in identifiers)
                if is_scalar:
                    scalars.add(var)

        statement = Statement(var, op, expr, dtypes[var], scalar=var in scalars)
        line.statement = statement
        # for each line will give the variable being written to
        line.write = var 
        # each line will give a set of variables which are read
        line.read = get_identifiers_recursively(expr, variables)

    # All writes to scalar variables must happen before writes to vector
    # variables
    scalar_write_done = False
    for line in lines:
        stmt = line.statement
        if stmt.op != ':=' and stmt.var in scalars and scalar_write_done:
            raise SyntaxError(('All writes to scalar variables in a code block '
                               'have to be made before writes to vector '
                               'variables. Illegal write to %s.') % line.write)
        elif not stmt.var in scalars:
            scalar_write_done = True

    if DEBUG:
        print 'PARSED STATEMENTS:'
        for line in lines:
            print line.statement, 'Read:'+str(line.read), 'Write:'+line.write
    
    # all variables which are written to at some point in the code block
    # used to determine whether they should be const or not
    all_write = set(line.write for line in lines)

    if DEBUG:
        print 'ALL WRITE:', all_write
        
    # backwards compute whether or not variables will be read again
    # note that will_read for a line gives the set of variables it will read
    # on the current line or subsequent ones. will_write gives the set of
    # variables that will be written after the current line
    will_read = set()
    will_write = set()
    for line in lines[::-1]:
        will_read = will_read.union(line.read)
        line.will_read = will_read.copy()
        line.will_write = will_write.copy()
        will_write.add(line.write)

    if DEBUG:
        print 'WILL READ/WRITE:'
        for line in lines:
            print line.statement, 'Read:'+str(line.will_read), 'Write:'+str(line.will_write)
        
    # generate cacheing statements for common subexpressions
    # cached subexpressions need to be recomputed whenever they are to be used
    # on the next line, and currently invalid (meaning that the current value
    # stored in the subexpression variable is no longer accurate because one
    # of the variables appearing in it has changed). All subexpressions start
    # as invalid, and are invalidated whenever one of the variables appearing
    # in the RHS changes value.
    subexpressions = dict((name, val) for name, val in variables.items() if isinstance(val, Subexpression))
    # sort subexpressions into an order so that subexpressions that don't depend
    # on other subexpressions are first
    subexpr_deps = dict((name, [dep for dep in subexpr.identifiers if dep in subexpressions]) for \
                                                            name, subexpr in subexpressions.items())
    sorted_subexpr_vars = topsort(subexpr_deps)

    if DEBUG:
        print 'SUBEXPRESSIONS:', subexpressions.keys()
    statements = []
    # all start as invalid
    valid = dict((name, False) for name in subexpressions.keys())
    # none are yet defined (or declared)
    subdefined = dict((name, False) for name in subexpressions.keys())
    for line in lines:
        stmt = line.statement
        read = line.read
        write = line.write
        will_read = line.will_read
        will_write = line.will_write
        # check that all subexpressions in expr are valid, and if not
        # add a definition/set its value, and set it to be valid
        # scan through in sorted order so that recursive subexpression dependencies
        # are handled in the right order
        for var in sorted_subexpr_vars:
            if var not in read:
                continue
            # if subexpression, and invalid
            if not valid.get(var, True): # all non-subexpressions are valid
                subexpression = subexpressions[var]
                # if already defined/declared
                if subdefined[var]:
                    op = '='
                    constant = False
                else:
                    op = ':='
                    subdefined[var] = True
                    dtypes[var] = variables[var].dtype
                    # set to constant only if we will not write to it again
                    constant = var not in will_write
                    # check all subvariables are not written to again as well
                    if constant:
                        ids = subexpression.identifiers
                        constant = all(v not in will_write for v in ids)
                valid[var] = True
                statement = Statement(var, op, subexpression.expr,
                                      variables[var].dtype, constant=constant,
                                      subexpression=True, scalar=var in scalars)
                statements.append(statement)
        var, op, expr = stmt.var, stmt.op, stmt.expr
        # invalidate any subexpressions including var, recursively
        # we do this by having a set of variables that are invalid that we
        # start with the changed var and increase by any subexpressions we
        # find that have a dependency on something in the invalid set. We
        # go through in sorted subexpression order so that the invalid set
        # is increased in the right order
        invalid = set([var])
        for subvar in sorted_subexpr_vars:
            spec = subexpressions[subvar]
            spec_ids = set(spec.identifiers)
            if spec_ids.intersection(invalid):
                valid[subvar] = False
                invalid.add(subvar)
        # constant only if we are declaring a new variable and we will not
        # write to it again
        constant = op==':=' and var not in will_write
        statement = Statement(var, op, expr, dtypes[var],
                              constant=constant, scalar=var in scalars)
        statements.append(statement)

    if DEBUG:
        print 'OUTPUT STATEMENTS:'
        for stmt in statements:
            print stmt

    return statements
示例#39
0
    print 'Saving notebook and converting to RST'
    exporter = NotebookExporter()
    output, _ = exporter.from_notebook_node(notebook)
    with codecs.open(output_ipynb_fname, 'w', encoding='utf-8') as f:
        f.write(output)

    # Insert a note about ipython notebooks with a download link
    note = deindent(u'''
    .. only:: html

        .. |launchbinder| image:: http://mybinder.org/badge.svg
        .. _launchbinder: http://mybinder.org:/repo/brian-team/brian2-binder/notebooks/tutorials/{tutorial}.ipynb
    
        .. note::
           This tutorial is a static non-editable version. You can launch an
           interactive, editable version without installing any local files
           using the Binder service (although note that at some times this
           may be slow or fail to open): |launchbinder|_
    
           Alternatively, you can download a copy of the notebook file
           to use locally: :download:`{tutorial}.ipynb`
    
           See the :doc:`tutorial overview page <index>` for more details.

    '''.format(tutorial=basename))
    notebook.cells.insert(1, NotebookNode(cell_type=u'raw', metadata={},
                                          source=note))
    exporter = RSTExporter()
    output, resources = exporter.from_notebook_node(notebook,
                                                    resources={'unique_key': basename+'_image'})
    with codecs.open(output_rst_fname, 'w', encoding='utf-8') as f:
        f.write(output)
示例#40
0
def make_statements(code, variables, dtype):
    '''
    Turn a series of abstract code statements into Statement objects, inferring
    whether each line is a set/declare operation, whether the variables are
    constant or not, and handling the cacheing of subexpressions.

    Parameters
    ----------
    code : str
        A (multi-line) string of statements.
    variables : dict-like
        A dictionary of with `Variable` and `Function` objects for every
        identifier used in the `code`.
    dtype : `dtype`
        The data type to use for temporary variables

    Returns
    -------
    scalar_statements, vector_statements : (list of `Statement`, list of `Statement`)
        Lists with statements that are to be executed once and statements that
        are to be executed once for every neuron/synapse/... (or in a vectorised
        way)

    Notes
    -----
    The `scalar_statements` may include newly introduced scalar constants that
    have been identified as loop-invariant and have therefore been pulled out
    of the vector statements.
    '''
    code = strip_empty_lines(deindent(code))
    lines = re.split(r'[;\n]', code)
    lines = [LineInfo(code=line) for line in lines if len(line)]
    if DEBUG:
        print 'INPUT CODE:'
        print code
    # Do a copy so we can add stuff without altering the original dict
    variables = dict(variables)
    # we will do inference to work out which lines are := and which are =
    defined = set(k for k, v in variables.iteritems()
                  if not isinstance(v, AuxiliaryVariable))
    for line in lines:
        # parse statement into "var op expr"
        var, op, expr, comment = parse_statement(line.code)
        if op=='=':
            if var not in defined:
                op = ':='
                defined.add(var)
                if var not in variables:
                    is_scalar = is_scalar_expression(expr, variables)
                    new_var = AuxiliaryVariable(var, Unit(1), # doesn't matter here
                                                dtype=dtype, scalar=is_scalar)
                    variables[var] = new_var


        statement = Statement(var, op, expr, comment,
                              dtype=variables[var].dtype,
                              scalar=variables[var].scalar)
        line.statement = statement
        # for each line will give the variable being written to
        line.write = var 
        # each line will give a set of variables which are read
        line.read = get_identifiers_recursively([expr], variables)

    # All writes to scalar variables must happen before writes to vector
    # variables
    scalar_write_done = False
    for line in lines:
        stmt = line.statement
        if stmt.op != ':=' and variables[stmt.var].scalar and scalar_write_done:
            raise SyntaxError(('All writes to scalar variables in a code block '
                               'have to be made before writes to vector '
                               'variables. Illegal write to %s.') % line.write)
        elif not variables[stmt.var].scalar:
            scalar_write_done = True

    if DEBUG:
        print 'PARSED STATEMENTS:'
        for line in lines:
            print line.statement, 'Read:'+str(line.read), 'Write:'+line.write
    
    # all variables which are written to at some point in the code block
    # used to determine whether they should be const or not
    all_write = set(line.write for line in lines)

    if DEBUG:
        print 'ALL WRITE:', all_write
        
    # backwards compute whether or not variables will be read again
    # note that will_read for a line gives the set of variables it will read
    # on the current line or subsequent ones. will_write gives the set of
    # variables that will be written after the current line
    will_read = set()
    will_write = set()
    for line in lines[::-1]:
        will_read = will_read.union(line.read)
        line.will_read = will_read.copy()
        line.will_write = will_write.copy()
        will_write.add(line.write)

    if DEBUG:
        print 'WILL READ/WRITE:'
        for line in lines:
            print line.statement, 'Read:'+str(line.will_read), 'Write:'+str(line.will_write)
        
    # generate cacheing statements for common subexpressions
    # cached subexpressions need to be recomputed whenever they are to be used
    # on the next line, and currently invalid (meaning that the current value
    # stored in the subexpression variable is no longer accurate because one
    # of the variables appearing in it has changed). All subexpressions start
    # as invalid, and are invalidated whenever one of the variables appearing
    # in the RHS changes value.
    subexpressions = dict((name, val) for name, val in variables.items() if isinstance(val, Subexpression))
    # sort subexpressions into an order so that subexpressions that don't depend
    # on other subexpressions are first
    subexpr_deps = dict((name, [dep for dep in subexpr.identifiers if dep in subexpressions]) for \
                                                            name, subexpr in subexpressions.items())
    sorted_subexpr_vars = topsort(subexpr_deps)

    if DEBUG:
        print 'SUBEXPRESSIONS:', subexpressions.keys()
    statements = []
    # all start as invalid
    valid = dict((name, False) for name in subexpressions.keys())
    # none are yet defined (or declared)
    subdefined = dict((name, False) for name in subexpressions.keys())
    for line in lines:
        stmt = line.statement
        read = line.read
        write = line.write
        will_read = line.will_read
        will_write = line.will_write
        # check that all subexpressions in expr are valid, and if not
        # add a definition/set its value, and set it to be valid
        # scan through in sorted order so that recursive subexpression dependencies
        # are handled in the right order
        for var in sorted_subexpr_vars:
            if var not in read:
                continue
            # if subexpression, and invalid
            if not valid.get(var, True): # all non-subexpressions are valid
                subexpression = subexpressions[var]
                # if already defined/declared
                if subdefined[var]:
                    op = '='
                    constant = False
                else:
                    op = ':='
                    subdefined[var] = True
                    # set to constant only if we will not write to it again
                    constant = var not in will_write
                    # check all subvariables are not written to again as well
                    if constant:
                        ids = subexpression.identifiers
                        constant = all(v not in will_write for v in ids)
                valid[var] = True
                statement = Statement(var, op, subexpression.expr, comment='',
                                      dtype=variables[var].dtype,
                                      constant=constant,
                                      subexpression=True,
                                      scalar=variables[var].scalar)
                statements.append(statement)
        var, op, expr, comment = stmt.var, stmt.op, stmt.expr, stmt.comment
        # invalidate any subexpressions including var, recursively
        # we do this by having a set of variables that are invalid that we
        # start with the changed var and increase by any subexpressions we
        # find that have a dependency on something in the invalid set. We
        # go through in sorted subexpression order so that the invalid set
        # is increased in the right order
        invalid = {var}
        for subvar in sorted_subexpr_vars:
            spec = subexpressions[subvar]
            spec_ids = set(spec.identifiers)
            if spec_ids.intersection(invalid):
                valid[subvar] = False
                invalid.add(subvar)
        # constant only if we are declaring a new variable and we will not
        # write to it again
        constant = op==':=' and var not in will_write
        statement = Statement(var, op, expr, comment,
                              dtype=variables[var].dtype,
                              constant=constant,
                              scalar=variables[var].scalar)
        statements.append(statement)

    if DEBUG:
        print 'OUTPUT STATEMENTS:'
        for stmt in statements:
            print stmt

    scalar_statements = [s for s in statements if s.scalar]
    vector_statements = [s for s in statements if not s.scalar]

    if prefs.codegen.loop_invariant_optimisations:
        scalar_constants, vector_statements = apply_loop_invariant_optimisations(vector_statements,
                                                                                 variables,
                                                                                 dtype)
        scalar_statements.extend(scalar_constants)

    return scalar_statements, vector_statements
示例#41
0
    def _add_user_function(self, varname, var, added):
        user_functions = []
        load_namespace = []
        support_code = []
        impl = var.implementations[self.codeobj_class]
        if (impl.name, var) in added:
            return  # nothing to do
        else:
            added.add((impl.name, var))
        func_code = impl.get_code(self.owner)
        # Implementation can be None if the function is already
        # available in Cython (possibly under a different name)
        if func_code is not None:
            if isinstance(func_code, str):
                # Function is provided as Cython code
                # To make namespace variables available to functions, we
                # create global variables and assign to them in the main
                # code
                user_functions.append((varname, var))
                func_namespace = impl.get_namespace(self.owner) or {}
                for ns_key, ns_value in func_namespace.items():
                    load_namespace.append(
                        f'# namespace for function {varname}')
                    if hasattr(ns_value, 'dtype'):
                        if ns_value.shape == ():
                            raise NotImplementedError((
                                'Directly replace scalar values in the function '
                                'instead of providing them via the namespace'))
                        newlines = [
                            "global _namespace{var_name}",
                            "global _namespace_num{var_name}",
                            "cdef _numpy.ndarray[{cpp_dtype}, ndim=1, mode='c'] _buf_{var_name} = _namespace['{var_name}']",
                            "_namespace{var_name} = <{cpp_dtype} *> _buf_{var_name}.data",
                            "_namespace_num{var_name} = len(_namespace['{var_name}'])"
                        ]
                        support_code.append(
                            f"cdef {get_cpp_dtype(ns_value.dtype)} *_namespace{ns_key}"
                        )

                    else:  # e.g. a function
                        newlines = [
                            "_namespace{var_name} = namespace['{var_name}']"
                        ]
                    for line in newlines:
                        load_namespace.append(
                            line.format(
                                cpp_dtype=get_cpp_dtype(ns_value.dtype),
                                numpy_dtype=get_numpy_dtype(ns_value.dtype),
                                var_name=ns_key))
                # Rename references to any dependencies if necessary
                for dep_name, dep in impl.dependencies.items():
                    dep_impl = dep.implementations[self.codeobj_class]
                    dep_impl_name = dep_impl.name
                    if dep_impl_name is None:
                        dep_impl_name = dep.pyfunc.__name__
                    if dep_name != dep_impl_name:
                        func_code = word_substitute(func_code,
                                                    {dep_name: dep_impl_name})
                support_code.append(deindent(func_code))
            elif callable(func_code):
                self.variables[varname] = func_code
                line = f'{varname} = _namespace["{varname}"]'
                load_namespace.append(line)
            else:
                raise TypeError(
                    f"Provided function implementation for function "
                    f"'{varname}' is neither a string nor callable (is "
                    f"type {type(func_code)} instead).")

        dep_support_code = []
        dep_load_namespace = []
        dep_user_functions = []
        if impl.dependencies is not None:
            for dep_name, dep in impl.dependencies.items():
                if dep_name not in self.variables:
                    self.variables[dep_name] = dep
                    user_func = self._add_user_function(dep_name, dep, added)
                    if user_func is not None:
                        sc, ln, uf = user_func
                        dep_support_code.extend(sc)
                        dep_load_namespace.extend(ln)
                        dep_user_functions.extend(uf)

        return (support_code + dep_support_code,
                dep_load_namespace + load_namespace,
                dep_user_functions + user_functions)
示例#42
0
    def determine_keywords(self):
        # set up the restricted pointers, these are used so that the compiler
        # knows there is no aliasing in the pointers, for optimisation
        lines = []
        # it is possible that several different variable names refer to the
        # same array. E.g. in gapjunction code, v_pre and v_post refer to the
        # same array if a group is connected to itself
        handled_pointers = set()
        template_kwds = {}
        # again, do the import here to avoid a circular dependency.
        from brian2.devices.device import get_device
        device = get_device()
        for varname, var in self.variables.iteritems():
            if isinstance(var, ArrayVariable):
                # This is the "true" array name, not the restricted pointer.
                array_name = device.get_array_name(var)
                pointer_name = self.get_array_name(var)
                if pointer_name in handled_pointers:
                    continue
                if getattr(var, 'ndim', 1) > 1:
                    continue  # multidimensional (dynamic) arrays have to be treated differently
                line = self.c_data_type(var.dtype) + ' * ' + self.restrict + pointer_name + ' = ' + array_name + ';'
                lines.append(line)
                handled_pointers.add(pointer_name)

        pointers = '\n'.join(lines)

        # set up the functions
        user_functions = []
        support_code = ''
        hash_defines = ''
        # set convertion types for standard C99 functions in device code
        if prefs.codegen.generators.cuda.default_functions_integral_convertion == np.float64:
            default_func_type = 'double'
            other_func_type = 'float'
        else:  # np.float32
            default_func_type = 'float'
            other_func_type = 'double'
        # set clip function to either use all float or all double arguments
        # see #51 for details
        if prefs['core.default_float_dtype'] == np.float64:
            float_dtype = 'float'
        else:  # np.float32
            float_dtype = 'double'
        for varname, variable in self.variables.items():
            if isinstance(variable, Function):
                user_functions.append((varname, variable))
                funccode = variable.implementations[self.codeobj_class].get_code(self.owner)
                if varname in functions_C99:
                    funccode = funccode.format(default_type=default_func_type, other_type=other_func_type)
                if varname == 'clip':
                    funccode = funccode.format(float_dtype=float_dtype)
                if isinstance(funccode, basestring):
                    funccode = {'support_code': funccode}
                if funccode is not None:
                    support_code += '\n' + deindent(funccode.get('support_code', ''))
                    hash_defines += '\n' + deindent(funccode.get('hashdefine_code', ''))
                # add the Python function with a leading '_python', if it
                # exists. This allows the function to make use of the Python
                # function via weave if necessary (e.g. in the case of randn)
                if not variable.pyfunc is None:
                    pyfunc_name = '_python_' + varname
                    if pyfunc_name in self.variables:
                        logger.warn(('Namespace already contains function %s, '
                                     'not replacing it') % pyfunc_name)
                    else:
                        self.variables[pyfunc_name] = variable.pyfunc

        # delete the user-defined functions from the namespace and add the
        # function namespaces (if any)
        for funcname, func in user_functions:
            del self.variables[funcname]
            func_namespace = func.implementations[self.codeobj_class].get_namespace(self.owner)
            if func_namespace is not None:
                self.variables.update(func_namespace)

        support_code += '\n' + deindent(self.universal_support_code)

        keywords = {'pointers_lines': stripped_deindented_lines(pointers),
                    'support_code_lines': stripped_deindented_lines(support_code),
                    'hashdefine_lines': stripped_deindented_lines(hash_defines),
                    'denormals_code_lines': stripped_deindented_lines(self.denormals_to_zero_code()),
                    'uses_atomics': self.uses_atomics
                    }
        keywords.update(template_kwds)
        return keywords
示例#43
0
    def create_extension(self, code, force=False, name=None,
                         include=None, library_dirs=None, compile_args=None, link_args=None, lib=None,
                         ):

        if Cython is None:
            raise ImportError('Cython is not available')

        code = deindent(code)

        lib_dir = os.path.expanduser('~/.brian/cython_extensions')
        if not os.path.exists(lib_dir):
            os.makedirs(lib_dir)

        key = code, sys.version_info, sys.executable, Cython.__version__
            
        if force:
            # Force a new module name by adding the current time to the
            # key which is hashed to determine the module name.
            key += time.time(),            

        if key in self._code_cache:
            return self._code_cache[key]

        if name is not None:
            module_name = name#py3compat.unicode_to_str(args.name)
        else:
            module_name = "_cython_magic_" + hashlib.md5(str(key).encode('utf-8')).hexdigest()



        module_path = os.path.join(lib_dir, module_name + self.so_ext)
        
        have_module = os.path.isfile(module_path)
        
        if not have_module:
            if include is None:
                include = []
            if library_dirs is None:
                library_dirs = []
            if compile_args is None:
                compile_args = []
            if link_args is None:
                link_args = []
            if lib is None:
                lib = []
                
            c_include_dirs = include
            if 'numpy' in code:
                import numpy
                c_include_dirs.append(numpy.get_include())
            pyx_file = os.path.join(lib_dir, module_name + '.pyx')
            # ignore Python 3 unicode stuff for the moment
            #pyx_file = py3compat.cast_bytes_py2(pyx_file, encoding=sys.getfilesystemencoding())
            #with io.open(pyx_file, 'w') as f:#, encoding='utf-8') as f:
            #    f.write(code)
            open(pyx_file, 'w').write(code)

            extension = Extension(
                name=module_name,
                sources=[pyx_file],
                include_dirs=c_include_dirs,
                library_dirs=library_dirs,
                extra_compile_args=compile_args,
                extra_link_args=link_args,
                libraries=lib,
                language='c++',
                )
            build_extension = self._get_build_extension()
            try:
                opts = dict(
                    quiet=True,
                    annotate=False,
                    force=True,
                    )
                # suppresses the output on stdout
                with std_silent():
                    build_extension.extensions = Cython_Build.cythonize([extension], **opts)

                    build_extension.build_temp = os.path.dirname(pyx_file)
                    build_extension.build_lib = lib_dir
                    build_extension.run()
            except Cython_Compiler.Errors.CompileError:
                return

        module = imp.load_dynamic(module_name, module_path)
        self._code_cache[key] = module
        return module
示例#44
0
    results = dict(
        all=allids,
        read=read,
        write=write,
        funcs=funcs,
        known_all=allids.intersection(known_vars.union(known_funcs)),
        known_read=read.intersection(known_vars),
        known_write=write.intersection(known_vars),
        known_funcs=funcs.intersection(known_funcs),
        unknown_read=read - known_vars,
        unknown_write=write - known_vars,
        unknown_funcs=funcs - known_funcs,
        undefined_read=undefined_read,
        newly_defined=newly_defined,
    )
    return namedtuple('AbstractCodeDependencies', results.keys())(**results)


if __name__ == '__main__':
    code = '''
    x = y+z
    a = f(b)
    '''
    known_vars = set(['y', 'z'])
    print deindent(code)
    print 'known_vars:', known_vars
    print
    r = abstract_code_dependencies(code, known_vars)
    for k, v in r.__dict__.items():
        print k + ':', ', '.join(list(v))
示例#45
0
def make_statements(code, variables, dtype):
    '''
    Turn a series of abstract code statements into Statement objects, inferring
    whether each line is a set/declare operation, whether the variables are
    constant or not, and handling the cacheing of subexpressions. Returns a
    list of Statement objects. For arguments, see documentation for
    :func:`translate`.
    '''
    code = strip_empty_lines(deindent(code))
    lines = re.split(r'[;\n]', code)
    lines = [LineInfo(code=line) for line in lines if len(line)]
    if DEBUG:
        print 'INPUT CODE:'
        print code
    dtypes = dict((name, var.dtype) for name, var in variables.iteritems()
                  if not isinstance(var, Function))
    # we will do inference to work out which lines are := and which are =
    defined = set(k for k, v in variables.iteritems()
                  if not isinstance(v, AuxiliaryVariable))

    for line in lines:
        # parse statement into "var op expr"
        var, op, expr = parse_statement(line.code)
        if op=='=' and var not in defined:
            op = ':='
            defined.add(var)
            if var not in dtypes:
                dtypes[var] = dtype
        statement = Statement(var, op, expr, dtypes[var])
        line.statement = statement
        # for each line will give the variable being written to
        line.write = var 
        # each line will give a set of variables which are read
        line.read = get_identifiers_recursively(expr, variables)
        
    if DEBUG:
        print 'PARSED STATEMENTS:'
        for line in lines:
            print line.statement, 'Read:'+str(line.read), 'Write:'+line.write
    
    # all variables which are written to at some point in the code block
    # used to determine whether they should be const or not
    all_write = set(line.write for line in lines)
    if DEBUG:
        print 'ALL WRITE:', all_write
        
    # backwards compute whether or not variables will be read again
    # note that will_read for a line gives the set of variables it will read
    # on the current line or subsequent ones. will_write gives the set of
    # variables that will be written after the current line
    will_read = set()
    will_write = set()
    for line in lines[::-1]:
        will_read = will_read.union(line.read)
        line.will_read = will_read.copy()
        line.will_write = will_write.copy()
        will_write.add(line.write)

    if DEBUG:
        print 'WILL READ/WRITE:'
        for line in lines:
            print line.statement, 'Read:'+str(line.will_read), 'Write:'+str(line.will_write)
        
    # generate cacheing statements for common subexpressions
    # cached subexpressions need to be recomputed whenever they are to be used
    # on the next line, and currently invalid (meaning that the current value
    # stored in the subexpression variable is no longer accurate because one
    # of the variables appearing in it has changed). All subexpressions start
    # as invalid, and are invalidated whenever one of the variables appearing
    # in the RHS changes value.
    #subexpressions = get_all_subexpressions()
    subexpressions = dict((name, val) for name, val in variables.items() if isinstance(val, Subexpression))

    subexpressions = translate_subexpressions(subexpressions, variables)

    if DEBUG:
        print 'SUBEXPRESSIONS:', subexpressions.keys()
    statements = []
    # all start as invalid
    valid = dict((name, False) for name in subexpressions.keys())
    # none are yet defined (or declared)
    subdefined = dict((name, False) for name in subexpressions.keys())
    for line in lines:
        stmt = line.statement
        read = line.read
        write = line.write
        will_read = line.will_read
        will_write = line.will_write
        # check that all subexpressions in expr are valid, and if not
        # add a definition/set its value, and set it to be valid
        for var in read:
            # if subexpression, and invalid
            if not valid.get(var, True): # all non-subexpressions are valid
                # if already defined/declared
                if subdefined[var]:
                    op = '='
                    constant = False
                else:
                    op = ':='
                    subdefined[var] = True
                    dtypes[var] = variables[var].dtype
                    # set to constant only if we will not write to it again
                    constant = var not in will_write
                    # check all subvariables are not written to again as well
                    if constant:
                        ids = subexpressions[var].identifiers
                        constant = all(v not in will_write for v in ids)
                valid[var] = True
                statement = Statement(var, op, subexpressions[var].expr,
                                      variables[var].dtype, constant=constant,
                                      subexpression=True)
                statements.append(statement)
        var, op, expr = stmt.var, stmt.op, stmt.expr
        # invalidate any subexpressions including var
        for subvar, spec in subexpressions.items():
            if var in spec.identifiers:
                valid[subvar] = False
        # constant only if we are declaring a new variable and we will not
        # write to it again
        constant = op==':=' and var not in will_write
        statement = Statement(var, op, expr, dtypes[var],
                              constant=constant)
        statements.append(statement)

    if DEBUG:
        print 'OUTPUT STATEMENTS:'
        for stmt in statements:
            print stmt

    return statements
示例#46
0
    def determine_keywords(self):
        # set up the restricted pointers, these are used so that the compiler
        # knows there is no aliasing in the pointers, for optimisation
        lines = []
        # It is possible that several different variable names refer to the
        # same array. E.g. in gapjunction code, v_pre and v_post refer to the
        # same array if a group is connected to itself
        handled_pointers = set()
        template_kwds = {}
        # Again, do the import here to avoid a circular dependency.
        from brian2.devices.device import get_device
        device = get_device()
        for varname, var in self.variables.iteritems():
            if isinstance(var, ArrayVariable):
                # This is the "true" array name, not the restricted pointer.
                array_name = device.get_array_name(var)
                pointer_name = self.get_array_name(var)
                if pointer_name in handled_pointers:
                    continue
                if getattr(var, 'dimensions', 1) > 1:
                    continue  # multidimensional (dynamic) arrays have to be treated differently
                line = self.c_data_type(var.dtype) + ' * ' + self.restrict + pointer_name + ' = ' + array_name + ';'
                lines.append(line)
                handled_pointers.add(pointer_name)

        pointers = '\n'.join(lines)

        # set up the functions
        user_functions = []
        support_code = ''
        hash_defines = ''
        for varname, variable in self.variables.items():
            if isinstance(variable, Function):
                user_functions.append((varname, variable))
                funccode = variable.implementations[self.codeobj_class].get_code(self.owner)
                if isinstance(funccode, basestring):
                    funccode = {'support_code': funccode}
                if funccode is not None:
                    support_code += '\n' + deindent(funccode.get('support_code', ''))
                    hash_defines += '\n' + deindent(funccode.get('hashdefine_code', ''))
                # add the Python function with a leading '_python', if it
                # exists. This allows the function to make use of the Python
                # function via weave if necessary (e.g. in the case of randn)
                if not variable.pyfunc is None:
                    pyfunc_name = '_python_' + varname
                    if pyfunc_name in self.variables:
                        logger.warn(('Namespace already contains function %s, '
                                     'not replacing it') % pyfunc_name)
                    else:
                        self.variables[pyfunc_name] = variable.pyfunc

        # delete the user-defined functions from the namespace and add the
        # function namespaces (if any)
        for funcname, func in user_functions:
            del self.variables[funcname]
            func_namespace = func.implementations[self.codeobj_class].get_namespace(self.owner)
            if func_namespace is not None:
                self.variables.update(func_namespace)

        keywords = {'pointers_lines': stripped_deindented_lines(pointers),
                    'support_code_lines': stripped_deindented_lines(support_code),
                    'hashdefine_lines': stripped_deindented_lines(hash_defines),
                    'denormals_code_lines': stripped_deindented_lines(self.denormals_to_zero_code()),
                    }
        keywords.update(template_kwds)
        return keywords
示例#47
0
def make_statements(code, variables, dtype, optimise=True):
    '''
    Turn a series of abstract code statements into Statement objects, inferring
    whether each line is a set/declare operation, whether the variables are
    constant or not, and handling the cacheing of subexpressions.

    Parameters
    ----------
    code : str
        A (multi-line) string of statements.
    variables : dict-like
        A dictionary of with `Variable` and `Function` objects for every
        identifier used in the `code`.
    dtype : `dtype`
        The data type to use for temporary variables
    optimise : bool, optional
        Whether to optimise expressions, including
        pulling out loop invariant expressions and putting them in new
        scalar constants. Defaults to ``False``, since this function is also
        used just to in contexts where we are not interested by this kind of
        optimisation. For the main code generation stage, its value is set by
        the `codegen.loop_invariant_optimisations` preference.
    Returns
    -------
    scalar_statements, vector_statements : (list of `Statement`, list of `Statement`)
        Lists with statements that are to be executed once and statements that
        are to be executed once for every neuron/synapse/... (or in a vectorised
        way)

    Notes
    -----
    If ``optimise`` is ``True``, then the
    ``scalar_statements`` may include newly introduced scalar constants that
    have been identified as loop-invariant and have therefore been pulled out
    of the vector statements. The resulting statements will also use augmented
    assignments where possible, i.e. a statement such as ``w = w + 1`` will be
    replaced by ``w += 1``. Also, statements involving booleans will have
    additional information added to them (see `Statement` for details)
    describing how the statement can be reformulated as a sequence of if/then
    statements. Calls `~brian2.codegen.optimisation.optimise_statements`.
    '''
    code = strip_empty_lines(deindent(code))
    lines = re.split(r'[;\n]', code)
    lines = [LineInfo(code=line) for line in lines if len(line)]
    # Do a copy so we can add stuff without altering the original dict
    variables = dict(variables)
    # we will do inference to work out which lines are := and which are =
    defined = set(k for k, v in variables.iteritems()
                  if not isinstance(v, AuxiliaryVariable))
    for line in lines:
        statement = None
        # parse statement into "var op expr"
        var, op, expr, comment = parse_statement(line.code)
        if op == '=':
            if var not in defined:
                op = ':='
                defined.add(var)
                if var not in variables:
                    is_scalar = is_scalar_expression(expr, variables)
                    new_var = AuxiliaryVariable(var, Unit(1), # doesn't matter here
                                                dtype=dtype, scalar=is_scalar)
                    variables[var] = new_var
            elif not variables[var].is_boolean:
                sympy_expr = str_to_sympy(expr, variables)
                sympy_var = sympy.Symbol(var, real=True)
                try:
                    collected = sympy.collect(sympy_expr, sympy_var,
                                              exact=True, evaluate=False)
                except AttributeError:
                    # If something goes wrong during collection, e.g. collect
                    # does not work for logical expressions
                    collected = {1: sympy_expr}

                if (len(collected) == 2 and
                        set(collected.keys()) == {1, sympy_var} and
                        collected[sympy_var] == 1):
                    # We can replace this statement by a += assignment
                    statement = Statement(var, '+=',
                                          sympy_to_str(collected[1]),
                                          comment,
                                          dtype=variables[var].dtype,
                                          scalar=variables[var].scalar)
                elif len(collected) == 1 and sympy_var in collected:
                    # We can replace this statement by a *= assignment
                    statement = Statement(var, '*=',
                                          sympy_to_str(collected[sympy_var]),
                                          comment,
                                          dtype=variables[var].dtype,
                                          scalar=variables[var].scalar)
        if statement is None:
            statement = Statement(var, op, expr, comment,
                                  dtype=variables[var].dtype,
                                  scalar=variables[var].scalar)

        line.statement = statement
        # for each line will give the variable being written to
        line.write = var 
        # each line will give a set of variables which are read
        line.read = get_identifiers_recursively([expr], variables)

    # All writes to scalar variables must happen before writes to vector
    # variables
    scalar_write_done = False
    for line in lines:
        stmt = line.statement
        if stmt.op != ':=' and variables[stmt.var].scalar and scalar_write_done:
            raise SyntaxError(('All writes to scalar variables in a code block '
                               'have to be made before writes to vector '
                               'variables. Illegal write to %s.') % line.write)
        elif not variables[stmt.var].scalar:
            scalar_write_done = True

    # all variables which are written to at some point in the code block
    # used to determine whether they should be const or not
    all_write = set(line.write for line in lines)

    # backwards compute whether or not variables will be read again
    # note that will_read for a line gives the set of variables it will read
    # on the current line or subsequent ones. will_write gives the set of
    # variables that will be written after the current line
    will_read = set()
    will_write = set()
    for line in lines[::-1]:
        will_read = will_read.union(line.read)
        line.will_read = will_read.copy()
        line.will_write = will_write.copy()
        will_write.add(line.write)

    # generate cacheing statements for common subexpressions
    # cached subexpressions need to be recomputed whenever they are to be used
    # on the next line, and currently invalid (meaning that the current value
    # stored in the subexpression variable is no longer accurate because one
    # of the variables appearing in it has changed). All subexpressions start
    # as invalid, and are invalidated whenever one of the variables appearing
    # in the RHS changes value.
    subexpressions = dict((name, val) for name, val in variables.items() if isinstance(val, Subexpression))
    # sort subexpressions into an order so that subexpressions that don't depend
    # on other subexpressions are first
    subexpr_deps = dict((name, [dep for dep in subexpr.identifiers if dep in subexpressions]) for \
                                                            name, subexpr in subexpressions.items())
    sorted_subexpr_vars = topsort(subexpr_deps)

    statements = []
    # all start as invalid
    valid = dict((name, False) for name in subexpressions.keys())
    # none are yet defined (or declared)
    subdefined = dict((name, False) for name in subexpressions.keys())
    for line in lines:
        stmt = line.statement
        read = line.read
        write = line.write
        will_read = line.will_read
        will_write = line.will_write
        # check that all subexpressions in expr are valid, and if not
        # add a definition/set its value, and set it to be valid
        # scan through in sorted order so that recursive subexpression dependencies
        # are handled in the right order
        for var in sorted_subexpr_vars:
            if var not in read:
                continue
            # if subexpression, and invalid
            if not valid.get(var, True): # all non-subexpressions are valid
                subexpression = subexpressions[var]
                # if already defined/declared
                if subdefined[var]:
                    op = '='
                    constant = False
                else:
                    op = ':='
                    subdefined[var] = True
                    # set to constant only if we will not write to it again
                    constant = var not in will_write
                    # check all subvariables are not written to again as well
                    if constant:
                        ids = subexpression.identifiers
                        constant = all(v not in will_write for v in ids)
                valid[var] = True
                statement = Statement(var, op, subexpression.expr, comment='',
                                      dtype=variables[var].dtype,
                                      constant=constant,
                                      subexpression=True,
                                      scalar=variables[var].scalar)
                statements.append(statement)
        var, op, expr, comment = stmt.var, stmt.op, stmt.expr, stmt.comment
        # invalidate any subexpressions including var, recursively
        # we do this by having a set of variables that are invalid that we
        # start with the changed var and increase by any subexpressions we
        # find that have a dependency on something in the invalid set. We
        # go through in sorted subexpression order so that the invalid set
        # is increased in the right order
        invalid = {var}
        for subvar in sorted_subexpr_vars:
            spec = subexpressions[subvar]
            spec_ids = set(spec.identifiers)
            if spec_ids.intersection(invalid):
                valid[subvar] = False
                invalid.add(subvar)
        # constant only if we are declaring a new variable and we will not
        # write to it again
        constant = op==':=' and var not in will_write
        statement = Statement(var, op, expr, comment,
                              dtype=variables[var].dtype,
                              constant=constant,
                              scalar=variables[var].scalar)
        statements.append(statement)

    scalar_statements = [s for s in statements if s.scalar]
    vector_statements = [s for s in statements if not s.scalar]

    if optimise and prefs.codegen.loop_invariant_optimisations:
        scalar_statements, vector_statements = optimise_statements(scalar_statements,
                                                                   vector_statements,
                                                                   variables)

    return scalar_statements, vector_statements
示例#48
0
    def create_extension(
        self,
        code,
        force=False,
        name=None,
        define_macros=None,
        include_dirs=None,
        library_dirs=None,
        runtime_library_dirs=None,
        extra_compile_args=None,
        extra_link_args=None,
        libraries=None,
        compiler=None,
        owner_name='',
    ):

        self._simplify_paths()

        if Cython is None:
            raise ImportError('Cython is not available')

        code = deindent(code)

        lib_dir = prefs.codegen.runtime.cython.cache_dir
        if lib_dir is None:
            lib_dir = os.path.join(get_cython_cache_dir(), 'brian_extensions')
        if '~' in lib_dir:
            lib_dir = os.path.expanduser(lib_dir)
        try:
            os.makedirs(lib_dir)
        except OSError:
            if not os.path.exists(lib_dir):
                raise IOError(
                    "Couldn't create Cython cache directory '%s', try setting the "
                    "cache directly with prefs.codegen.runtime.cython.cache_dir."
                    % lib_dir)

        key = code, sys.version_info, sys.executable, Cython.__version__

        if force:
            # Force a new module name by adding the current time to the
            # key which is hashed to determine the module name.
            key += time.time(),

        if key in self._code_cache:
            return self._code_cache[key]

        if name is not None:
            module_name = name  #py3compat.unicode_to_str(args.name)
        else:
            module_name = "_cython_magic_" + hashlib.md5(
                str(key).encode('utf-8')).hexdigest()
        if owner_name:
            logger.diagnostic(
                '"{owner_name}" using Cython module "{module_name}"'.format(
                    owner_name=owner_name, module_name=module_name))

        module_path = os.path.join(lib_dir, module_name + self.so_ext)

        if prefs['codegen.runtime.cython.multiprocess_safe']:
            lock_file = os.path.join(lib_dir, module_name + '.lock')
            with open(lock_file, 'w') as f:
                if msvcrt:
                    msvcrt.locking(f.fileno(), msvcrt.LK_RLCK,
                                   os.stat(lock_file).st_size)
                else:
                    fcntl.flock(f, fcntl.LOCK_EX)
                return self._load_module(
                    module_path,
                    define_macros=define_macros,
                    include_dirs=include_dirs,
                    library_dirs=library_dirs,
                    extra_compile_args=extra_compile_args,
                    extra_link_args=extra_link_args,
                    libraries=libraries,
                    code=code,
                    lib_dir=lib_dir,
                    module_name=module_name,
                    runtime_library_dirs=runtime_library_dirs,
                    compiler=compiler,
                    key=key)
        else:
            return self._load_module(module_path,
                                     define_macros=define_macros,
                                     include_dirs=include_dirs,
                                     library_dirs=library_dirs,
                                     extra_compile_args=extra_compile_args,
                                     extra_link_args=extra_link_args,
                                     libraries=libraries,
                                     code=code,
                                     lib_dir=lib_dir,
                                     module_name=module_name,
                                     runtime_library_dirs=runtime_library_dirs,
                                     compiler=compiler,
                                     key=key)
示例#49
0
def make_statements(code, variables, dtype, optimise=True, blockname=''):
    '''
    Turn a series of abstract code statements into Statement objects, inferring
    whether each line is a set/declare operation, whether the variables are
    constant or not, and handling the cacheing of subexpressions.

    Parameters
    ----------
    code : str
        A (multi-line) string of statements.
    variables : dict-like
        A dictionary of with `Variable` and `Function` objects for every
        identifier used in the `code`.
    dtype : `dtype`
        The data type to use for temporary variables
    optimise : bool, optional
        Whether to optimise expressions, including
        pulling out loop invariant expressions and putting them in new
        scalar constants. Defaults to ``False``, since this function is also
        used just to in contexts where we are not interested by this kind of
        optimisation. For the main code generation stage, its value is set by
        the `codegen.loop_invariant_optimisations` preference.
    blockname : str, optional
        A name for the block (used to name intermediate variables to avoid
        name clashes when multiple blocks are used together)
    Returns
    -------
    scalar_statements, vector_statements : (list of `Statement`, list of `Statement`)
        Lists with statements that are to be executed once and statements that
        are to be executed once for every neuron/synapse/... (or in a vectorised
        way)

    Notes
    -----
    If ``optimise`` is ``True``, then the
    ``scalar_statements`` may include newly introduced scalar constants that
    have been identified as loop-invariant and have therefore been pulled out
    of the vector statements. The resulting statements will also use augmented
    assignments where possible, i.e. a statement such as ``w = w + 1`` will be
    replaced by ``w += 1``. Also, statements involving booleans will have
    additional information added to them (see `Statement` for details)
    describing how the statement can be reformulated as a sequence of if/then
    statements. Calls `~brian2.codegen.optimisation.optimise_statements`.
    '''
    code = strip_empty_lines(deindent(code))
    lines = re.split(r'[;\n]', code)
    lines = [LineInfo(code=line) for line in lines if len(line)]
    # Do a copy so we can add stuff without altering the original dict
    variables = dict(variables)
    # we will do inference to work out which lines are := and which are =
    defined = set(k for k, v in variables.iteritems()
                  if not isinstance(v, AuxiliaryVariable))
    for line in lines:
        statement = None
        # parse statement into "var op expr"
        var, op, expr, comment = parse_statement(line.code)
        if op == '=':
            if var not in defined:
                op = ':='
                defined.add(var)
                if var not in variables:
                    is_scalar = is_scalar_expression(expr, variables)
                    new_var = AuxiliaryVariable(var,
                                                dtype=dtype,
                                                scalar=is_scalar)
                    variables[var] = new_var
            elif not variables[var].is_boolean:
                sympy_expr = str_to_sympy(expr, variables)
                sympy_var = sympy.Symbol(var, real=True)
                try:
                    collected = sympy.collect(sympy_expr,
                                              sympy_var,
                                              exact=True,
                                              evaluate=False)
                except AttributeError:
                    # If something goes wrong during collection, e.g. collect
                    # does not work for logical expressions
                    collected = {1: sympy_expr}

                if (len(collected) == 2
                        and set(collected.keys()) == {1, sympy_var}
                        and collected[sympy_var] == 1):
                    # We can replace this statement by a += assignment
                    statement = Statement(var,
                                          '+=',
                                          sympy_to_str(collected[1]),
                                          comment,
                                          dtype=variables[var].dtype,
                                          scalar=variables[var].scalar)
                elif len(collected) == 1 and sympy_var in collected:
                    # We can replace this statement by a *= assignment
                    statement = Statement(var,
                                          '*=',
                                          sympy_to_str(collected[sympy_var]),
                                          comment,
                                          dtype=variables[var].dtype,
                                          scalar=variables[var].scalar)
        if statement is None:
            statement = Statement(var,
                                  op,
                                  expr,
                                  comment,
                                  dtype=variables[var].dtype,
                                  scalar=variables[var].scalar)

        line.statement = statement
        # for each line will give the variable being written to
        line.write = var
        # each line will give a set of variables which are read
        line.read = get_identifiers_recursively([expr], variables)

    # All writes to scalar variables must happen before writes to vector
    # variables
    scalar_write_done = False
    for line in lines:
        stmt = line.statement
        if stmt.op != ':=' and variables[
                stmt.var].scalar and scalar_write_done:
            raise SyntaxError(
                ('All writes to scalar variables in a code block '
                 'have to be made before writes to vector '
                 'variables. Illegal write to %s.') % line.write)
        elif not variables[stmt.var].scalar:
            scalar_write_done = True

    # all variables which are written to at some point in the code block
    # used to determine whether they should be const or not
    all_write = set(line.write for line in lines)

    # backwards compute whether or not variables will be read again
    # note that will_read for a line gives the set of variables it will read
    # on the current line or subsequent ones. will_write gives the set of
    # variables that will be written after the current line
    will_read = set()
    will_write = set()
    for line in lines[::-1]:
        will_read = will_read.union(line.read)
        line.will_read = will_read.copy()
        line.will_write = will_write.copy()
        will_write.add(line.write)

    subexpressions = dict((name, val) for name, val in variables.items()
                          if isinstance(val, Subexpression))
    # sort subexpressions into an order so that subexpressions that don't depend
    # on other subexpressions are first
    subexpr_deps = dict((name, [dep for dep in subexpr.identifiers if dep in subexpressions]) for \
                                                            name, subexpr in subexpressions.items())
    sorted_subexpr_vars = topsort(subexpr_deps)

    statements = []

    # none are yet defined (or declared)
    subdefined = dict((name, False) for name in subexpressions.keys())
    for line in lines:
        stmt = line.statement
        read = line.read
        write = line.write
        will_read = line.will_read
        will_write = line.will_write
        # update/define all subexpressions needed by this statement
        for var in sorted_subexpr_vars:
            if var not in read:
                continue

            subexpression = subexpressions[var]
            # if already defined/declared
            if subdefined[var]:
                op = '='
                constant = False
            else:
                op = ':='
                subdefined[var] = True
                # set to constant only if we will not write to it again
                constant = var not in will_write
                # check all subvariables are not written to again as well
                if constant:
                    ids = subexpression.identifiers
                    constant = all(v not in will_write for v in ids)

            statement = Statement(var,
                                  op,
                                  subexpression.expr,
                                  comment='',
                                  dtype=variables[var].dtype,
                                  constant=constant,
                                  subexpression=True,
                                  scalar=variables[var].scalar)
            statements.append(statement)

        var, op, expr, comment = stmt.var, stmt.op, stmt.expr, stmt.comment

        # constant only if we are declaring a new variable and we will not
        # write to it again
        constant = op == ':=' and var not in will_write
        statement = Statement(var,
                              op,
                              expr,
                              comment,
                              dtype=variables[var].dtype,
                              constant=constant,
                              scalar=variables[var].scalar)
        statements.append(statement)

    scalar_statements = [s for s in statements if s.scalar]
    vector_statements = [s for s in statements if not s.scalar]

    if optimise and prefs.codegen.loop_invariant_optimisations:
        scalar_statements, vector_statements = optimise_statements(
            scalar_statements,
            vector_statements,
            variables,
            blockname=blockname)

    return scalar_statements, vector_statements
示例#50
0
def make_statements(code, specifiers, dtype):
    '''
    Turn a series of abstract code statements into Statement objects, inferring
    whether each line is a set/declare operation, whether the variables are
    constant or not, and handling the cacheing of subexpressions. Returns a
    list of Statement objects. For arguments, see documentation for
    :func:`translate`.
    '''
    code = strip_empty_lines(deindent(code))
    lines = code.split('\n')
    lines = [LineInfo(code=line) for line in lines]
    if DEBUG:
        print 'INPUT CODE:'
        print code
    dtypes = dict((name, value.dtype) for name, value in specifiers.items() if hasattr(value, 'dtype'))
    # we will do inference to work out which lines are := and which are =
    #defined = set(specifiers.keys()) # variables which are already defined
    defined = set(var for var, spec in specifiers.items() if not isinstance(spec, OutputVariable))
    for line in lines:
        # parse statement into "var op expr"
        m = re.search(r'(\+|\-|\*|/|//|%|\*\*|>>|<<|&|\^|\|)?=', line.code)
        if not m:
            raise ValueError("Could not extract statement from: "+line.code)
        start, end = m.start(), m.end()
        op = line.code[start:end].strip()
        var = line.code[:start].strip()
        expr = line.code[end:].strip()
        # var should be a single word
        if len(re.findall(r'^[A-Za-z_][A-Za-z0-9_]*$', var))!=1:
            raise ValueError("LHS in statement must be single variable name, line: "+line.code)
        if op=='=' and var not in defined:
            op = ':='
            defined.add(var)
            if var not in dtypes:
                dtypes[var] = dtype
        statement = Statement(var, op, expr, dtypes[var])
        line.statement = statement
        # for each line will give the variable being written to
        line.write = var 
        # each line will give a set of variables which are read
        line.read = set(get_identifiers(expr))
        
    if DEBUG:
        print 'PARSED STATEMENTS:'
        for line in lines:
            print line.statement, 'Read:'+str(line.read), 'Write:'+line.write
    
    # all variables which are written to at some point in the code block
    # used to determine whether they should be const or not
    all_write = set(line.write for line in lines)
    if DEBUG:
        print 'ALL WRITE:', all_write
        
    # backwards compute whether or not variables will be read again
    # note that will_read for a line gives the set of variables it will read
    # on the current line or subsequent ones. will_write gives the set of
    # variables that will be written after the current line
    will_read = set()
    will_write = set()
    for line in lines[::-1]:
        will_read = will_read.union(line.read)
        line.will_read = will_read.copy()
        line.will_write = will_write.copy()
        will_write.add(line.write)

    if DEBUG:
        print 'WILL READ/WRITE:'
        for line in lines:
            print line.statement, 'Read:'+str(line.will_read), 'Write:'+str(line.will_write)
        
    # generate cacheing statements for common subexpressions
    # cached subexpressions need to be recomputed whenever they are to be used
    # on the next line, and currently invalid (meaning that the current value
    # stored in the subexpression variable is no longer accurate because one
    # of the variables appearing in it has changed). All subexpressions start
    # as invalid, and are invalidated whenever one of the variables appearing
    # in the RHS changes value.
    subexpressions = dict((name, val) for name, val in specifiers.items() if isinstance(val, Subexpression))
    if DEBUG:
        print 'SUBEXPRESSIONS:', subexpressions.keys()
    statements = []
    # all start as invalid
    valid = dict((name, False) for name in subexpressions.keys())
    # none are yet defined (or declared)
    subdefined = dict((name, False) for name in subexpressions.keys())
    for line in lines:
        stmt = line.statement
        read = line.read
        write = line.write
        will_read = line.will_read
        will_write = line.will_write
        # check that all subexpressions in expr are valid, and if not
        # add a definition/set its value, and set it to be valid
        for var in read:
            # if subexpression, and invalid
            if not valid.get(var, True): # all non-subexpressions are valid
                # if already defined/declared
                if subdefined[var]:
                    op = '='
                    constant = False
                else:
                    op = ':='
                    subdefined[var] = True
                    dtypes[var] = dtype # default dtype
                    # set to constant only if we will not write to it again
                    constant = var not in will_write
                    # check all subvariables are not written to again as well
                    if constant:
                        ids = subexpressions[var].identifiers
                        constant = all(v not in will_write for v in ids)
                valid[var] = True
                statement = Statement(var, op, subexpressions[var].expr,
                                      dtype, constant=constant,
                                      subexpression=True)
                statements.append(statement)
        var, op, expr = stmt.var, stmt.op, stmt.expr
        # invalidate any subexpressions including var
        for subvar, spec in subexpressions.items():
            if var in spec.identifiers:
                valid[subvar] = False
        # constant only if we are declaring a new variable and we will not
        # write to it again
        constant = op==':=' and var not in will_write
        statement = Statement(var, op, expr, dtypes[var],
                              constant=constant)
        statements.append(statement)

    if DEBUG:
        print 'OUTPUT STATEMENTS:'
        for stmt in statements:
            print stmt

    return statements
示例#51
0
    def create_extension(
        self,
        code,
        force=False,
        name=None,
        define_macros=None,
        include_dirs=None,
        library_dirs=None,
        runtime_library_dirs=None,
        extra_compile_args=None,
        extra_link_args=None,
        libraries=None,
        compiler=None,
        sources=None,
        owner_name='',
    ):
        if sources is None:
            sources = []
        self._simplify_paths()

        if Cython is None:
            raise ImportError('Cython is not available')

        code = deindent(code)

        lib_dir = get_cython_cache_dir()
        if '~' in lib_dir:
            lib_dir = os.path.expanduser(lib_dir)
        try:
            os.makedirs(lib_dir)
        except OSError:
            if not os.path.exists(lib_dir):
                raise IOError(
                    "Couldn't create Cython cache directory '%s', try setting the "
                    "cache directly with prefs.codegen.runtime.cython.cache_dir."
                    % lib_dir)

        numpy_version = '.'.join(
            numpy.__version__.split('.')[:2])  # Only use major.minor version
        key = code, sys.version_info, sys.executable, Cython.__version__, numpy_version

        if force:
            # Force a new module name by adding the current time to the
            # key which is hashed to determine the module name.
            key += time.time(),

        if key in self._code_cache:
            return self._code_cache[key]

        if name is not None:
            module_name = name  #py3compat.unicode_to_str(args.name)
        else:
            module_name = "_cython_magic_" + hashlib.md5(
                str(key).encode('utf-8')).hexdigest()
        if owner_name:
            logger.diagnostic(
                '"{owner_name}" using Cython module "{module_name}"'.format(
                    owner_name=owner_name, module_name=module_name))

        module_path = os.path.join(lib_dir, module_name + self.so_ext)

        if prefs['codegen.runtime.cython.multiprocess_safe']:
            lock = FileLock(os.path.join(lib_dir, module_name + '.lock'))
            with lock:
                module = self._load_module(
                    module_path,
                    define_macros=define_macros,
                    include_dirs=include_dirs,
                    library_dirs=library_dirs,
                    extra_compile_args=extra_compile_args,
                    extra_link_args=extra_link_args,
                    libraries=libraries,
                    code=code,
                    lib_dir=lib_dir,
                    module_name=module_name,
                    runtime_library_dirs=runtime_library_dirs,
                    compiler=compiler,
                    key=key,
                    sources=sources)
            return module
        else:
            return self._load_module(module_path,
                                     define_macros=define_macros,
                                     include_dirs=include_dirs,
                                     library_dirs=library_dirs,
                                     extra_compile_args=extra_compile_args,
                                     extra_link_args=extra_link_args,
                                     libraries=libraries,
                                     code=code,
                                     lib_dir=lib_dir,
                                     module_name=module_name,
                                     runtime_library_dirs=runtime_library_dirs,
                                     compiler=compiler,
                                     key=key,
                                     sources=sources)
示例#52
0
def substitute_abstract_code_functions(code, funcs):
    '''
    Performs inline substitution of all the functions in the code
    
    Parameters
    ----------
    code : str
        The abstract code to make inline substitutions into.
    funcs : list, dict or set of AbstractCodeFunction
        The function substitutions to use, note in the case of a dict, the
        keys are ignored and the function name is used.
        
    Returns
    -------
    code : str
        The code with inline substitutions performed.
    '''
    if isinstance(funcs, (list, set)):
        newfuncs = dict()
        for f in funcs:
            newfuncs[f.name] = f
        funcs = newfuncs
        
    code = deindent(code)
    lines = ast.parse(code, mode='exec').body

    # This is a slightly nasty hack, but basically we just check by looking at
    # the existing identifiers how many inline operations have already been
    # performed by previous calls to this function
    ids = get_identifiers(code)
    funcstarts = {}
    for func in funcs.values():
        subids = {id for id in ids if id.startswith('_inline_'+func.name+'_')}
        subids = {id.replace('_inline_'+func.name+'_', '') for id in subids}
        alli = []
        for subid in subids:
            p = subid.find('_')
            if p>0:
                subid = subid[:p]
            i = int(subid)
            alli.append(i)
        if len(alli)==0:
            i = 0
        else:
            i = max(alli)+1
        funcstarts[func.name] = i
    
    # Now we rewrite all the lines, replacing each line with a sequence of
    # lines performing the inlining
    newlines = []
    for line in lines:
        for func in funcs.values():
            rw = FunctionRewriter(func, funcstarts[func.name])
            line = rw.visit(line)
            newlines.extend(rw.pre)
            funcstarts[func.name] = rw.numcalls
        newlines.append(line)
        
    # Now we render to a code string
    nr = NodeRenderer()
    newcode = '\n'.join(nr.render_node(line) for line in newlines)
    
    # We recurse until no changes in the code to ensure that all functions
    # are expanded if one function refers to another, etc.
    if newcode==code:
        return newcode
    else:
        return substitute_abstract_code_functions(newcode, funcs)
示例#53
0
def test_atomics_parallelisation():
    # Adapted from brian2.test_synapses:test_ufunc_at_vectorisation()
    for n, code in enumerate(permutation_analysis_good_examples):
        should_be_able_to_use_ufunc_at = not 'NOT_UFUNC_AT_VECTORISABLE' in code
        if should_be_able_to_use_ufunc_at:
            use_ufunc_at_list = [False, True]
        else:
            use_ufunc_at_list = [True]
        code = deindent(code)
        vars = get_identifiers(code)
        vars_src = []
        vars_tgt = []
        vars_syn = []
        vars_shared = []
        vars_const = {}
        for var in vars:
            if var.endswith('_pre'):
                vars_src.append(var[:-4])
            elif var.endswith('_post'):
                vars_tgt.append(var[:-5])
            elif var.endswith('_syn'):
                vars_syn.append(var[:-4])
            elif var.endswith('_shared'):
                vars_shared.append(var[:-7])
            elif var.endswith('_const'):
                vars_const[var[:-6]] = 42
        eqs_src = '\n'.join(var + ':1' for var in vars_src)
        eqs_tgt = '\n'.join(var + ':1' for var in vars_tgt)
        eqs_syn = '\n'.join(var + ':1' for var in vars_syn)
        eqs_syn += '\n' + '\n'.join(var + ':1 (shared)' for var in vars_shared)
        origvals = {}
        endvals = {}
        group_size = 1000
        syn_size = group_size**2
        try:
            BrianLogger._log_messages.clear()
            with catch_logs(log_level=logging.INFO) as caught_logs:
                for use_ufunc_at in use_ufunc_at_list:
                    set_device('cuda_standalone',
                               directory=None,
                               compile=True,
                               run=True,
                               debug=False)
                    CUDACodeGenerator._use_atomics = use_ufunc_at
                    src = NeuronGroup(group_size,
                                      eqs_src,
                                      threshold='True',
                                      name='src')
                    tgt = NeuronGroup(group_size, eqs_tgt, name='tgt')
                    syn = Synapses(src,
                                   tgt,
                                   eqs_syn,
                                   on_pre=code.replace('_syn', '').replace(
                                       '_const', '').replace('_shared', ''),
                                   name='syn',
                                   namespace=vars_const)
                    syn.connect()
                    for G, vars in [(src, vars_src), (tgt, vars_tgt),
                                    (syn, vars_syn)]:
                        for var in vars:
                            fullvar = var + G.name
                            if fullvar in origvals:
                                G.state(var)[:] = origvals[fullvar]
                            else:
                                if isinstance(G, Synapses):
                                    val = rand(syn_size)
                                else:
                                    val = rand(len(G))
                                G.state(var)[:] = val
                                origvals[fullvar] = val.copy()
                    Network(src, tgt, syn).run(5 * defaultclock.dt)
                    for G, vars in [(src, vars_src), (tgt, vars_tgt),
                                    (syn, vars_syn)]:
                        for var in vars:
                            fullvar = var + G.name
                            val = G.state(var)[:].copy()
                            if fullvar in endvals:
                                assert_allclose(val,
                                                endvals[fullvar],
                                                err_msg='%d: %s' % (n, code),
                                                rtol=1e-5)
                            else:
                                endvals[fullvar] = val
                    device.reinit()
                    device.activate()
                cuda_generator_messages = [
                    l for l in caught_logs
                    if l[1] == 'brian2.codegen.generators.cuda_generator'
                ]
                if should_be_able_to_use_ufunc_at:
                    assert len(
                        cuda_generator_messages) == 0, cuda_generator_messages
                else:
                    assert len(
                        cuda_generator_messages) == 1, cuda_generator_messages
                    log_lev, log_mod, log_msg = cuda_generator_messages[0]
                    assert log_msg.startswith(
                        'Failed to parallelise code'), log_msg
        finally:
            CUDACodeGenerator._use_atomics = False  #restore it
            device.reinit()
            device.activate()
示例#54
0
    print 'Saving notebook and converting to RST'
    exporter = NotebookExporter()
    output, _ = exporter.from_notebook_node(notebook)
    with codecs.open(output_ipynb_fname, 'w', encoding='utf-8') as f:
        f.write(output)

    # Insert a note about ipython notebooks with a download link
    note = deindent(u'''
    .. only:: html

        .. |launchbinder| image:: http://mybinder.org/badge.svg
        .. _launchbinder: http://mybinder.org:/repo/brian-team/brian2-binder/notebooks/tutorials/{tutorial}.ipynb
    
        .. note::
           This tutorial is a static non-editable version. You can launch an
           interactive, editable version without installing any local files
           using the Binder service (although note that at some times this
           may be slow or fail to open): |launchbinder|_
    
           Alternatively, you can download a copy of the notebook file
           to use locally: :download:`{tutorial}.ipynb`
    
           See the :doc:`tutorial overview page <index>` for more details.

    '''.format(tutorial=basename))
    notebook.cells.insert(
        1, NotebookNode(cell_type=u'raw', metadata={}, source=note))
    exporter = RSTExporter()
    output, resources = exporter.from_notebook_node(
        notebook, resources={'unique_key': basename + '_image'})
    with codecs.open(output_rst_fname, 'w', encoding='utf-8') as f:
        f.write(output)
示例#55
0
def make_statements(code, variables, dtype):
    '''
    Turn a series of abstract code statements into Statement objects, inferring
    whether each line is a set/declare operation, whether the variables are
    constant or not, and handling the cacheing of subexpressions. Returns a
    list of Statement objects. For arguments, see documentation for
    :func:`translate`.
    '''
    code = strip_empty_lines(deindent(code))
    lines = re.split(r'[;\n]', code)
    lines = [LineInfo(code=line) for line in lines if len(line)]
    if DEBUG:
        print 'INPUT CODE:'
        print code
    dtypes = dict((name, var.dtype) for name, var in variables.iteritems())
    # we will do inference to work out which lines are := and which are =
    defined = set(variables.keys())

    for line in lines:
        # parse statement into "var op expr"
        var, op, expr = parse_statement(line.code)
        if op == '=' and var not in defined:
            op = ':='
            defined.add(var)
            if var not in dtypes:
                dtypes[var] = dtype
        statement = Statement(var, op, expr, dtypes[var])
        line.statement = statement
        # for each line will give the variable being written to
        line.write = var
        # each line will give a set of variables which are read
        line.read = get_identifiers_recursively(expr, variables)

    if DEBUG:
        print 'PARSED STATEMENTS:'
        for line in lines:
            print line.statement, 'Read:' + str(
                line.read), 'Write:' + line.write

    # all variables which are written to at some point in the code block
    # used to determine whether they should be const or not
    all_write = set(line.write for line in lines)
    if DEBUG:
        print 'ALL WRITE:', all_write

    # backwards compute whether or not variables will be read again
    # note that will_read for a line gives the set of variables it will read
    # on the current line or subsequent ones. will_write gives the set of
    # variables that will be written after the current line
    will_read = set()
    will_write = set()
    for line in lines[::-1]:
        will_read = will_read.union(line.read)
        line.will_read = will_read.copy()
        line.will_write = will_write.copy()
        will_write.add(line.write)

    if DEBUG:
        print 'WILL READ/WRITE:'
        for line in lines:
            print line.statement, 'Read:' + str(
                line.will_read), 'Write:' + str(line.will_write)

    # generate cacheing statements for common subexpressions
    # cached subexpressions need to be recomputed whenever they are to be used
    # on the next line, and currently invalid (meaning that the current value
    # stored in the subexpression variable is no longer accurate because one
    # of the variables appearing in it has changed). All subexpressions start
    # as invalid, and are invalidated whenever one of the variables appearing
    # in the RHS changes value.
    #subexpressions = get_all_subexpressions()
    subexpressions = dict((name, val) for name, val in variables.items()
                          if isinstance(val, Subexpression))
    if DEBUG:
        print 'SUBEXPRESSIONS:', subexpressions.keys()
    statements = []
    # all start as invalid
    valid = dict((name, False) for name in subexpressions.keys())
    # none are yet defined (or declared)
    subdefined = dict((name, False) for name in subexpressions.keys())
    for line in lines:
        stmt = line.statement
        read = line.read
        write = line.write
        will_read = line.will_read
        will_write = line.will_write
        # check that all subexpressions in expr are valid, and if not
        # add a definition/set its value, and set it to be valid
        for var in read:
            # if subexpression, and invalid
            if not valid.get(var, True):  # all non-subexpressions are valid
                # if already defined/declared
                if subdefined[var]:
                    op = '='
                    constant = False
                else:
                    op = ':='
                    subdefined[var] = True
                    dtypes[var] = dtype  # default dtype
                    # set to constant only if we will not write to it again
                    constant = var not in will_write
                    # check all subvariables are not written to again as well
                    if constant:
                        ids = subexpressions[var].identifiers
                        constant = all(v not in will_write for v in ids)
                valid[var] = True
                statement = Statement(var,
                                      op,
                                      subexpressions[var].expr,
                                      dtype,
                                      constant=constant,
                                      subexpression=True)
                statements.append(statement)
        var, op, expr = stmt.var, stmt.op, stmt.expr
        # invalidate any subexpressions including var
        for subvar, spec in subexpressions.items():
            if var in spec.identifiers:
                valid[subvar] = False
        # constant only if we are declaring a new variable and we will not
        # write to it again
        constant = op == ':=' and var not in will_write
        statement = Statement(var, op, expr, dtypes[var], constant=constant)
        statements.append(statement)

    if DEBUG:
        print 'OUTPUT STATEMENTS:'
        for stmt in statements:
            print stmt

    return statements
示例#56
0
class CPPCodeGenerator(CodeGenerator):
    '''
    C++ language
    
    C++ code templates should provide Jinja2 macros with the following names:
    
    ``main``
        The main loop.
    ``support_code``
        The support code (function definitions, etc.), compiled in a separate
        file.
        
    For user-defined functions, there are two keys to provide:
    
    ``support_code``
        The function definition which will be added to the support code.
    ``hashdefine_code``
        The ``#define`` code added to the main loop.
        
    See `TimedArray` for an example of these keys.
    '''

    class_name = 'cpp'

    universal_support_code = deindent(mod_support_code)

    def __init__(self, *args, **kwds):
        super(CPPCodeGenerator, self).__init__(*args, **kwds)
        self.c_data_type = c_data_type

    @property
    def restrict(self):
        return prefs['codegen.generators.cpp.restrict_keyword'] + ' '

    @property
    def flush_denormals(self):
        return prefs['codegen.generators.cpp.flush_denormals']

    @staticmethod
    def get_array_name(var, access_data=True):
        # We have to do the import here to avoid circular import dependencies.
        from brian2.devices.device import get_device
        device = get_device()
        if access_data:
            return '_ptr' + device.get_array_name(var)
        else:
            return device.get_array_name(var, access_data=False)

    def translate_expression(self, expr):
        expr = word_substitute(expr, self.func_name_replacements)
        return CPPNodeRenderer().render_expr(expr).strip()

    def translate_statement(self, statement):
        var, op, expr, comment = (statement.var, statement.op, statement.expr,
                                  statement.comment)
        # For C++ we replace complex expressions involving boolean variables into a sequence of
        # if/then expressions with simpler expressions. This is provided by the optimise_statements
        # function.
        if statement.used_boolean_variables is not None and len(
                statement.used_boolean_variables):
            used_boolvars = statement.used_boolean_variables
            bool_simp = statement.boolean_simplified_expressions
            if op == ':=':
                # we have to declare the variable outside the if/then statement (which
                # unfortunately means we can't make it const but the optimisation is worth
                # it anyway).
                codelines = [
                    self.c_data_type(statement.dtype) + ' ' + var + ';'
                ]
                op = '='
            else:
                codelines = []
            firstline = True
            # bool assigns is a sequence of (var, value) pairs giving the conditions under
            # which the simplified expression simp_expr holds
            for bool_assigns, simp_expr in bool_simp.iteritems():
                # generate a boolean expression like ``var1 && var2 && !var3``
                atomics = []
                for boolvar, boolval in bool_assigns:
                    if boolval:
                        atomics.append(boolvar)
                    else:
                        atomics.append('!' + boolvar)
                if firstline:
                    line = ''
                else:
                    line = 'else '
                # only need another if statement when we have more than one boolean variables
                if firstline or len(used_boolvars) > 1:
                    line += 'if(' + (' && '.join(atomics)) + ')'
                line += '\n    '
                line += var + ' ' + op + ' ' + self.translate_expression(
                    simp_expr) + ';'
                codelines.append(line)
                firstline = False
            code = '\n'.join(codelines)
        else:
            if op == ':=':
                decl = self.c_data_type(statement.dtype) + ' '
                op = '='
                if statement.constant:
                    decl = 'const ' + decl
            else:
                decl = ''
            code = decl + var + ' ' + op + ' ' + self.translate_expression(
                expr) + ';'
        if len(comment):
            code += ' // ' + comment
        return code

    def translate_to_read_arrays(self, statements):
        read, write, indices, conditional_write_vars = self.arrays_helper(
            statements)
        lines = []
        # index and read arrays (index arrays first)
        for varname in itertools.chain(indices, read):
            index_var = self.variable_indices[varname]
            var = self.variables[varname]
            if varname not in write:
                line = 'const '
            else:
                line = ''
            line = line + self.c_data_type(var.dtype) + ' ' + varname + ' = '
            line = line + self.get_array_name(
                var, self.variables) + '[' + index_var + '];'
            lines.append(line)
        return lines

    def translate_to_declarations(self, statements):
        read, write, indices, conditional_write_vars = self.arrays_helper(
            statements)
        lines = []
        # simply declare variables that will be written but not read
        for varname in write:
            if varname not in read and varname not in indices:
                var = self.variables[varname]
                line = self.c_data_type(var.dtype) + ' ' + varname + ';'
                lines.append(line)
        return lines

    def translate_to_statements(self, statements):
        read, write, indices, conditional_write_vars = self.arrays_helper(
            statements)
        lines = []
        # the actual code
        for stmt in statements:
            line = self.translate_statement(stmt)
            if stmt.var in conditional_write_vars:
                subs = {}
                condvar = conditional_write_vars[stmt.var]
                lines.append('if(%s)' % condvar)
                lines.append('    ' + line)
            else:
                lines.append(line)
        return lines

    def translate_to_write_arrays(self, statements):
        read, write, indices, conditional_write_vars = self.arrays_helper(
            statements)
        lines = []
        # write arrays
        for varname in write:
            index_var = self.variable_indices[varname]
            var = self.variables[varname]
            line = self.get_array_name(
                var, self.variables) + '[' + index_var + '] = ' + varname + ';'
            lines.append(line)
        return lines

    def translate_one_statement_sequence(self, statements, scalar=False):
        # This function is refactored into four functions which perform the
        # four necessary operations. It's done like this so that code
        # deriving from this class can overwrite specific parts.
        lines = []
        # index and read arrays (index arrays first)
        lines += self.translate_to_read_arrays(statements)
        # simply declare variables that will be written but not read
        lines += self.translate_to_declarations(statements)
        # the actual code
        lines += self.translate_to_statements(statements)
        # write arrays
        lines += self.translate_to_write_arrays(statements)
        code = '\n'.join(lines)
        return stripped_deindented_lines(code)

    def denormals_to_zero_code(self):
        if self.flush_denormals:
            return '''
            #define CSR_FLUSH_TO_ZERO         (1 << 15)
            unsigned csr = __builtin_ia32_stmxcsr();
            csr |= CSR_FLUSH_TO_ZERO;
            __builtin_ia32_ldmxcsr(csr);
            '''
        else:
            return ''

    def _add_user_function(self, varname, variable):
        impl = variable.implementations[self.codeobj_class]
        support_code = []
        hash_defines = []
        pointers = []
        user_functions = [(varname, variable)]
        funccode = impl.get_code(self.owner)
        if isinstance(funccode, basestring):
            funccode = {'support_code': funccode}
        if funccode is not None:
            # To make namespace variables available to functions, we
            # create global variables and assign to them in the main
            # code
            func_namespace = impl.get_namespace(self.owner) or {}
            for ns_key, ns_value in func_namespace.iteritems():
                if hasattr(ns_value, 'dtype'):
                    if ns_value.shape == ():
                        raise NotImplementedError(
                            ('Directly replace scalar values in the function '
                             'instead of providing them via the namespace'))
                    type_str = c_data_type(ns_value.dtype) + '*'
                else:  # e.g. a function
                    type_str = 'py::object'
                support_code.append('static {0} _namespace{1};'.format(
                    type_str, ns_key))
                pointers.append('_namespace{0} = {1};'.format(ns_key, ns_key))
            support_code.append(deindent(funccode.get('support_code', '')))
            hash_defines.append(deindent(funccode.get('hashdefine_code', '')))

        dep_hash_defines = []
        dep_pointers = []
        dep_support_code = []
        if impl.dependencies is not None:
            for dep_name, dep in impl.dependencies.iteritems():
                self.variables[dep_name] = dep
                hd, ps, sc, uf = self._add_user_function(dep_name, dep)
                dep_hash_defines.extend(hd)
                dep_pointers.extend(ps)
                dep_support_code.extend(sc)
                user_functions.extend(uf)

        return (dep_hash_defines + hash_defines, dep_pointers + pointers,
                dep_support_code + support_code, user_functions)

    def determine_keywords(self):
        # set up the restricted pointers, these are used so that the compiler
        # knows there is no aliasing in the pointers, for optimisation
        pointers = []
        # It is possible that several different variable names refer to the
        # same array. E.g. in gapjunction code, v_pre and v_post refer to the
        # same array if a group is connected to itself
        handled_pointers = set()
        template_kwds = {}
        # Again, do the import here to avoid a circular dependency.
        from brian2.devices.device import get_device
        device = get_device()
        for varname, var in self.variables.iteritems():
            if isinstance(var, ArrayVariable):
                # This is the "true" array name, not the restricted pointer.
                array_name = device.get_array_name(var)
                pointer_name = self.get_array_name(var)
                if pointer_name in handled_pointers:
                    continue
                if getattr(var, 'dimensions', 1) > 1:
                    continue  # multidimensional (dynamic) arrays have to be treated differently
                line = '{0}* {1} {2} = {3};'.format(
                    self.c_data_type(var.dtype), self.restrict, pointer_name,
                    array_name)
                pointers.append(line)
                handled_pointers.add(pointer_name)

        # set up the functions
        user_functions = []
        support_code = []
        hash_defines = []
        for varname, variable in self.variables.items():
            if isinstance(variable, Function):
                hd, ps, sc, uf = self._add_user_function(varname, variable)
                user_functions.extend(uf)
                support_code.extend(sc)
                pointers.extend(ps)
                hash_defines.extend(hd)

        # delete the user-defined functions from the namespace and add the
        # function namespaces (if any)
        for funcname, func in user_functions:
            del self.variables[funcname]
            func_namespace = func.implementations[
                self.codeobj_class].get_namespace(self.owner)
            if func_namespace is not None:
                self.variables.update(func_namespace)

        support_code.append(self.universal_support_code)

        keywords = {
            'pointers_lines':
            stripped_deindented_lines('\n'.join(pointers)),
            'support_code_lines':
            stripped_deindented_lines('\n'.join(support_code)),
            'hashdefine_lines':
            stripped_deindented_lines('\n'.join(hash_defines)),
            'denormals_code_lines':
            stripped_deindented_lines('\n'.join(
                self.denormals_to_zero_code())),
        }
        keywords.update(template_kwds)
        return keywords
示例#57
0
文件: cpp_lang.py 项目: yayyme/brian2
    def translate_statement_sequence(self, statements, variables, namespace,
                                     variable_indices, iterate_all):

        # Note that C++ code does not care about the iterate_all argument -- it
        # always has to loop over the elements

        read, write = self.array_read_write(statements, variables)
        lines = []
        # read arrays
        for varname in read:
            index_var = variable_indices[varname]
            var = variables[varname]
            if varname not in write:
                line = 'const '
            else:
                line = ''
            line = line + self.c_data_type(var.dtype) + ' ' + varname + ' = '
            line = line + '_ptr' + var.arrayname + '[' + index_var + '];'
            lines.append(line)
        # simply declare variables that will be written but not read
        for varname in write:
            if varname not in read:
                var = variables[varname]
                line = self.c_data_type(var.dtype) + ' ' + varname + ';'
                lines.append(line)
        # the actual code
        lines.extend([self.translate_statement(stmt) for stmt in statements])
        # write arrays
        for varname in write:
            index_var = variable_indices[varname]
            var = variables[varname]
            line = '_ptr' + var.arrayname + '[' + index_var + '] = ' + varname + ';'
            lines.append(line)
        code = '\n'.join(lines)
        # set up the restricted pointers, these are used so that the compiler
        # knows there is no aliasing in the pointers, for optimisation
        lines = []
        # It is possible that several different variable names refer to the
        # same array. E.g. in gapjunction code, v_pre and v_post refer to the
        # same array if a group is connected to itself
        arraynames = set()
        for varname, var in variables.iteritems():
            if isinstance(var, ArrayVariable):
                arrayname = var.arrayname
                if not arrayname in arraynames:
                    line = self.c_data_type(var.dtype) + ' * ' + self.restrict + '_ptr' + arrayname + ' = ' + arrayname + ';'
                    lines.append(line)
                    arraynames.add(arrayname)
        pointers = '\n'.join(lines)
        
        # set up the functions
        user_functions = []
        support_code = ''
        hash_defines = ''
        for varname, variable in namespace.items():
            if isinstance(variable, Function):
                user_functions.append(varname)
                speccode = variable.code(self, varname)
                support_code += '\n' + deindent(speccode['support_code'])
                hash_defines += deindent(speccode['hashdefine_code'])
                # add the Python function with a leading '_python', if it
                # exists. This allows the function to make use of the Python
                # function via weave if necessary (e.g. in the case of randn)
                if not variable.pyfunc is None:
                    pyfunc_name = '_python_' + varname
                    if pyfunc_name in namespace:
                        logger.warn(('Namespace already contains function %s, '
                                     'not replacing it') % pyfunc_name)
                    else:
                        namespace[pyfunc_name] = variable.pyfunc
        
        # delete the user-defined functions from the namespace
        for func in user_functions:
            del namespace[func]
        
        # return
        return (stripped_deindented_lines(code),
                {'pointers_lines': stripped_deindented_lines(pointers),
                 'support_code_lines': stripped_deindented_lines(support_code),
                 'hashdefine_lines': stripped_deindented_lines(hash_defines),
                 'denormals_code_lines': stripped_deindented_lines(self.denormals_to_zero_code()),
                 })
示例#58
0
    def translate_statement_sequence(self, statements, variables, namespace,
                                     variable_indices, iterate_all):

        # Note that C++ code does not care about the iterate_all argument -- it
        # always has to loop over the elements

        read, write = self.array_read_write(statements, variables)
        lines = []
        # read arrays
        for varname in read:
            index_var = variable_indices[varname]
            var = variables[varname]
            if varname not in write:
                line = 'const '
            else:
                line = ''
            line = line + self.c_data_type(var.dtype) + ' ' + varname + ' = '
            line = line + '_ptr' + var.arrayname + '[' + index_var + '];'
            lines.append(line)
        # simply declare variables that will be written but not read
        for varname in write:
            if varname not in read:
                var = variables[varname]
                line = self.c_data_type(var.dtype) + ' ' + varname + ';'
                lines.append(line)
        # the actual code
        lines.extend([self.translate_statement(stmt) for stmt in statements])
        # write arrays
        for varname in write:
            index_var = variable_indices[varname]
            var = variables[varname]
            line = '_ptr' + var.arrayname + '[' + index_var + '] = ' + varname + ';'
            lines.append(line)
        code = '\n'.join(lines)
        # set up the restricted pointers, these are used so that the compiler
        # knows there is no aliasing in the pointers, for optimisation
        lines = []
        # It is possible that several different variable names refer to the
        # same array. E.g. in gapjunction code, v_pre and v_post refer to the
        # same array if a group is connected to itself
        arraynames = set()
        for varname, var in variables.iteritems():
            if isinstance(var, ArrayVariable):
                arrayname = var.arrayname
                if not arrayname in arraynames:
                    line = self.c_data_type(
                        var.dtype
                    ) + ' * ' + self.restrict + '_ptr' + arrayname + ' = ' + arrayname + ';'
                    lines.append(line)
                    arraynames.add(arrayname)
        pointers = '\n'.join(lines)

        # set up the functions
        user_functions = []
        support_code = ''
        hash_defines = ''
        for varname, variable in namespace.items():
            if isinstance(variable, Function):
                user_functions.append(varname)
                speccode = variable.code(self, varname)
                support_code += '\n' + deindent(speccode['support_code'])
                hash_defines += deindent(speccode['hashdefine_code'])
                # add the Python function with a leading '_python', if it
                # exists. This allows the function to make use of the Python
                # function via weave if necessary (e.g. in the case of randn)
                if not variable.pyfunc is None:
                    pyfunc_name = '_python_' + varname
                    if pyfunc_name in namespace:
                        logger.warn(('Namespace already contains function %s, '
                                     'not replacing it') % pyfunc_name)
                    else:
                        namespace[pyfunc_name] = variable.pyfunc

        # delete the user-defined functions from the namespace
        for func in user_functions:
            del namespace[func]

        # return
        return (stripped_deindented_lines(code), {
            'pointers_lines':
            stripped_deindented_lines(pointers),
            'support_code_lines':
            stripped_deindented_lines(support_code),
            'hashdefine_lines':
            stripped_deindented_lines(hash_defines),
            'denormals_code_lines':
            stripped_deindented_lines(self.denormals_to_zero_code()),
        })
示例#59
0
def make_statements(code, variables, dtype):
    '''
    Turn a series of abstract code statements into Statement objects, inferring
    whether each line is a set/declare operation, whether the variables are
    constant or not, and handling the cacheing of subexpressions.

    Parameters
    ----------
    code : str
        A (multi-line) string of statements.
    variables : dict-like
        A dictionary of with `Variable` and `Function` objects for every
        identifier used in the `code`.
    dtype : `dtype`
        The data type to use for temporary variables

    Returns
    -------
    scalar_statements, vector_statements : (list of `Statement`, list of `Statement`)
        Lists with statements that are to be executed once and statements that
        are to be executed once for every neuron/synapse/... (or in a vectorised
        way)

    Notes
    -----
    The `scalar_statements` may include newly introduced scalar constants that
    have been identified as loop-invariant and have therefore been pulled out
    of the vector statements. The resulting statements will also use augmented
    assignments where possible, i.e. a statement such as ``w = w + 1`` will be
    replaced by ``w += 1``.
    '''
    code = strip_empty_lines(deindent(code))
    lines = re.split(r'[;\n]', code)
    lines = [LineInfo(code=line) for line in lines if len(line)]
    if DEBUG:
        print 'INPUT CODE:'
        print code
    # Do a copy so we can add stuff without altering the original dict
    variables = dict(variables)
    # we will do inference to work out which lines are := and which are =
    defined = set(k for k, v in variables.iteritems()
                  if not isinstance(v, AuxiliaryVariable))
    for line in lines:
        statement = None
        # parse statement into "var op expr"
        var, op, expr, comment = parse_statement(line.code)
        if op == '=':
            if var not in defined:
                op = ':='
                defined.add(var)
                if var not in variables:
                    is_scalar = is_scalar_expression(expr, variables)
                    new_var = AuxiliaryVariable(
                        var,
                        Unit(1),  # doesn't matter here
                        dtype=dtype,
                        scalar=is_scalar)
                    variables[var] = new_var
            elif not variables[var].is_boolean:
                sympy_expr = str_to_sympy(expr)
                sympy_var = sympy.Symbol(var, real=True)
                try:
                    collected = sympy.collect(sympy_expr,
                                              sympy_var,
                                              exact=True,
                                              evaluate=False)
                except AttributeError:
                    # If something goes wrong during collection, e.g. collect
                    # does not work for logical expressions
                    collected = {1: sympy_expr}

                if (len(collected) == 2
                        and set(collected.keys()) == {1, sympy_var}
                        and collected[sympy_var] == 1):
                    # We can replace this statement by a += assignment
                    statement = Statement(var,
                                          '+=',
                                          sympy_to_str(collected[1]),
                                          comment,
                                          dtype=variables[var].dtype,
                                          scalar=variables[var].scalar)
                elif len(collected) == 1 and sympy_var in collected:
                    # We can replace this statement by a *= assignment
                    statement = Statement(var,
                                          '*=',
                                          sympy_to_str(collected[sympy_var]),
                                          comment,
                                          dtype=variables[var].dtype,
                                          scalar=variables[var].scalar)
        if statement is None:
            statement = Statement(var,
                                  op,
                                  expr,
                                  comment,
                                  dtype=variables[var].dtype,
                                  scalar=variables[var].scalar)

        line.statement = statement
        # for each line will give the variable being written to
        line.write = var
        # each line will give a set of variables which are read
        line.read = get_identifiers_recursively([expr], variables)

    # All writes to scalar variables must happen before writes to vector
    # variables
    scalar_write_done = False
    for line in lines:
        stmt = line.statement
        if stmt.op != ':=' and variables[
                stmt.var].scalar and scalar_write_done:
            raise SyntaxError(
                ('All writes to scalar variables in a code block '
                 'have to be made before writes to vector '
                 'variables. Illegal write to %s.') % line.write)
        elif not variables[stmt.var].scalar:
            scalar_write_done = True

    if DEBUG:
        print 'PARSED STATEMENTS:'
        for line in lines:
            print line.statement, 'Read:' + str(
                line.read), 'Write:' + line.write

    # all variables which are written to at some point in the code block
    # used to determine whether they should be const or not
    all_write = set(line.write for line in lines)

    if DEBUG:
        print 'ALL WRITE:', all_write

    # backwards compute whether or not variables will be read again
    # note that will_read for a line gives the set of variables it will read
    # on the current line or subsequent ones. will_write gives the set of
    # variables that will be written after the current line
    will_read = set()
    will_write = set()
    for line in lines[::-1]:
        will_read = will_read.union(line.read)
        line.will_read = will_read.copy()
        line.will_write = will_write.copy()
        will_write.add(line.write)

    if DEBUG:
        print 'WILL READ/WRITE:'
        for line in lines:
            print line.statement, 'Read:' + str(
                line.will_read), 'Write:' + str(line.will_write)

    # generate cacheing statements for common subexpressions
    # cached subexpressions need to be recomputed whenever they are to be used
    # on the next line, and currently invalid (meaning that the current value
    # stored in the subexpression variable is no longer accurate because one
    # of the variables appearing in it has changed). All subexpressions start
    # as invalid, and are invalidated whenever one of the variables appearing
    # in the RHS changes value.
    subexpressions = dict((name, val) for name, val in variables.items()
                          if isinstance(val, Subexpression))
    # sort subexpressions into an order so that subexpressions that don't depend
    # on other subexpressions are first
    subexpr_deps = dict((name, [dep for dep in subexpr.identifiers if dep in subexpressions]) for \
                                                            name, subexpr in subexpressions.items())
    sorted_subexpr_vars = topsort(subexpr_deps)

    if DEBUG:
        print 'SUBEXPRESSIONS:', subexpressions.keys()
    statements = []
    # all start as invalid
    valid = dict((name, False) for name in subexpressions.keys())
    # none are yet defined (or declared)
    subdefined = dict((name, False) for name in subexpressions.keys())
    for line in lines:
        stmt = line.statement
        read = line.read
        write = line.write
        will_read = line.will_read
        will_write = line.will_write
        # check that all subexpressions in expr are valid, and if not
        # add a definition/set its value, and set it to be valid
        # scan through in sorted order so that recursive subexpression dependencies
        # are handled in the right order
        for var in sorted_subexpr_vars:
            if var not in read:
                continue
            # if subexpression, and invalid
            if not valid.get(var, True):  # all non-subexpressions are valid
                subexpression = subexpressions[var]
                # if already defined/declared
                if subdefined[var]:
                    op = '='
                    constant = False
                else:
                    op = ':='
                    subdefined[var] = True
                    # set to constant only if we will not write to it again
                    constant = var not in will_write
                    # check all subvariables are not written to again as well
                    if constant:
                        ids = subexpression.identifiers
                        constant = all(v not in will_write for v in ids)
                valid[var] = True
                statement = Statement(var,
                                      op,
                                      subexpression.expr,
                                      comment='',
                                      dtype=variables[var].dtype,
                                      constant=constant,
                                      subexpression=True,
                                      scalar=variables[var].scalar)
                statements.append(statement)
        var, op, expr, comment = stmt.var, stmt.op, stmt.expr, stmt.comment
        # invalidate any subexpressions including var, recursively
        # we do this by having a set of variables that are invalid that we
        # start with the changed var and increase by any subexpressions we
        # find that have a dependency on something in the invalid set. We
        # go through in sorted subexpression order so that the invalid set
        # is increased in the right order
        invalid = {var}
        for subvar in sorted_subexpr_vars:
            spec = subexpressions[subvar]
            spec_ids = set(spec.identifiers)
            if spec_ids.intersection(invalid):
                valid[subvar] = False
                invalid.add(subvar)
        # constant only if we are declaring a new variable and we will not
        # write to it again
        constant = op == ':=' and var not in will_write
        statement = Statement(var,
                              op,
                              expr,
                              comment,
                              dtype=variables[var].dtype,
                              constant=constant,
                              scalar=variables[var].scalar)
        statements.append(statement)

    if DEBUG:
        print 'OUTPUT STATEMENTS:'
        for stmt in statements:
            print stmt

    scalar_statements = [s for s in statements if s.scalar]
    vector_statements = [s for s in statements if not s.scalar]

    if prefs.codegen.loop_invariant_optimisations:
        scalar_constants, vector_statements = apply_loop_invariant_optimisations(
            vector_statements, variables, dtype)
        scalar_statements.extend(scalar_constants)

    return scalar_statements, vector_statements