Ejemplo n.º 1
0
    def _sort_subexpressions(self):
        '''
        Sorts the subexpressions in a way that resolves their dependencies
        upon each other. After this method has been run, the subexpressions
        returned by the ``ordered`` property are in the order in which
        they should be updated
        '''

        # Get a dictionary of all the dependencies on other subexpressions,
        # i.e. ignore dependencies on parameters and differential equations
        static_deps = {}
        for eq in self._equations.itervalues():
            if eq.type == SUBEXPRESSION:
                static_deps[eq.varname] = [
                    dep for dep in eq.identifiers if dep in self._equations
                    and self._equations[dep].type == SUBEXPRESSION
                ]

        try:
            sorted_eqs = topsort(static_deps)
        except ValueError:
            raise ValueError('Cannot resolve dependencies between static '
                             'equations, dependencies contain a cycle.')

        # put the equations objects in the correct order
        for order, static_variable in enumerate(sorted_eqs):
            self._equations[static_variable].update_order = order

        # Sort differential equations and parameters after subexpressions
        for eq in self._equations.itervalues():
            if eq.type == DIFFERENTIAL_EQUATION:
                eq.update_order = len(sorted_eqs)
            elif eq.type == PARAMETER:
                eq.update_order = len(sorted_eqs) + 1
Ejemplo n.º 2
0
    def _sort_subexpressions(self):
        '''
        Sorts the subexpressions in a way that resolves their dependencies
        upon each other. After this method has been run, the subexpressions
        returned by the ``ordered`` property are in the order in which
        they should be updated
        '''

        # Get a dictionary of all the dependencies on other subexpressions,
        # i.e. ignore dependencies on parameters and differential equations
        static_deps = {}
        for eq in self._equations.itervalues():
            if eq.type == SUBEXPRESSION:
                static_deps[eq.varname] = [dep for dep in eq.identifiers if
                                           dep in self._equations and
                                           self._equations[dep].type == SUBEXPRESSION]
        
        try:
            sorted_eqs = topsort(static_deps)
        except ValueError:
            raise ValueError('Cannot resolve dependencies between static '
                             'equations, dependencies contain a cycle.')

        # put the equations objects in the correct order
        for order, static_variable in enumerate(sorted_eqs):
            self._equations[static_variable].update_order = order

        # Sort differential equations and parameters after subexpressions
        for eq in self._equations.itervalues():
            if eq.type == DIFFERENTIAL_EQUATION:
                eq.update_order = len(sorted_eqs)
            elif eq.type == PARAMETER:
                eq.update_order = len(sorted_eqs) + 1
Ejemplo n.º 3
0
def make_statements(code, variables, dtype, optimise=True, blockname=''):
    '''
    Turn a series of abstract code statements into Statement objects, inferring
    whether each line is a set/declare operation, whether the variables are
    constant or not, and handling the cacheing of subexpressions.

    Parameters
    ----------
    code : str
        A (multi-line) string of statements.
    variables : dict-like
        A dictionary of with `Variable` and `Function` objects for every
        identifier used in the `code`.
    dtype : `dtype`
        The data type to use for temporary variables
    optimise : bool, optional
        Whether to optimise expressions, including
        pulling out loop invariant expressions and putting them in new
        scalar constants. Defaults to ``False``, since this function is also
        used just to in contexts where we are not interested by this kind of
        optimisation. For the main code generation stage, its value is set by
        the `codegen.loop_invariant_optimisations` preference.
    blockname : str, optional
        A name for the block (used to name intermediate variables to avoid
        name clashes when multiple blocks are used together)
    Returns
    -------
    scalar_statements, vector_statements : (list of `Statement`, list of `Statement`)
        Lists with statements that are to be executed once and statements that
        are to be executed once for every neuron/synapse/... (or in a vectorised
        way)

    Notes
    -----
    If ``optimise`` is ``True``, then the
    ``scalar_statements`` may include newly introduced scalar constants that
    have been identified as loop-invariant and have therefore been pulled out
    of the vector statements. The resulting statements will also use augmented
    assignments where possible, i.e. a statement such as ``w = w + 1`` will be
    replaced by ``w += 1``. Also, statements involving booleans will have
    additional information added to them (see `Statement` for details)
    describing how the statement can be reformulated as a sequence of if/then
    statements. Calls `~brian2.codegen.optimisation.optimise_statements`.
    '''
    code = strip_empty_lines(deindent(code))
    lines = re.split(r'[;\n]', code)
    lines = [LineInfo(code=line) for line in lines if len(line)]
    # Do a copy so we can add stuff without altering the original dict
    variables = dict(variables)
    # we will do inference to work out which lines are := and which are =
    defined = set(k for k, v in variables.iteritems()
                  if not isinstance(v, AuxiliaryVariable))
    for line in lines:
        statement = None
        # parse statement into "var op expr"
        var, op, expr, comment = parse_statement(line.code)
        if op == '=':
            if var not in defined:
                op = ':='
                defined.add(var)
                if var not in variables:
                    is_scalar = is_scalar_expression(expr, variables)
                    new_var = AuxiliaryVariable(var,
                                                dtype=dtype,
                                                scalar=is_scalar)
                    variables[var] = new_var
            elif not variables[var].is_boolean:
                sympy_expr = str_to_sympy(expr, variables)
                sympy_var = sympy.Symbol(var, real=True)
                try:
                    collected = sympy.collect(sympy_expr,
                                              sympy_var,
                                              exact=True,
                                              evaluate=False)
                except AttributeError:
                    # If something goes wrong during collection, e.g. collect
                    # does not work for logical expressions
                    collected = {1: sympy_expr}

                if (len(collected) == 2
                        and set(collected.keys()) == {1, sympy_var}
                        and collected[sympy_var] == 1):
                    # We can replace this statement by a += assignment
                    statement = Statement(var,
                                          '+=',
                                          sympy_to_str(collected[1]),
                                          comment,
                                          dtype=variables[var].dtype,
                                          scalar=variables[var].scalar)
                elif len(collected) == 1 and sympy_var in collected:
                    # We can replace this statement by a *= assignment
                    statement = Statement(var,
                                          '*=',
                                          sympy_to_str(collected[sympy_var]),
                                          comment,
                                          dtype=variables[var].dtype,
                                          scalar=variables[var].scalar)
        if statement is None:
            statement = Statement(var,
                                  op,
                                  expr,
                                  comment,
                                  dtype=variables[var].dtype,
                                  scalar=variables[var].scalar)

        line.statement = statement
        # for each line will give the variable being written to
        line.write = var
        # each line will give a set of variables which are read
        line.read = get_identifiers_recursively([expr], variables)

    # All writes to scalar variables must happen before writes to vector
    # variables
    scalar_write_done = False
    for line in lines:
        stmt = line.statement
        if stmt.op != ':=' and variables[
                stmt.var].scalar and scalar_write_done:
            raise SyntaxError(
                ('All writes to scalar variables in a code block '
                 'have to be made before writes to vector '
                 'variables. Illegal write to %s.') % line.write)
        elif not variables[stmt.var].scalar:
            scalar_write_done = True

    # all variables which are written to at some point in the code block
    # used to determine whether they should be const or not
    all_write = set(line.write for line in lines)

    # backwards compute whether or not variables will be read again
    # note that will_read for a line gives the set of variables it will read
    # on the current line or subsequent ones. will_write gives the set of
    # variables that will be written after the current line
    will_read = set()
    will_write = set()
    for line in lines[::-1]:
        will_read = will_read.union(line.read)
        line.will_read = will_read.copy()
        line.will_write = will_write.copy()
        will_write.add(line.write)

    subexpressions = dict((name, val) for name, val in variables.items()
                          if isinstance(val, Subexpression))
    # sort subexpressions into an order so that subexpressions that don't depend
    # on other subexpressions are first
    subexpr_deps = dict((name, [dep for dep in subexpr.identifiers if dep in subexpressions]) for \
                                                            name, subexpr in subexpressions.items())
    sorted_subexpr_vars = topsort(subexpr_deps)

    statements = []

    # none are yet defined (or declared)
    subdefined = dict((name, False) for name in subexpressions.keys())
    for line in lines:
        stmt = line.statement
        read = line.read
        write = line.write
        will_read = line.will_read
        will_write = line.will_write
        # update/define all subexpressions needed by this statement
        for var in sorted_subexpr_vars:
            if var not in read:
                continue

            subexpression = subexpressions[var]
            # if already defined/declared
            if subdefined[var]:
                op = '='
                constant = False
            else:
                op = ':='
                subdefined[var] = True
                # set to constant only if we will not write to it again
                constant = var not in will_write
                # check all subvariables are not written to again as well
                if constant:
                    ids = subexpression.identifiers
                    constant = all(v not in will_write for v in ids)

            statement = Statement(var,
                                  op,
                                  subexpression.expr,
                                  comment='',
                                  dtype=variables[var].dtype,
                                  constant=constant,
                                  subexpression=True,
                                  scalar=variables[var].scalar)
            statements.append(statement)

        var, op, expr, comment = stmt.var, stmt.op, stmt.expr, stmt.comment

        # constant only if we are declaring a new variable and we will not
        # write to it again
        constant = op == ':=' and var not in will_write
        statement = Statement(var,
                              op,
                              expr,
                              comment,
                              dtype=variables[var].dtype,
                              constant=constant,
                              scalar=variables[var].scalar)
        statements.append(statement)

    scalar_statements = [s for s in statements if s.scalar]
    vector_statements = [s for s in statements if not s.scalar]

    if optimise and prefs.codegen.loop_invariant_optimisations:
        scalar_statements, vector_statements = optimise_statements(
            scalar_statements,
            vector_statements,
            variables,
            blockname=blockname)

    return scalar_statements, vector_statements
Ejemplo n.º 4
0
def make_statements(code, variables, dtype, optimise=True):
    '''
    Turn a series of abstract code statements into Statement objects, inferring
    whether each line is a set/declare operation, whether the variables are
    constant or not, and handling the cacheing of subexpressions.

    Parameters
    ----------
    code : str
        A (multi-line) string of statements.
    variables : dict-like
        A dictionary of with `Variable` and `Function` objects for every
        identifier used in the `code`.
    dtype : `dtype`
        The data type to use for temporary variables
    optimise : bool, optional
        Whether to optimise expressions, including
        pulling out loop invariant expressions and putting them in new
        scalar constants. Defaults to ``False``, since this function is also
        used just to in contexts where we are not interested by this kind of
        optimisation. For the main code generation stage, its value is set by
        the `codegen.loop_invariant_optimisations` preference.
    Returns
    -------
    scalar_statements, vector_statements : (list of `Statement`, list of `Statement`)
        Lists with statements that are to be executed once and statements that
        are to be executed once for every neuron/synapse/... (or in a vectorised
        way)

    Notes
    -----
    If ``optimise`` is ``True``, then the
    ``scalar_statements`` may include newly introduced scalar constants that
    have been identified as loop-invariant and have therefore been pulled out
    of the vector statements. The resulting statements will also use augmented
    assignments where possible, i.e. a statement such as ``w = w + 1`` will be
    replaced by ``w += 1``. Also, statements involving booleans will have
    additional information added to them (see `Statement` for details)
    describing how the statement can be reformulated as a sequence of if/then
    statements. Calls `~brian2.codegen.optimisation.optimise_statements`.
    '''
    code = strip_empty_lines(deindent(code))
    lines = re.split(r'[;\n]', code)
    lines = [LineInfo(code=line) for line in lines if len(line)]
    # Do a copy so we can add stuff without altering the original dict
    variables = dict(variables)
    # we will do inference to work out which lines are := and which are =
    defined = set(k for k, v in variables.iteritems()
                  if not isinstance(v, AuxiliaryVariable))
    for line in lines:
        statement = None
        # parse statement into "var op expr"
        var, op, expr, comment = parse_statement(line.code)
        if op == '=':
            if var not in defined:
                op = ':='
                defined.add(var)
                if var not in variables:
                    is_scalar = is_scalar_expression(expr, variables)
                    new_var = AuxiliaryVariable(var, Unit(1), # doesn't matter here
                                                dtype=dtype, scalar=is_scalar)
                    variables[var] = new_var
            elif not variables[var].is_boolean:
                sympy_expr = str_to_sympy(expr, variables)
                sympy_var = sympy.Symbol(var, real=True)
                try:
                    collected = sympy.collect(sympy_expr, sympy_var,
                                              exact=True, evaluate=False)
                except AttributeError:
                    # If something goes wrong during collection, e.g. collect
                    # does not work for logical expressions
                    collected = {1: sympy_expr}

                if (len(collected) == 2 and
                        set(collected.keys()) == {1, sympy_var} and
                        collected[sympy_var] == 1):
                    # We can replace this statement by a += assignment
                    statement = Statement(var, '+=',
                                          sympy_to_str(collected[1]),
                                          comment,
                                          dtype=variables[var].dtype,
                                          scalar=variables[var].scalar)
                elif len(collected) == 1 and sympy_var in collected:
                    # We can replace this statement by a *= assignment
                    statement = Statement(var, '*=',
                                          sympy_to_str(collected[sympy_var]),
                                          comment,
                                          dtype=variables[var].dtype,
                                          scalar=variables[var].scalar)
        if statement is None:
            statement = Statement(var, op, expr, comment,
                                  dtype=variables[var].dtype,
                                  scalar=variables[var].scalar)

        line.statement = statement
        # for each line will give the variable being written to
        line.write = var 
        # each line will give a set of variables which are read
        line.read = get_identifiers_recursively([expr], variables)

    # All writes to scalar variables must happen before writes to vector
    # variables
    scalar_write_done = False
    for line in lines:
        stmt = line.statement
        if stmt.op != ':=' and variables[stmt.var].scalar and scalar_write_done:
            raise SyntaxError(('All writes to scalar variables in a code block '
                               'have to be made before writes to vector '
                               'variables. Illegal write to %s.') % line.write)
        elif not variables[stmt.var].scalar:
            scalar_write_done = True

    # all variables which are written to at some point in the code block
    # used to determine whether they should be const or not
    all_write = set(line.write for line in lines)

    # backwards compute whether or not variables will be read again
    # note that will_read for a line gives the set of variables it will read
    # on the current line or subsequent ones. will_write gives the set of
    # variables that will be written after the current line
    will_read = set()
    will_write = set()
    for line in lines[::-1]:
        will_read = will_read.union(line.read)
        line.will_read = will_read.copy()
        line.will_write = will_write.copy()
        will_write.add(line.write)

    # generate cacheing statements for common subexpressions
    # cached subexpressions need to be recomputed whenever they are to be used
    # on the next line, and currently invalid (meaning that the current value
    # stored in the subexpression variable is no longer accurate because one
    # of the variables appearing in it has changed). All subexpressions start
    # as invalid, and are invalidated whenever one of the variables appearing
    # in the RHS changes value.
    subexpressions = dict((name, val) for name, val in variables.items() if isinstance(val, Subexpression))
    # sort subexpressions into an order so that subexpressions that don't depend
    # on other subexpressions are first
    subexpr_deps = dict((name, [dep for dep in subexpr.identifiers if dep in subexpressions]) for \
                                                            name, subexpr in subexpressions.items())
    sorted_subexpr_vars = topsort(subexpr_deps)

    statements = []
    # all start as invalid
    valid = dict((name, False) for name in subexpressions.keys())
    # none are yet defined (or declared)
    subdefined = dict((name, False) for name in subexpressions.keys())
    for line in lines:
        stmt = line.statement
        read = line.read
        write = line.write
        will_read = line.will_read
        will_write = line.will_write
        # check that all subexpressions in expr are valid, and if not
        # add a definition/set its value, and set it to be valid
        # scan through in sorted order so that recursive subexpression dependencies
        # are handled in the right order
        for var in sorted_subexpr_vars:
            if var not in read:
                continue
            # if subexpression, and invalid
            if not valid.get(var, True): # all non-subexpressions are valid
                subexpression = subexpressions[var]
                # if already defined/declared
                if subdefined[var]:
                    op = '='
                    constant = False
                else:
                    op = ':='
                    subdefined[var] = True
                    # set to constant only if we will not write to it again
                    constant = var not in will_write
                    # check all subvariables are not written to again as well
                    if constant:
                        ids = subexpression.identifiers
                        constant = all(v not in will_write for v in ids)
                valid[var] = True
                statement = Statement(var, op, subexpression.expr, comment='',
                                      dtype=variables[var].dtype,
                                      constant=constant,
                                      subexpression=True,
                                      scalar=variables[var].scalar)
                statements.append(statement)
        var, op, expr, comment = stmt.var, stmt.op, stmt.expr, stmt.comment
        # invalidate any subexpressions including var, recursively
        # we do this by having a set of variables that are invalid that we
        # start with the changed var and increase by any subexpressions we
        # find that have a dependency on something in the invalid set. We
        # go through in sorted subexpression order so that the invalid set
        # is increased in the right order
        invalid = {var}
        for subvar in sorted_subexpr_vars:
            spec = subexpressions[subvar]
            spec_ids = set(spec.identifiers)
            if spec_ids.intersection(invalid):
                valid[subvar] = False
                invalid.add(subvar)
        # constant only if we are declaring a new variable and we will not
        # write to it again
        constant = op==':=' and var not in will_write
        statement = Statement(var, op, expr, comment,
                              dtype=variables[var].dtype,
                              constant=constant,
                              scalar=variables[var].scalar)
        statements.append(statement)

    scalar_statements = [s for s in statements if s.scalar]
    vector_statements = [s for s in statements if not s.scalar]

    if optimise and prefs.codegen.loop_invariant_optimisations:
        scalar_statements, vector_statements = optimise_statements(scalar_statements,
                                                                   vector_statements,
                                                                   variables)

    return scalar_statements, vector_statements
Ejemplo n.º 5
0
def make_statements(code, variables, dtype):
    '''
    Turn a series of abstract code statements into Statement objects, inferring
    whether each line is a set/declare operation, whether the variables are
    constant or not, and handling the cacheing of subexpressions.

    Parameters
    ----------
    code : str
        A (multi-line) string of statements.
    variables : dict-like
        A dictionary of with `Variable` and `Function` objects for every
        identifier used in the `code`.
    dtype : `dtype`
        The data type to use for temporary variables

    Returns
    -------
    scalar_statements, vector_statements : (list of `Statement`, list of `Statement`)
        Lists with statements that are to be executed once and statements that
        are to be executed once for every neuron/synapse/... (or in a vectorised
        way)

    Notes
    -----
    The `scalar_statements` may include newly introduced scalar constants that
    have been identified as loop-invariant and have therefore been pulled out
    of the vector statements.
    '''
    code = strip_empty_lines(deindent(code))
    lines = re.split(r'[;\n]', code)
    lines = [LineInfo(code=line) for line in lines if len(line)]
    if DEBUG:
        print 'INPUT CODE:'
        print code
    # Do a copy so we can add stuff without altering the original dict
    variables = dict(variables)
    # we will do inference to work out which lines are := and which are =
    defined = set(k for k, v in variables.iteritems()
                  if not isinstance(v, AuxiliaryVariable))
    for line in lines:
        # parse statement into "var op expr"
        var, op, expr, comment = parse_statement(line.code)
        if op=='=':
            if var not in defined:
                op = ':='
                defined.add(var)
                if var not in variables:
                    is_scalar = is_scalar_expression(expr, variables)
                    new_var = AuxiliaryVariable(var, Unit(1), # doesn't matter here
                                                dtype=dtype, scalar=is_scalar)
                    variables[var] = new_var


        statement = Statement(var, op, expr, comment,
                              dtype=variables[var].dtype,
                              scalar=variables[var].scalar)
        line.statement = statement
        # for each line will give the variable being written to
        line.write = var 
        # each line will give a set of variables which are read
        line.read = get_identifiers_recursively([expr], variables)

    # All writes to scalar variables must happen before writes to vector
    # variables
    scalar_write_done = False
    for line in lines:
        stmt = line.statement
        if stmt.op != ':=' and variables[stmt.var].scalar and scalar_write_done:
            raise SyntaxError(('All writes to scalar variables in a code block '
                               'have to be made before writes to vector '
                               'variables. Illegal write to %s.') % line.write)
        elif not variables[stmt.var].scalar:
            scalar_write_done = True

    if DEBUG:
        print 'PARSED STATEMENTS:'
        for line in lines:
            print line.statement, 'Read:'+str(line.read), 'Write:'+line.write
    
    # all variables which are written to at some point in the code block
    # used to determine whether they should be const or not
    all_write = set(line.write for line in lines)

    if DEBUG:
        print 'ALL WRITE:', all_write
        
    # backwards compute whether or not variables will be read again
    # note that will_read for a line gives the set of variables it will read
    # on the current line or subsequent ones. will_write gives the set of
    # variables that will be written after the current line
    will_read = set()
    will_write = set()
    for line in lines[::-1]:
        will_read = will_read.union(line.read)
        line.will_read = will_read.copy()
        line.will_write = will_write.copy()
        will_write.add(line.write)

    if DEBUG:
        print 'WILL READ/WRITE:'
        for line in lines:
            print line.statement, 'Read:'+str(line.will_read), 'Write:'+str(line.will_write)
        
    # generate cacheing statements for common subexpressions
    # cached subexpressions need to be recomputed whenever they are to be used
    # on the next line, and currently invalid (meaning that the current value
    # stored in the subexpression variable is no longer accurate because one
    # of the variables appearing in it has changed). All subexpressions start
    # as invalid, and are invalidated whenever one of the variables appearing
    # in the RHS changes value.
    subexpressions = dict((name, val) for name, val in variables.items() if isinstance(val, Subexpression))
    # sort subexpressions into an order so that subexpressions that don't depend
    # on other subexpressions are first
    subexpr_deps = dict((name, [dep for dep in subexpr.identifiers if dep in subexpressions]) for \
                                                            name, subexpr in subexpressions.items())
    sorted_subexpr_vars = topsort(subexpr_deps)

    if DEBUG:
        print 'SUBEXPRESSIONS:', subexpressions.keys()
    statements = []
    # all start as invalid
    valid = dict((name, False) for name in subexpressions.keys())
    # none are yet defined (or declared)
    subdefined = dict((name, False) for name in subexpressions.keys())
    for line in lines:
        stmt = line.statement
        read = line.read
        write = line.write
        will_read = line.will_read
        will_write = line.will_write
        # check that all subexpressions in expr are valid, and if not
        # add a definition/set its value, and set it to be valid
        # scan through in sorted order so that recursive subexpression dependencies
        # are handled in the right order
        for var in sorted_subexpr_vars:
            if var not in read:
                continue
            # if subexpression, and invalid
            if not valid.get(var, True): # all non-subexpressions are valid
                subexpression = subexpressions[var]
                # if already defined/declared
                if subdefined[var]:
                    op = '='
                    constant = False
                else:
                    op = ':='
                    subdefined[var] = True
                    # set to constant only if we will not write to it again
                    constant = var not in will_write
                    # check all subvariables are not written to again as well
                    if constant:
                        ids = subexpression.identifiers
                        constant = all(v not in will_write for v in ids)
                valid[var] = True
                statement = Statement(var, op, subexpression.expr, comment='',
                                      dtype=variables[var].dtype,
                                      constant=constant,
                                      subexpression=True,
                                      scalar=variables[var].scalar)
                statements.append(statement)
        var, op, expr, comment = stmt.var, stmt.op, stmt.expr, stmt.comment
        # invalidate any subexpressions including var, recursively
        # we do this by having a set of variables that are invalid that we
        # start with the changed var and increase by any subexpressions we
        # find that have a dependency on something in the invalid set. We
        # go through in sorted subexpression order so that the invalid set
        # is increased in the right order
        invalid = {var}
        for subvar in sorted_subexpr_vars:
            spec = subexpressions[subvar]
            spec_ids = set(spec.identifiers)
            if spec_ids.intersection(invalid):
                valid[subvar] = False
                invalid.add(subvar)
        # constant only if we are declaring a new variable and we will not
        # write to it again
        constant = op==':=' and var not in will_write
        statement = Statement(var, op, expr, comment,
                              dtype=variables[var].dtype,
                              constant=constant,
                              scalar=variables[var].scalar)
        statements.append(statement)

    if DEBUG:
        print 'OUTPUT STATEMENTS:'
        for stmt in statements:
            print stmt

    scalar_statements = [s for s in statements if s.scalar]
    vector_statements = [s for s in statements if not s.scalar]

    if prefs.codegen.loop_invariant_optimisations:
        scalar_constants, vector_statements = apply_loop_invariant_optimisations(vector_statements,
                                                                                 variables,
                                                                                 dtype)
        scalar_statements.extend(scalar_constants)

    return scalar_statements, vector_statements
Ejemplo n.º 6
0
def make_statements(code, variables, dtype):
    '''
    Turn a series of abstract code statements into Statement objects, inferring
    whether each line is a set/declare operation, whether the variables are
    constant or not, and handling the cacheing of subexpressions. Returns a
    list of Statement objects. For arguments, see documentation for
    :func:`translate`.
    '''
    code = strip_empty_lines(deindent(code))
    lines = re.split(r'[;\n]', code)
    lines = [LineInfo(code=line) for line in lines if len(line)]
    if DEBUG:
        print 'INPUT CODE:'
        print code
    dtypes = dict((name, var.dtype) for name, var in variables.iteritems()
                  if not isinstance(var, Function))
    # we will do inference to work out which lines are := and which are =
    defined = set(k for k, v in variables.iteritems()
                  if not isinstance(v, AuxiliaryVariable))
    scalars = set(k for k,v in variables.iteritems()
                  if getattr(v, 'scalar', False))
    for line in lines:
        # parse statement into "var op expr"
        var, op, expr = parse_statement(line.code)
        if op=='=':
            if var not in defined:
                op = ':='
                defined.add(var)
                if var not in dtypes:
                    dtypes[var] = dtype
                # determine whether this is a scalar variable
                identifiers = get_identifiers_recursively(expr, variables)
                # In the following we assume that all unknown identifiers are
                # scalar constants -- this should cover numerical literals and
                # e.g. "True" or "inf".
                is_scalar = all((name in scalars) or not (name in defined)
                                for name in identifiers)
                if is_scalar:
                    scalars.add(var)

        statement = Statement(var, op, expr, dtypes[var], scalar=var in scalars)
        line.statement = statement
        # for each line will give the variable being written to
        line.write = var 
        # each line will give a set of variables which are read
        line.read = get_identifiers_recursively(expr, variables)

    # All writes to scalar variables must happen before writes to vector
    # variables
    scalar_write_done = False
    for line in lines:
        stmt = line.statement
        if stmt.op != ':=' and stmt.var in scalars and scalar_write_done:
            raise SyntaxError(('All writes to scalar variables in a code block '
                               'have to be made before writes to vector '
                               'variables. Illegal write to %s.') % line.write)
        elif not stmt.var in scalars:
            scalar_write_done = True

    if DEBUG:
        print 'PARSED STATEMENTS:'
        for line in lines:
            print line.statement, 'Read:'+str(line.read), 'Write:'+line.write
    
    # all variables which are written to at some point in the code block
    # used to determine whether they should be const or not
    all_write = set(line.write for line in lines)

    if DEBUG:
        print 'ALL WRITE:', all_write
        
    # backwards compute whether or not variables will be read again
    # note that will_read for a line gives the set of variables it will read
    # on the current line or subsequent ones. will_write gives the set of
    # variables that will be written after the current line
    will_read = set()
    will_write = set()
    for line in lines[::-1]:
        will_read = will_read.union(line.read)
        line.will_read = will_read.copy()
        line.will_write = will_write.copy()
        will_write.add(line.write)

    if DEBUG:
        print 'WILL READ/WRITE:'
        for line in lines:
            print line.statement, 'Read:'+str(line.will_read), 'Write:'+str(line.will_write)
        
    # generate cacheing statements for common subexpressions
    # cached subexpressions need to be recomputed whenever they are to be used
    # on the next line, and currently invalid (meaning that the current value
    # stored in the subexpression variable is no longer accurate because one
    # of the variables appearing in it has changed). All subexpressions start
    # as invalid, and are invalidated whenever one of the variables appearing
    # in the RHS changes value.
    subexpressions = dict((name, val) for name, val in variables.items() if isinstance(val, Subexpression))
    # sort subexpressions into an order so that subexpressions that don't depend
    # on other subexpressions are first
    subexpr_deps = dict((name, [dep for dep in subexpr.identifiers if dep in subexpressions]) for \
                                                            name, subexpr in subexpressions.items())
    sorted_subexpr_vars = topsort(subexpr_deps)

    if DEBUG:
        print 'SUBEXPRESSIONS:', subexpressions.keys()
    statements = []
    # all start as invalid
    valid = dict((name, False) for name in subexpressions.keys())
    # none are yet defined (or declared)
    subdefined = dict((name, False) for name in subexpressions.keys())
    for line in lines:
        stmt = line.statement
        read = line.read
        write = line.write
        will_read = line.will_read
        will_write = line.will_write
        # check that all subexpressions in expr are valid, and if not
        # add a definition/set its value, and set it to be valid
        # scan through in sorted order so that recursive subexpression dependencies
        # are handled in the right order
        for var in sorted_subexpr_vars:
            if var not in read:
                continue
            # if subexpression, and invalid
            if not valid.get(var, True): # all non-subexpressions are valid
                subexpression = subexpressions[var]
                # if already defined/declared
                if subdefined[var]:
                    op = '='
                    constant = False
                else:
                    op = ':='
                    subdefined[var] = True
                    dtypes[var] = variables[var].dtype
                    # set to constant only if we will not write to it again
                    constant = var not in will_write
                    # check all subvariables are not written to again as well
                    if constant:
                        ids = subexpression.identifiers
                        constant = all(v not in will_write for v in ids)
                valid[var] = True
                statement = Statement(var, op, subexpression.expr,
                                      variables[var].dtype, constant=constant,
                                      subexpression=True, scalar=var in scalars)
                statements.append(statement)
        var, op, expr = stmt.var, stmt.op, stmt.expr
        # invalidate any subexpressions including var, recursively
        # we do this by having a set of variables that are invalid that we
        # start with the changed var and increase by any subexpressions we
        # find that have a dependency on something in the invalid set. We
        # go through in sorted subexpression order so that the invalid set
        # is increased in the right order
        invalid = set([var])
        for subvar in sorted_subexpr_vars:
            spec = subexpressions[subvar]
            spec_ids = set(spec.identifiers)
            if spec_ids.intersection(invalid):
                valid[subvar] = False
                invalid.add(subvar)
        # constant only if we are declaring a new variable and we will not
        # write to it again
        constant = op==':=' and var not in will_write
        statement = Statement(var, op, expr, dtypes[var],
                              constant=constant, scalar=var in scalars)
        statements.append(statement)

    if DEBUG:
        print 'OUTPUT STATEMENTS:'
        for stmt in statements:
            print stmt

    return statements
Ejemplo n.º 7
0
def make_statements(code, variables, dtype):
    '''
    Turn a series of abstract code statements into Statement objects, inferring
    whether each line is a set/declare operation, whether the variables are
    constant or not, and handling the cacheing of subexpressions.

    Parameters
    ----------
    code : str
        A (multi-line) string of statements.
    variables : dict-like
        A dictionary of with `Variable` and `Function` objects for every
        identifier used in the `code`.
    dtype : `dtype`
        The data type to use for temporary variables

    Returns
    -------
    scalar_statements, vector_statements : (list of `Statement`, list of `Statement`)
        Lists with statements that are to be executed once and statements that
        are to be executed once for every neuron/synapse/... (or in a vectorised
        way)

    Notes
    -----
    The `scalar_statements` may include newly introduced scalar constants that
    have been identified as loop-invariant and have therefore been pulled out
    of the vector statements. The resulting statements will also use augmented
    assignments where possible, i.e. a statement such as ``w = w + 1`` will be
    replaced by ``w += 1``.
    '''
    code = strip_empty_lines(deindent(code))
    lines = re.split(r'[;\n]', code)
    lines = [LineInfo(code=line) for line in lines if len(line)]
    if DEBUG:
        print 'INPUT CODE:'
        print code
    # Do a copy so we can add stuff without altering the original dict
    variables = dict(variables)
    # we will do inference to work out which lines are := and which are =
    defined = set(k for k, v in variables.iteritems()
                  if not isinstance(v, AuxiliaryVariable))
    for line in lines:
        statement = None
        # parse statement into "var op expr"
        var, op, expr, comment = parse_statement(line.code)
        if op == '=':
            if var not in defined:
                op = ':='
                defined.add(var)
                if var not in variables:
                    is_scalar = is_scalar_expression(expr, variables)
                    new_var = AuxiliaryVariable(
                        var,
                        Unit(1),  # doesn't matter here
                        dtype=dtype,
                        scalar=is_scalar)
                    variables[var] = new_var
            elif not variables[var].is_boolean:
                sympy_expr = str_to_sympy(expr)
                sympy_var = sympy.Symbol(var, real=True)
                try:
                    collected = sympy.collect(sympy_expr,
                                              sympy_var,
                                              exact=True,
                                              evaluate=False)
                except AttributeError:
                    # If something goes wrong during collection, e.g. collect
                    # does not work for logical expressions
                    collected = {1: sympy_expr}

                if (len(collected) == 2
                        and set(collected.keys()) == {1, sympy_var}
                        and collected[sympy_var] == 1):
                    # We can replace this statement by a += assignment
                    statement = Statement(var,
                                          '+=',
                                          sympy_to_str(collected[1]),
                                          comment,
                                          dtype=variables[var].dtype,
                                          scalar=variables[var].scalar)
                elif len(collected) == 1 and sympy_var in collected:
                    # We can replace this statement by a *= assignment
                    statement = Statement(var,
                                          '*=',
                                          sympy_to_str(collected[sympy_var]),
                                          comment,
                                          dtype=variables[var].dtype,
                                          scalar=variables[var].scalar)
        if statement is None:
            statement = Statement(var,
                                  op,
                                  expr,
                                  comment,
                                  dtype=variables[var].dtype,
                                  scalar=variables[var].scalar)

        line.statement = statement
        # for each line will give the variable being written to
        line.write = var
        # each line will give a set of variables which are read
        line.read = get_identifiers_recursively([expr], variables)

    # All writes to scalar variables must happen before writes to vector
    # variables
    scalar_write_done = False
    for line in lines:
        stmt = line.statement
        if stmt.op != ':=' and variables[
                stmt.var].scalar and scalar_write_done:
            raise SyntaxError(
                ('All writes to scalar variables in a code block '
                 'have to be made before writes to vector '
                 'variables. Illegal write to %s.') % line.write)
        elif not variables[stmt.var].scalar:
            scalar_write_done = True

    if DEBUG:
        print 'PARSED STATEMENTS:'
        for line in lines:
            print line.statement, 'Read:' + str(
                line.read), 'Write:' + line.write

    # all variables which are written to at some point in the code block
    # used to determine whether they should be const or not
    all_write = set(line.write for line in lines)

    if DEBUG:
        print 'ALL WRITE:', all_write

    # backwards compute whether or not variables will be read again
    # note that will_read for a line gives the set of variables it will read
    # on the current line or subsequent ones. will_write gives the set of
    # variables that will be written after the current line
    will_read = set()
    will_write = set()
    for line in lines[::-1]:
        will_read = will_read.union(line.read)
        line.will_read = will_read.copy()
        line.will_write = will_write.copy()
        will_write.add(line.write)

    if DEBUG:
        print 'WILL READ/WRITE:'
        for line in lines:
            print line.statement, 'Read:' + str(
                line.will_read), 'Write:' + str(line.will_write)

    # generate cacheing statements for common subexpressions
    # cached subexpressions need to be recomputed whenever they are to be used
    # on the next line, and currently invalid (meaning that the current value
    # stored in the subexpression variable is no longer accurate because one
    # of the variables appearing in it has changed). All subexpressions start
    # as invalid, and are invalidated whenever one of the variables appearing
    # in the RHS changes value.
    subexpressions = dict((name, val) for name, val in variables.items()
                          if isinstance(val, Subexpression))
    # sort subexpressions into an order so that subexpressions that don't depend
    # on other subexpressions are first
    subexpr_deps = dict((name, [dep for dep in subexpr.identifiers if dep in subexpressions]) for \
                                                            name, subexpr in subexpressions.items())
    sorted_subexpr_vars = topsort(subexpr_deps)

    if DEBUG:
        print 'SUBEXPRESSIONS:', subexpressions.keys()
    statements = []
    # all start as invalid
    valid = dict((name, False) for name in subexpressions.keys())
    # none are yet defined (or declared)
    subdefined = dict((name, False) for name in subexpressions.keys())
    for line in lines:
        stmt = line.statement
        read = line.read
        write = line.write
        will_read = line.will_read
        will_write = line.will_write
        # check that all subexpressions in expr are valid, and if not
        # add a definition/set its value, and set it to be valid
        # scan through in sorted order so that recursive subexpression dependencies
        # are handled in the right order
        for var in sorted_subexpr_vars:
            if var not in read:
                continue
            # if subexpression, and invalid
            if not valid.get(var, True):  # all non-subexpressions are valid
                subexpression = subexpressions[var]
                # if already defined/declared
                if subdefined[var]:
                    op = '='
                    constant = False
                else:
                    op = ':='
                    subdefined[var] = True
                    # set to constant only if we will not write to it again
                    constant = var not in will_write
                    # check all subvariables are not written to again as well
                    if constant:
                        ids = subexpression.identifiers
                        constant = all(v not in will_write for v in ids)
                valid[var] = True
                statement = Statement(var,
                                      op,
                                      subexpression.expr,
                                      comment='',
                                      dtype=variables[var].dtype,
                                      constant=constant,
                                      subexpression=True,
                                      scalar=variables[var].scalar)
                statements.append(statement)
        var, op, expr, comment = stmt.var, stmt.op, stmt.expr, stmt.comment
        # invalidate any subexpressions including var, recursively
        # we do this by having a set of variables that are invalid that we
        # start with the changed var and increase by any subexpressions we
        # find that have a dependency on something in the invalid set. We
        # go through in sorted subexpression order so that the invalid set
        # is increased in the right order
        invalid = {var}
        for subvar in sorted_subexpr_vars:
            spec = subexpressions[subvar]
            spec_ids = set(spec.identifiers)
            if spec_ids.intersection(invalid):
                valid[subvar] = False
                invalid.add(subvar)
        # constant only if we are declaring a new variable and we will not
        # write to it again
        constant = op == ':=' and var not in will_write
        statement = Statement(var,
                              op,
                              expr,
                              comment,
                              dtype=variables[var].dtype,
                              constant=constant,
                              scalar=variables[var].scalar)
        statements.append(statement)

    if DEBUG:
        print 'OUTPUT STATEMENTS:'
        for stmt in statements:
            print stmt

    scalar_statements = [s for s in statements if s.scalar]
    vector_statements = [s for s in statements if not s.scalar]

    if prefs.codegen.loop_invariant_optimisations:
        scalar_constants, vector_statements = apply_loop_invariant_optimisations(
            vector_statements, variables, dtype)
        scalar_statements.extend(scalar_constants)

    return scalar_statements, vector_statements
Ejemplo n.º 8
0
def make_statements(code, variables, dtype):
    '''
    Turn a series of abstract code statements into Statement objects, inferring
    whether each line is a set/declare operation, whether the variables are
    constant or not, and handling the cacheing of subexpressions. Returns a
    list of Statement objects. For arguments, see documentation for
    :func:`translate`.
    '''
    code = strip_empty_lines(deindent(code))
    lines = re.split(r'[;\n]', code)
    lines = [LineInfo(code=line) for line in lines if len(line)]
    if DEBUG:
        print 'INPUT CODE:'
        print code
    dtypes = dict((name, var.dtype) for name, var in variables.iteritems()
                  if not isinstance(var, Function))
    # we will do inference to work out which lines are := and which are =
    defined = set(k for k, v in variables.iteritems()
                  if not isinstance(v, AuxiliaryVariable))
    scalars = set(k for k, v in variables.iteritems()
                  if getattr(v, 'scalar', False))
    for line in lines:
        # parse statement into "var op expr"
        var, op, expr, comment = parse_statement(line.code)
        if op == '=':
            if var not in defined:
                op = ':='
                defined.add(var)
                if var not in dtypes:
                    dtypes[var] = dtype
                # determine whether this is a scalar variable
                identifiers = get_identifiers_recursively([expr], variables)
                # In the following we assume that all unknown identifiers are
                # scalar constants -- this should cover numerical literals and
                # e.g. "True" or "inf".
                is_scalar = all((name in scalars) or not (name in defined)
                                for name in identifiers)
                if is_scalar:
                    scalars.add(var)

        statement = Statement(var,
                              op,
                              expr,
                              comment,
                              dtype=dtypes[var],
                              scalar=var in scalars)
        line.statement = statement
        # for each line will give the variable being written to
        line.write = var
        # each line will give a set of variables which are read
        line.read = get_identifiers_recursively([expr], variables)

    # All writes to scalar variables must happen before writes to vector
    # variables
    scalar_write_done = False
    for line in lines:
        stmt = line.statement
        if stmt.op != ':=' and stmt.var in scalars and scalar_write_done:
            raise SyntaxError(
                ('All writes to scalar variables in a code block '
                 'have to be made before writes to vector '
                 'variables. Illegal write to %s.') % line.write)
        elif not stmt.var in scalars:
            scalar_write_done = True

    if DEBUG:
        print 'PARSED STATEMENTS:'
        for line in lines:
            print line.statement, 'Read:' + str(
                line.read), 'Write:' + line.write

    # all variables which are written to at some point in the code block
    # used to determine whether they should be const or not
    all_write = set(line.write for line in lines)

    if DEBUG:
        print 'ALL WRITE:', all_write

    # backwards compute whether or not variables will be read again
    # note that will_read for a line gives the set of variables it will read
    # on the current line or subsequent ones. will_write gives the set of
    # variables that will be written after the current line
    will_read = set()
    will_write = set()
    for line in lines[::-1]:
        will_read = will_read.union(line.read)
        line.will_read = will_read.copy()
        line.will_write = will_write.copy()
        will_write.add(line.write)

    if DEBUG:
        print 'WILL READ/WRITE:'
        for line in lines:
            print line.statement, 'Read:' + str(
                line.will_read), 'Write:' + str(line.will_write)

    # generate cacheing statements for common subexpressions
    # cached subexpressions need to be recomputed whenever they are to be used
    # on the next line, and currently invalid (meaning that the current value
    # stored in the subexpression variable is no longer accurate because one
    # of the variables appearing in it has changed). All subexpressions start
    # as invalid, and are invalidated whenever one of the variables appearing
    # in the RHS changes value.
    subexpressions = dict((name, val) for name, val in variables.items()
                          if isinstance(val, Subexpression))
    # sort subexpressions into an order so that subexpressions that don't depend
    # on other subexpressions are first
    subexpr_deps = dict((name, [dep for dep in subexpr.identifiers if dep in subexpressions]) for \
                                                            name, subexpr in subexpressions.items())
    sorted_subexpr_vars = topsort(subexpr_deps)

    if DEBUG:
        print 'SUBEXPRESSIONS:', subexpressions.keys()
    statements = []
    # all start as invalid
    valid = dict((name, False) for name in subexpressions.keys())
    # none are yet defined (or declared)
    subdefined = dict((name, False) for name in subexpressions.keys())
    for line in lines:
        stmt = line.statement
        read = line.read
        write = line.write
        will_read = line.will_read
        will_write = line.will_write
        # check that all subexpressions in expr are valid, and if not
        # add a definition/set its value, and set it to be valid
        # scan through in sorted order so that recursive subexpression dependencies
        # are handled in the right order
        for var in sorted_subexpr_vars:
            if var not in read:
                continue
            # if subexpression, and invalid
            if not valid.get(var, True):  # all non-subexpressions are valid
                subexpression = subexpressions[var]
                # if already defined/declared
                if subdefined[var]:
                    op = '='
                    constant = False
                else:
                    op = ':='
                    subdefined[var] = True
                    dtypes[var] = variables[var].dtype
                    # set to constant only if we will not write to it again
                    constant = var not in will_write
                    # check all subvariables are not written to again as well
                    if constant:
                        ids = subexpression.identifiers
                        constant = all(v not in will_write for v in ids)
                valid[var] = True
                statement = Statement(var,
                                      op,
                                      subexpression.expr,
                                      comment='',
                                      dtype=variables[var].dtype,
                                      constant=constant,
                                      subexpression=True,
                                      scalar=var in scalars)
                statements.append(statement)
        var, op, expr, comment = stmt.var, stmt.op, stmt.expr, stmt.comment
        # invalidate any subexpressions including var, recursively
        # we do this by having a set of variables that are invalid that we
        # start with the changed var and increase by any subexpressions we
        # find that have a dependency on something in the invalid set. We
        # go through in sorted subexpression order so that the invalid set
        # is increased in the right order
        invalid = set([var])
        for subvar in sorted_subexpr_vars:
            spec = subexpressions[subvar]
            spec_ids = set(spec.identifiers)
            if spec_ids.intersection(invalid):
                valid[subvar] = False
                invalid.add(subvar)
        # constant only if we are declaring a new variable and we will not
        # write to it again
        constant = op == ':=' and var not in will_write
        statement = Statement(var,
                              op,
                              expr,
                              comment,
                              dtype=dtypes[var],
                              constant=constant,
                              scalar=var in scalars)
        statements.append(statement)

    if DEBUG:
        print 'OUTPUT STATEMENTS:'
        for stmt in statements:
            print stmt

    return statements