def find_differential_variables(self, code): ''' Find the variables that were tagged _gsl_{var}_f{ind} and return var, ind pairs. `GSLStateUpdater` tagged differential variables and here we extract the information given in these tags. Parameters ---------- code : list of strings A list of strings containing gsl tagged variables Returns ------- diff_vars : dict A dictionary with variable names as keys and differential equation index as value ''' diff_vars = {} for expr_set in code: for expr in expr_set.split('\n'): expr = expr.strip(' ') try: lhs, op, rhs, comment = parse_statement(expr) except ValueError: pass m = re.search('_gsl_(.+?)_f([0-9]*)$', lhs) if m: diff_vars[m.group(1)] = m.group(2) return diff_vars
def translate_scalar_code(self, code_lines, variables_in_scalar, variables_in_vector): ''' Translate scalar code: if calculated variables are used in the vector_code their value is added to the variable in the _dataholder. Parameters ---------- code_lines : list list of strings containing scalar code variables_in_vector : dict dictionary with variable name (str), `Variable` pairs of variables occurring in vector code variables_in_scalar : dict dictionary with variable name (str), `Variable` pairs of variables occurring in scalar code Returns ------- scalar_code : str code fragment that should be injected in the main before the loop ''' code = [] for line in code_lines: m = re.search('(\w+ = .*)', line) try: new_line = m.group(1) var, op, expr, comment = parse_statement(new_line) except (ValueError, AttributeError): code += [line] continue if var in variables_in_scalar.keys(): code += [line] elif var in variables_in_vector.keys(): if var == 't': continue try: self.variables_to_be_processed.remove(var) except KeyError: raise AssertionError( ("Trying to process variable named %s by " "putting its value in the _GSL_dataholder " "based on scalar code, but the variable " "has been processed already." % var)) code += [ '_GSL_dataholder.{var} {op} {expr} {comment}'.format( var=var, op=op, expr=expr, comment=comment) ] return '\n'.join(code)
def translate_scalar_code(self, code_lines, variables_in_scalar, variables_in_vector): ''' Translate scalar code: if calculated variables are used in the vector_code their value is added to the variable in the _dataholder. Parameters ---------- code_lines : list list of strings containing scalar code variables_in_vector : dict dictionary with variable name (str), `Variable` pairs of variables occurring in vector code variables_in_scalar : dict dictionary with variable name (str), `Variable` pairs of variables occurring in scalar code Returns ------- scalar_code : str code fragment that should be injected in the main before the loop ''' code = [] for line in code_lines: m = re.search('(\w+ = .*)', line) try: new_line = m.group(1) var, op, expr, comment = parse_statement(new_line) except (ValueError, AttributeError): code += [line] continue if var in variables_in_scalar.keys(): code += [line] elif var in variables_in_vector.keys(): if var == 't': continue try: self.variables_to_be_processed.remove(var) except KeyError: raise AssertionError(("Trying to process variable named %s by " "putting its value in the _GSL_dataholder " "based on scalar code, but the variable " "has been processed already." % var)) code += ['_GSL_dataholder.{var} {op} {expr} {comment}'.format( var=var, op=op, expr=expr, comment=comment)] return '\n'.join(code)
def make_statements(code, variables, dtype, optimise=True, blockname=''): ''' Turn a series of abstract code statements into Statement objects, inferring whether each line is a set/declare operation, whether the variables are constant or not, and handling the cacheing of subexpressions. Parameters ---------- code : str A (multi-line) string of statements. variables : dict-like A dictionary of with `Variable` and `Function` objects for every identifier used in the `code`. dtype : `dtype` The data type to use for temporary variables optimise : bool, optional Whether to optimise expressions, including pulling out loop invariant expressions and putting them in new scalar constants. Defaults to ``False``, since this function is also used just to in contexts where we are not interested by this kind of optimisation. For the main code generation stage, its value is set by the `codegen.loop_invariant_optimisations` preference. blockname : str, optional A name for the block (used to name intermediate variables to avoid name clashes when multiple blocks are used together) Returns ------- scalar_statements, vector_statements : (list of `Statement`, list of `Statement`) Lists with statements that are to be executed once and statements that are to be executed once for every neuron/synapse/... (or in a vectorised way) Notes ----- If ``optimise`` is ``True``, then the ``scalar_statements`` may include newly introduced scalar constants that have been identified as loop-invariant and have therefore been pulled out of the vector statements. The resulting statements will also use augmented assignments where possible, i.e. a statement such as ``w = w + 1`` will be replaced by ``w += 1``. Also, statements involving booleans will have additional information added to them (see `Statement` for details) describing how the statement can be reformulated as a sequence of if/then statements. Calls `~brian2.codegen.optimisation.optimise_statements`. ''' code = strip_empty_lines(deindent(code)) lines = re.split(r'[;\n]', code) lines = [LineInfo(code=line) for line in lines if len(line)] # Do a copy so we can add stuff without altering the original dict variables = dict(variables) # we will do inference to work out which lines are := and which are = defined = set(k for k, v in variables.iteritems() if not isinstance(v, AuxiliaryVariable)) for line in lines: statement = None # parse statement into "var op expr" var, op, expr, comment = parse_statement(line.code) if op == '=': if var not in defined: op = ':=' defined.add(var) if var not in variables: is_scalar = is_scalar_expression(expr, variables) new_var = AuxiliaryVariable(var, dtype=dtype, scalar=is_scalar) variables[var] = new_var elif not variables[var].is_boolean: sympy_expr = str_to_sympy(expr, variables) sympy_var = sympy.Symbol(var, real=True) try: collected = sympy.collect(sympy_expr, sympy_var, exact=True, evaluate=False) except AttributeError: # If something goes wrong during collection, e.g. collect # does not work for logical expressions collected = {1: sympy_expr} if (len(collected) == 2 and set(collected.keys()) == {1, sympy_var} and collected[sympy_var] == 1): # We can replace this statement by a += assignment statement = Statement(var, '+=', sympy_to_str(collected[1]), comment, dtype=variables[var].dtype, scalar=variables[var].scalar) elif len(collected) == 1 and sympy_var in collected: # We can replace this statement by a *= assignment statement = Statement(var, '*=', sympy_to_str(collected[sympy_var]), comment, dtype=variables[var].dtype, scalar=variables[var].scalar) if statement is None: statement = Statement(var, op, expr, comment, dtype=variables[var].dtype, scalar=variables[var].scalar) line.statement = statement # for each line will give the variable being written to line.write = var # each line will give a set of variables which are read line.read = get_identifiers_recursively([expr], variables) # All writes to scalar variables must happen before writes to vector # variables scalar_write_done = False for line in lines: stmt = line.statement if stmt.op != ':=' and variables[ stmt.var].scalar and scalar_write_done: raise SyntaxError( ('All writes to scalar variables in a code block ' 'have to be made before writes to vector ' 'variables. Illegal write to %s.') % line.write) elif not variables[stmt.var].scalar: scalar_write_done = True # all variables which are written to at some point in the code block # used to determine whether they should be const or not all_write = set(line.write for line in lines) # backwards compute whether or not variables will be read again # note that will_read for a line gives the set of variables it will read # on the current line or subsequent ones. will_write gives the set of # variables that will be written after the current line will_read = set() will_write = set() for line in lines[::-1]: will_read = will_read.union(line.read) line.will_read = will_read.copy() line.will_write = will_write.copy() will_write.add(line.write) subexpressions = dict((name, val) for name, val in variables.items() if isinstance(val, Subexpression)) # sort subexpressions into an order so that subexpressions that don't depend # on other subexpressions are first subexpr_deps = dict((name, [dep for dep in subexpr.identifiers if dep in subexpressions]) for \ name, subexpr in subexpressions.items()) sorted_subexpr_vars = topsort(subexpr_deps) statements = [] # none are yet defined (or declared) subdefined = dict((name, False) for name in subexpressions.keys()) for line in lines: stmt = line.statement read = line.read write = line.write will_read = line.will_read will_write = line.will_write # update/define all subexpressions needed by this statement for var in sorted_subexpr_vars: if var not in read: continue subexpression = subexpressions[var] # if already defined/declared if subdefined[var]: op = '=' constant = False else: op = ':=' subdefined[var] = True # set to constant only if we will not write to it again constant = var not in will_write # check all subvariables are not written to again as well if constant: ids = subexpression.identifiers constant = all(v not in will_write for v in ids) statement = Statement(var, op, subexpression.expr, comment='', dtype=variables[var].dtype, constant=constant, subexpression=True, scalar=variables[var].scalar) statements.append(statement) var, op, expr, comment = stmt.var, stmt.op, stmt.expr, stmt.comment # constant only if we are declaring a new variable and we will not # write to it again constant = op == ':=' and var not in will_write statement = Statement(var, op, expr, comment, dtype=variables[var].dtype, constant=constant, scalar=variables[var].scalar) statements.append(statement) scalar_statements = [s for s in statements if s.scalar] vector_statements = [s for s in statements if not s.scalar] if optimise and prefs.codegen.loop_invariant_optimisations: scalar_statements, vector_statements = optimise_statements( scalar_statements, vector_statements, variables, blockname=blockname) return scalar_statements, vector_statements
def check_units_statements(code, variables): ''' Check the units for a series of statements. Setting a model variable has to use the correct unit. For newly introduced temporary variables, the unit is determined and used to check the following statements to ensure consistency. Parameters ---------- code : str The statements as a (multi-line) string variables : dict of `Variable` objects The information about all variables used in `code` (including `Constant` objects for external variables) Raises ------ KeyError In case on of the identifiers cannot be resolved. DimensionMismatchError If an unit mismatch occurs during the evaluation. ''' # Avoid a circular import from brian2.codegen.translation import analyse_identifiers known = set(variables.keys()) newly_defined, _, unknown = analyse_identifiers(code, known) if len(unknown): raise AssertionError(('Encountered unknown identifiers, this should ' 'not happen at this stage. Unkown identifiers: %s' % unknown)) code = re.split(r'[;\n]', code) for line in code: line = line.strip() if not len(line): continue # skip empty lines varname, op, expr, comment = parse_statement(line) if op in ('+=', '-=', '*=', '/=', '%='): # Replace statements such as "w *=2" by "w = w * 2" expr = '{var} {op_first} {expr}'.format(var=varname, op_first=op[0], expr=expr) op = '=' elif op == '=': pass else: raise AssertionError('Unknown operator "%s"' % op) expr_unit = parse_expression_unit(expr, variables) if varname in variables: fail_for_dimension_mismatch(variables[varname].unit, expr_unit, ('Code statement "%s" does not use ' 'correct units' % line)) elif varname in newly_defined: # note the unit for later variables[varname] = Variable(name=varname, unit=expr_unit, scalar=False) else: raise AssertionError(('Variable "%s" is neither in the variables ' 'dictionary nor in the list of undefined ' 'variables.' % varname))
def make_statements(code, variables, dtype, optimise=True): ''' Turn a series of abstract code statements into Statement objects, inferring whether each line is a set/declare operation, whether the variables are constant or not, and handling the cacheing of subexpressions. Parameters ---------- code : str A (multi-line) string of statements. variables : dict-like A dictionary of with `Variable` and `Function` objects for every identifier used in the `code`. dtype : `dtype` The data type to use for temporary variables optimise : bool, optional Whether to optimise expressions, including pulling out loop invariant expressions and putting them in new scalar constants. Defaults to ``False``, since this function is also used just to in contexts where we are not interested by this kind of optimisation. For the main code generation stage, its value is set by the `codegen.loop_invariant_optimisations` preference. Returns ------- scalar_statements, vector_statements : (list of `Statement`, list of `Statement`) Lists with statements that are to be executed once and statements that are to be executed once for every neuron/synapse/... (or in a vectorised way) Notes ----- If ``optimise`` is ``True``, then the ``scalar_statements`` may include newly introduced scalar constants that have been identified as loop-invariant and have therefore been pulled out of the vector statements. The resulting statements will also use augmented assignments where possible, i.e. a statement such as ``w = w + 1`` will be replaced by ``w += 1``. Also, statements involving booleans will have additional information added to them (see `Statement` for details) describing how the statement can be reformulated as a sequence of if/then statements. Calls `~brian2.codegen.optimisation.optimise_statements`. ''' code = strip_empty_lines(deindent(code)) lines = re.split(r'[;\n]', code) lines = [LineInfo(code=line) for line in lines if len(line)] # Do a copy so we can add stuff without altering the original dict variables = dict(variables) # we will do inference to work out which lines are := and which are = defined = set(k for k, v in variables.iteritems() if not isinstance(v, AuxiliaryVariable)) for line in lines: statement = None # parse statement into "var op expr" var, op, expr, comment = parse_statement(line.code) if op == '=': if var not in defined: op = ':=' defined.add(var) if var not in variables: is_scalar = is_scalar_expression(expr, variables) new_var = AuxiliaryVariable(var, Unit(1), # doesn't matter here dtype=dtype, scalar=is_scalar) variables[var] = new_var elif not variables[var].is_boolean: sympy_expr = str_to_sympy(expr, variables) sympy_var = sympy.Symbol(var, real=True) try: collected = sympy.collect(sympy_expr, sympy_var, exact=True, evaluate=False) except AttributeError: # If something goes wrong during collection, e.g. collect # does not work for logical expressions collected = {1: sympy_expr} if (len(collected) == 2 and set(collected.keys()) == {1, sympy_var} and collected[sympy_var] == 1): # We can replace this statement by a += assignment statement = Statement(var, '+=', sympy_to_str(collected[1]), comment, dtype=variables[var].dtype, scalar=variables[var].scalar) elif len(collected) == 1 and sympy_var in collected: # We can replace this statement by a *= assignment statement = Statement(var, '*=', sympy_to_str(collected[sympy_var]), comment, dtype=variables[var].dtype, scalar=variables[var].scalar) if statement is None: statement = Statement(var, op, expr, comment, dtype=variables[var].dtype, scalar=variables[var].scalar) line.statement = statement # for each line will give the variable being written to line.write = var # each line will give a set of variables which are read line.read = get_identifiers_recursively([expr], variables) # All writes to scalar variables must happen before writes to vector # variables scalar_write_done = False for line in lines: stmt = line.statement if stmt.op != ':=' and variables[stmt.var].scalar and scalar_write_done: raise SyntaxError(('All writes to scalar variables in a code block ' 'have to be made before writes to vector ' 'variables. Illegal write to %s.') % line.write) elif not variables[stmt.var].scalar: scalar_write_done = True # all variables which are written to at some point in the code block # used to determine whether they should be const or not all_write = set(line.write for line in lines) # backwards compute whether or not variables will be read again # note that will_read for a line gives the set of variables it will read # on the current line or subsequent ones. will_write gives the set of # variables that will be written after the current line will_read = set() will_write = set() for line in lines[::-1]: will_read = will_read.union(line.read) line.will_read = will_read.copy() line.will_write = will_write.copy() will_write.add(line.write) # generate cacheing statements for common subexpressions # cached subexpressions need to be recomputed whenever they are to be used # on the next line, and currently invalid (meaning that the current value # stored in the subexpression variable is no longer accurate because one # of the variables appearing in it has changed). All subexpressions start # as invalid, and are invalidated whenever one of the variables appearing # in the RHS changes value. subexpressions = dict((name, val) for name, val in variables.items() if isinstance(val, Subexpression)) # sort subexpressions into an order so that subexpressions that don't depend # on other subexpressions are first subexpr_deps = dict((name, [dep for dep in subexpr.identifiers if dep in subexpressions]) for \ name, subexpr in subexpressions.items()) sorted_subexpr_vars = topsort(subexpr_deps) statements = [] # all start as invalid valid = dict((name, False) for name in subexpressions.keys()) # none are yet defined (or declared) subdefined = dict((name, False) for name in subexpressions.keys()) for line in lines: stmt = line.statement read = line.read write = line.write will_read = line.will_read will_write = line.will_write # check that all subexpressions in expr are valid, and if not # add a definition/set its value, and set it to be valid # scan through in sorted order so that recursive subexpression dependencies # are handled in the right order for var in sorted_subexpr_vars: if var not in read: continue # if subexpression, and invalid if not valid.get(var, True): # all non-subexpressions are valid subexpression = subexpressions[var] # if already defined/declared if subdefined[var]: op = '=' constant = False else: op = ':=' subdefined[var] = True # set to constant only if we will not write to it again constant = var not in will_write # check all subvariables are not written to again as well if constant: ids = subexpression.identifiers constant = all(v not in will_write for v in ids) valid[var] = True statement = Statement(var, op, subexpression.expr, comment='', dtype=variables[var].dtype, constant=constant, subexpression=True, scalar=variables[var].scalar) statements.append(statement) var, op, expr, comment = stmt.var, stmt.op, stmt.expr, stmt.comment # invalidate any subexpressions including var, recursively # we do this by having a set of variables that are invalid that we # start with the changed var and increase by any subexpressions we # find that have a dependency on something in the invalid set. We # go through in sorted subexpression order so that the invalid set # is increased in the right order invalid = {var} for subvar in sorted_subexpr_vars: spec = subexpressions[subvar] spec_ids = set(spec.identifiers) if spec_ids.intersection(invalid): valid[subvar] = False invalid.add(subvar) # constant only if we are declaring a new variable and we will not # write to it again constant = op==':=' and var not in will_write statement = Statement(var, op, expr, comment, dtype=variables[var].dtype, constant=constant, scalar=variables[var].scalar) statements.append(statement) scalar_statements = [s for s in statements if s.scalar] vector_statements = [s for s in statements if not s.scalar] if optimise and prefs.codegen.loop_invariant_optimisations: scalar_statements, vector_statements = optimise_statements(scalar_statements, vector_statements, variables) return scalar_statements, vector_statements
def check_units_statements(code, variables): ''' Check the units for a series of statements. Setting a model variable has to use the correct unit. For newly introduced temporary variables, the unit is determined and used to check the following statements to ensure consistency. Parameters ---------- code : str The statements as a (multi-line) string variables : dict of `Variable` objects The information about all variables used in `code` (including `Constant` objects for external variables) Raises ------ KeyError In case on of the identifiers cannot be resolved. DimensionMismatchError If an unit mismatch occurs during the evaluation. ''' # Avoid a circular import from brian2.codegen.translation import analyse_identifiers known = set(variables.keys()) newly_defined, _, unknown = analyse_identifiers(code, known) if len(unknown): raise AssertionError(('Encountered unknown identifiers, this should ' 'not happen at this stage. Unkown identifiers: %s' % unknown)) code = re.split(r'[;\n]', code) for line in code: line = line.strip() if not len(line): continue # skip empty lines varname, op, expr, comment = parse_statement(line) if op in ('+=', '-=', '*=', '/=', '%='): # Replace statements such as "w *=2" by "w = w * 2" expr = '{var} {op_first} {expr}'.format(var=varname, op_first=op[0], expr=expr) op = '=' elif op == '=': pass else: raise AssertionError('Unknown operator "%s"' % op) expr_unit = parse_expression_unit(expr, variables) if varname in variables: expected_unit = variables[varname].unit fail_for_dimension_mismatch(expr_unit, expected_unit, ('The right-hand-side of code ' 'statement ""%s" does not have the ' 'expected unit %r') % (line, expected_unit)) elif varname in newly_defined: # note the unit for later variables[varname] = Variable(name=varname, unit=expr_unit, scalar=False) else: raise AssertionError(('Variable "%s" is neither in the variables ' 'dictionary nor in the list of undefined ' 'variables.' % varname))
def make_statements(code, variables, dtype): ''' Turn a series of abstract code statements into Statement objects, inferring whether each line is a set/declare operation, whether the variables are constant or not, and handling the cacheing of subexpressions. Parameters ---------- code : str A (multi-line) string of statements. variables : dict-like A dictionary of with `Variable` and `Function` objects for every identifier used in the `code`. dtype : `dtype` The data type to use for temporary variables Returns ------- scalar_statements, vector_statements : (list of `Statement`, list of `Statement`) Lists with statements that are to be executed once and statements that are to be executed once for every neuron/synapse/... (or in a vectorised way) Notes ----- The `scalar_statements` may include newly introduced scalar constants that have been identified as loop-invariant and have therefore been pulled out of the vector statements. ''' code = strip_empty_lines(deindent(code)) lines = re.split(r'[;\n]', code) lines = [LineInfo(code=line) for line in lines if len(line)] if DEBUG: print 'INPUT CODE:' print code # Do a copy so we can add stuff without altering the original dict variables = dict(variables) # we will do inference to work out which lines are := and which are = defined = set(k for k, v in variables.iteritems() if not isinstance(v, AuxiliaryVariable)) for line in lines: # parse statement into "var op expr" var, op, expr, comment = parse_statement(line.code) if op=='=': if var not in defined: op = ':=' defined.add(var) if var not in variables: is_scalar = is_scalar_expression(expr, variables) new_var = AuxiliaryVariable(var, Unit(1), # doesn't matter here dtype=dtype, scalar=is_scalar) variables[var] = new_var statement = Statement(var, op, expr, comment, dtype=variables[var].dtype, scalar=variables[var].scalar) line.statement = statement # for each line will give the variable being written to line.write = var # each line will give a set of variables which are read line.read = get_identifiers_recursively([expr], variables) # All writes to scalar variables must happen before writes to vector # variables scalar_write_done = False for line in lines: stmt = line.statement if stmt.op != ':=' and variables[stmt.var].scalar and scalar_write_done: raise SyntaxError(('All writes to scalar variables in a code block ' 'have to be made before writes to vector ' 'variables. Illegal write to %s.') % line.write) elif not variables[stmt.var].scalar: scalar_write_done = True if DEBUG: print 'PARSED STATEMENTS:' for line in lines: print line.statement, 'Read:'+str(line.read), 'Write:'+line.write # all variables which are written to at some point in the code block # used to determine whether they should be const or not all_write = set(line.write for line in lines) if DEBUG: print 'ALL WRITE:', all_write # backwards compute whether or not variables will be read again # note that will_read for a line gives the set of variables it will read # on the current line or subsequent ones. will_write gives the set of # variables that will be written after the current line will_read = set() will_write = set() for line in lines[::-1]: will_read = will_read.union(line.read) line.will_read = will_read.copy() line.will_write = will_write.copy() will_write.add(line.write) if DEBUG: print 'WILL READ/WRITE:' for line in lines: print line.statement, 'Read:'+str(line.will_read), 'Write:'+str(line.will_write) # generate cacheing statements for common subexpressions # cached subexpressions need to be recomputed whenever they are to be used # on the next line, and currently invalid (meaning that the current value # stored in the subexpression variable is no longer accurate because one # of the variables appearing in it has changed). All subexpressions start # as invalid, and are invalidated whenever one of the variables appearing # in the RHS changes value. subexpressions = dict((name, val) for name, val in variables.items() if isinstance(val, Subexpression)) # sort subexpressions into an order so that subexpressions that don't depend # on other subexpressions are first subexpr_deps = dict((name, [dep for dep in subexpr.identifiers if dep in subexpressions]) for \ name, subexpr in subexpressions.items()) sorted_subexpr_vars = topsort(subexpr_deps) if DEBUG: print 'SUBEXPRESSIONS:', subexpressions.keys() statements = [] # all start as invalid valid = dict((name, False) for name in subexpressions.keys()) # none are yet defined (or declared) subdefined = dict((name, False) for name in subexpressions.keys()) for line in lines: stmt = line.statement read = line.read write = line.write will_read = line.will_read will_write = line.will_write # check that all subexpressions in expr are valid, and if not # add a definition/set its value, and set it to be valid # scan through in sorted order so that recursive subexpression dependencies # are handled in the right order for var in sorted_subexpr_vars: if var not in read: continue # if subexpression, and invalid if not valid.get(var, True): # all non-subexpressions are valid subexpression = subexpressions[var] # if already defined/declared if subdefined[var]: op = '=' constant = False else: op = ':=' subdefined[var] = True # set to constant only if we will not write to it again constant = var not in will_write # check all subvariables are not written to again as well if constant: ids = subexpression.identifiers constant = all(v not in will_write for v in ids) valid[var] = True statement = Statement(var, op, subexpression.expr, comment='', dtype=variables[var].dtype, constant=constant, subexpression=True, scalar=variables[var].scalar) statements.append(statement) var, op, expr, comment = stmt.var, stmt.op, stmt.expr, stmt.comment # invalidate any subexpressions including var, recursively # we do this by having a set of variables that are invalid that we # start with the changed var and increase by any subexpressions we # find that have a dependency on something in the invalid set. We # go through in sorted subexpression order so that the invalid set # is increased in the right order invalid = {var} for subvar in sorted_subexpr_vars: spec = subexpressions[subvar] spec_ids = set(spec.identifiers) if spec_ids.intersection(invalid): valid[subvar] = False invalid.add(subvar) # constant only if we are declaring a new variable and we will not # write to it again constant = op==':=' and var not in will_write statement = Statement(var, op, expr, comment, dtype=variables[var].dtype, constant=constant, scalar=variables[var].scalar) statements.append(statement) if DEBUG: print 'OUTPUT STATEMENTS:' for stmt in statements: print stmt scalar_statements = [s for s in statements if s.scalar] vector_statements = [s for s in statements if not s.scalar] if prefs.codegen.loop_invariant_optimisations: scalar_constants, vector_statements = apply_loop_invariant_optimisations(vector_statements, variables, dtype) scalar_statements.extend(scalar_constants) return scalar_statements, vector_statements
def make_statements(code, variables, dtype): ''' Turn a series of abstract code statements into Statement objects, inferring whether each line is a set/declare operation, whether the variables are constant or not, and handling the cacheing of subexpressions. Returns a list of Statement objects. For arguments, see documentation for :func:`translate`. ''' code = strip_empty_lines(deindent(code)) lines = re.split(r'[;\n]', code) lines = [LineInfo(code=line) for line in lines if len(line)] if DEBUG: print 'INPUT CODE:' print code dtypes = dict((name, var.dtype) for name, var in variables.iteritems()) # we will do inference to work out which lines are := and which are = defined = set(variables.keys()) for line in lines: # parse statement into "var op expr" var, op, expr = parse_statement(line.code) if op == '=' and var not in defined: op = ':=' defined.add(var) if var not in dtypes: dtypes[var] = dtype statement = Statement(var, op, expr, dtypes[var]) line.statement = statement # for each line will give the variable being written to line.write = var # each line will give a set of variables which are read line.read = get_identifiers_recursively(expr, variables) if DEBUG: print 'PARSED STATEMENTS:' for line in lines: print line.statement, 'Read:' + str( line.read), 'Write:' + line.write # all variables which are written to at some point in the code block # used to determine whether they should be const or not all_write = set(line.write for line in lines) if DEBUG: print 'ALL WRITE:', all_write # backwards compute whether or not variables will be read again # note that will_read for a line gives the set of variables it will read # on the current line or subsequent ones. will_write gives the set of # variables that will be written after the current line will_read = set() will_write = set() for line in lines[::-1]: will_read = will_read.union(line.read) line.will_read = will_read.copy() line.will_write = will_write.copy() will_write.add(line.write) if DEBUG: print 'WILL READ/WRITE:' for line in lines: print line.statement, 'Read:' + str( line.will_read), 'Write:' + str(line.will_write) # generate cacheing statements for common subexpressions # cached subexpressions need to be recomputed whenever they are to be used # on the next line, and currently invalid (meaning that the current value # stored in the subexpression variable is no longer accurate because one # of the variables appearing in it has changed). All subexpressions start # as invalid, and are invalidated whenever one of the variables appearing # in the RHS changes value. #subexpressions = get_all_subexpressions() subexpressions = dict((name, val) for name, val in variables.items() if isinstance(val, Subexpression)) if DEBUG: print 'SUBEXPRESSIONS:', subexpressions.keys() statements = [] # all start as invalid valid = dict((name, False) for name in subexpressions.keys()) # none are yet defined (or declared) subdefined = dict((name, False) for name in subexpressions.keys()) for line in lines: stmt = line.statement read = line.read write = line.write will_read = line.will_read will_write = line.will_write # check that all subexpressions in expr are valid, and if not # add a definition/set its value, and set it to be valid for var in read: # if subexpression, and invalid if not valid.get(var, True): # all non-subexpressions are valid # if already defined/declared if subdefined[var]: op = '=' constant = False else: op = ':=' subdefined[var] = True dtypes[var] = dtype # default dtype # set to constant only if we will not write to it again constant = var not in will_write # check all subvariables are not written to again as well if constant: ids = subexpressions[var].identifiers constant = all(v not in will_write for v in ids) valid[var] = True statement = Statement(var, op, subexpressions[var].expr, dtype, constant=constant, subexpression=True) statements.append(statement) var, op, expr = stmt.var, stmt.op, stmt.expr # invalidate any subexpressions including var for subvar, spec in subexpressions.items(): if var in spec.identifiers: valid[subvar] = False # constant only if we are declaring a new variable and we will not # write to it again constant = op == ':=' and var not in will_write statement = Statement(var, op, expr, dtypes[var], constant=constant) statements.append(statement) if DEBUG: print 'OUTPUT STATEMENTS:' for stmt in statements: print stmt return statements
def make_statements(code, variables, dtype): ''' Turn a series of abstract code statements into Statement objects, inferring whether each line is a set/declare operation, whether the variables are constant or not, and handling the cacheing of subexpressions. Returns a list of Statement objects. For arguments, see documentation for :func:`translate`. ''' code = strip_empty_lines(deindent(code)) lines = re.split(r'[;\n]', code) lines = [LineInfo(code=line) for line in lines if len(line)] if DEBUG: print 'INPUT CODE:' print code dtypes = dict((name, var.dtype) for name, var in variables.iteritems() if not isinstance(var, Function)) # we will do inference to work out which lines are := and which are = defined = set(k for k, v in variables.iteritems() if not isinstance(v, AuxiliaryVariable)) scalars = set(k for k,v in variables.iteritems() if getattr(v, 'scalar', False)) for line in lines: # parse statement into "var op expr" var, op, expr = parse_statement(line.code) if op=='=': if var not in defined: op = ':=' defined.add(var) if var not in dtypes: dtypes[var] = dtype # determine whether this is a scalar variable identifiers = get_identifiers_recursively(expr, variables) # In the following we assume that all unknown identifiers are # scalar constants -- this should cover numerical literals and # e.g. "True" or "inf". is_scalar = all((name in scalars) or not (name in defined) for name in identifiers) if is_scalar: scalars.add(var) statement = Statement(var, op, expr, dtypes[var], scalar=var in scalars) line.statement = statement # for each line will give the variable being written to line.write = var # each line will give a set of variables which are read line.read = get_identifiers_recursively(expr, variables) # All writes to scalar variables must happen before writes to vector # variables scalar_write_done = False for line in lines: stmt = line.statement if stmt.op != ':=' and stmt.var in scalars and scalar_write_done: raise SyntaxError(('All writes to scalar variables in a code block ' 'have to be made before writes to vector ' 'variables. Illegal write to %s.') % line.write) elif not stmt.var in scalars: scalar_write_done = True if DEBUG: print 'PARSED STATEMENTS:' for line in lines: print line.statement, 'Read:'+str(line.read), 'Write:'+line.write # all variables which are written to at some point in the code block # used to determine whether they should be const or not all_write = set(line.write for line in lines) if DEBUG: print 'ALL WRITE:', all_write # backwards compute whether or not variables will be read again # note that will_read for a line gives the set of variables it will read # on the current line or subsequent ones. will_write gives the set of # variables that will be written after the current line will_read = set() will_write = set() for line in lines[::-1]: will_read = will_read.union(line.read) line.will_read = will_read.copy() line.will_write = will_write.copy() will_write.add(line.write) if DEBUG: print 'WILL READ/WRITE:' for line in lines: print line.statement, 'Read:'+str(line.will_read), 'Write:'+str(line.will_write) # generate cacheing statements for common subexpressions # cached subexpressions need to be recomputed whenever they are to be used # on the next line, and currently invalid (meaning that the current value # stored in the subexpression variable is no longer accurate because one # of the variables appearing in it has changed). All subexpressions start # as invalid, and are invalidated whenever one of the variables appearing # in the RHS changes value. subexpressions = dict((name, val) for name, val in variables.items() if isinstance(val, Subexpression)) # sort subexpressions into an order so that subexpressions that don't depend # on other subexpressions are first subexpr_deps = dict((name, [dep for dep in subexpr.identifiers if dep in subexpressions]) for \ name, subexpr in subexpressions.items()) sorted_subexpr_vars = topsort(subexpr_deps) if DEBUG: print 'SUBEXPRESSIONS:', subexpressions.keys() statements = [] # all start as invalid valid = dict((name, False) for name in subexpressions.keys()) # none are yet defined (or declared) subdefined = dict((name, False) for name in subexpressions.keys()) for line in lines: stmt = line.statement read = line.read write = line.write will_read = line.will_read will_write = line.will_write # check that all subexpressions in expr are valid, and if not # add a definition/set its value, and set it to be valid # scan through in sorted order so that recursive subexpression dependencies # are handled in the right order for var in sorted_subexpr_vars: if var not in read: continue # if subexpression, and invalid if not valid.get(var, True): # all non-subexpressions are valid subexpression = subexpressions[var] # if already defined/declared if subdefined[var]: op = '=' constant = False else: op = ':=' subdefined[var] = True dtypes[var] = variables[var].dtype # set to constant only if we will not write to it again constant = var not in will_write # check all subvariables are not written to again as well if constant: ids = subexpression.identifiers constant = all(v not in will_write for v in ids) valid[var] = True statement = Statement(var, op, subexpression.expr, variables[var].dtype, constant=constant, subexpression=True, scalar=var in scalars) statements.append(statement) var, op, expr = stmt.var, stmt.op, stmt.expr # invalidate any subexpressions including var, recursively # we do this by having a set of variables that are invalid that we # start with the changed var and increase by any subexpressions we # find that have a dependency on something in the invalid set. We # go through in sorted subexpression order so that the invalid set # is increased in the right order invalid = set([var]) for subvar in sorted_subexpr_vars: spec = subexpressions[subvar] spec_ids = set(spec.identifiers) if spec_ids.intersection(invalid): valid[subvar] = False invalid.add(subvar) # constant only if we are declaring a new variable and we will not # write to it again constant = op==':=' and var not in will_write statement = Statement(var, op, expr, dtypes[var], constant=constant, scalar=var in scalars) statements.append(statement) if DEBUG: print 'OUTPUT STATEMENTS:' for stmt in statements: print stmt return statements
def make_statements(code, variables, dtype): ''' Turn a series of abstract code statements into Statement objects, inferring whether each line is a set/declare operation, whether the variables are constant or not, and handling the cacheing of subexpressions. Parameters ---------- code : str A (multi-line) string of statements. variables : dict-like A dictionary of with `Variable` and `Function` objects for every identifier used in the `code`. dtype : `dtype` The data type to use for temporary variables Returns ------- scalar_statements, vector_statements : (list of `Statement`, list of `Statement`) Lists with statements that are to be executed once and statements that are to be executed once for every neuron/synapse/... (or in a vectorised way) Notes ----- The `scalar_statements` may include newly introduced scalar constants that have been identified as loop-invariant and have therefore been pulled out of the vector statements. The resulting statements will also use augmented assignments where possible, i.e. a statement such as ``w = w + 1`` will be replaced by ``w += 1``. ''' code = strip_empty_lines(deindent(code)) lines = re.split(r'[;\n]', code) lines = [LineInfo(code=line) for line in lines if len(line)] if DEBUG: print 'INPUT CODE:' print code # Do a copy so we can add stuff without altering the original dict variables = dict(variables) # we will do inference to work out which lines are := and which are = defined = set(k for k, v in variables.iteritems() if not isinstance(v, AuxiliaryVariable)) for line in lines: statement = None # parse statement into "var op expr" var, op, expr, comment = parse_statement(line.code) if op == '=': if var not in defined: op = ':=' defined.add(var) if var not in variables: is_scalar = is_scalar_expression(expr, variables) new_var = AuxiliaryVariable( var, Unit(1), # doesn't matter here dtype=dtype, scalar=is_scalar) variables[var] = new_var elif not variables[var].is_boolean: sympy_expr = str_to_sympy(expr) sympy_var = sympy.Symbol(var, real=True) try: collected = sympy.collect(sympy_expr, sympy_var, exact=True, evaluate=False) except AttributeError: # If something goes wrong during collection, e.g. collect # does not work for logical expressions collected = {1: sympy_expr} if (len(collected) == 2 and set(collected.keys()) == {1, sympy_var} and collected[sympy_var] == 1): # We can replace this statement by a += assignment statement = Statement(var, '+=', sympy_to_str(collected[1]), comment, dtype=variables[var].dtype, scalar=variables[var].scalar) elif len(collected) == 1 and sympy_var in collected: # We can replace this statement by a *= assignment statement = Statement(var, '*=', sympy_to_str(collected[sympy_var]), comment, dtype=variables[var].dtype, scalar=variables[var].scalar) if statement is None: statement = Statement(var, op, expr, comment, dtype=variables[var].dtype, scalar=variables[var].scalar) line.statement = statement # for each line will give the variable being written to line.write = var # each line will give a set of variables which are read line.read = get_identifiers_recursively([expr], variables) # All writes to scalar variables must happen before writes to vector # variables scalar_write_done = False for line in lines: stmt = line.statement if stmt.op != ':=' and variables[ stmt.var].scalar and scalar_write_done: raise SyntaxError( ('All writes to scalar variables in a code block ' 'have to be made before writes to vector ' 'variables. Illegal write to %s.') % line.write) elif not variables[stmt.var].scalar: scalar_write_done = True if DEBUG: print 'PARSED STATEMENTS:' for line in lines: print line.statement, 'Read:' + str( line.read), 'Write:' + line.write # all variables which are written to at some point in the code block # used to determine whether they should be const or not all_write = set(line.write for line in lines) if DEBUG: print 'ALL WRITE:', all_write # backwards compute whether or not variables will be read again # note that will_read for a line gives the set of variables it will read # on the current line or subsequent ones. will_write gives the set of # variables that will be written after the current line will_read = set() will_write = set() for line in lines[::-1]: will_read = will_read.union(line.read) line.will_read = will_read.copy() line.will_write = will_write.copy() will_write.add(line.write) if DEBUG: print 'WILL READ/WRITE:' for line in lines: print line.statement, 'Read:' + str( line.will_read), 'Write:' + str(line.will_write) # generate cacheing statements for common subexpressions # cached subexpressions need to be recomputed whenever they are to be used # on the next line, and currently invalid (meaning that the current value # stored in the subexpression variable is no longer accurate because one # of the variables appearing in it has changed). All subexpressions start # as invalid, and are invalidated whenever one of the variables appearing # in the RHS changes value. subexpressions = dict((name, val) for name, val in variables.items() if isinstance(val, Subexpression)) # sort subexpressions into an order so that subexpressions that don't depend # on other subexpressions are first subexpr_deps = dict((name, [dep for dep in subexpr.identifiers if dep in subexpressions]) for \ name, subexpr in subexpressions.items()) sorted_subexpr_vars = topsort(subexpr_deps) if DEBUG: print 'SUBEXPRESSIONS:', subexpressions.keys() statements = [] # all start as invalid valid = dict((name, False) for name in subexpressions.keys()) # none are yet defined (or declared) subdefined = dict((name, False) for name in subexpressions.keys()) for line in lines: stmt = line.statement read = line.read write = line.write will_read = line.will_read will_write = line.will_write # check that all subexpressions in expr are valid, and if not # add a definition/set its value, and set it to be valid # scan through in sorted order so that recursive subexpression dependencies # are handled in the right order for var in sorted_subexpr_vars: if var not in read: continue # if subexpression, and invalid if not valid.get(var, True): # all non-subexpressions are valid subexpression = subexpressions[var] # if already defined/declared if subdefined[var]: op = '=' constant = False else: op = ':=' subdefined[var] = True # set to constant only if we will not write to it again constant = var not in will_write # check all subvariables are not written to again as well if constant: ids = subexpression.identifiers constant = all(v not in will_write for v in ids) valid[var] = True statement = Statement(var, op, subexpression.expr, comment='', dtype=variables[var].dtype, constant=constant, subexpression=True, scalar=variables[var].scalar) statements.append(statement) var, op, expr, comment = stmt.var, stmt.op, stmt.expr, stmt.comment # invalidate any subexpressions including var, recursively # we do this by having a set of variables that are invalid that we # start with the changed var and increase by any subexpressions we # find that have a dependency on something in the invalid set. We # go through in sorted subexpression order so that the invalid set # is increased in the right order invalid = {var} for subvar in sorted_subexpr_vars: spec = subexpressions[subvar] spec_ids = set(spec.identifiers) if spec_ids.intersection(invalid): valid[subvar] = False invalid.add(subvar) # constant only if we are declaring a new variable and we will not # write to it again constant = op == ':=' and var not in will_write statement = Statement(var, op, expr, comment, dtype=variables[var].dtype, constant=constant, scalar=variables[var].scalar) statements.append(statement) if DEBUG: print 'OUTPUT STATEMENTS:' for stmt in statements: print stmt scalar_statements = [s for s in statements if s.scalar] vector_statements = [s for s in statements if not s.scalar] if prefs.codegen.loop_invariant_optimisations: scalar_constants, vector_statements = apply_loop_invariant_optimisations( vector_statements, variables, dtype) scalar_statements.extend(scalar_constants) return scalar_statements, vector_statements
def make_statements(code, variables, dtype): ''' Turn a series of abstract code statements into Statement objects, inferring whether each line is a set/declare operation, whether the variables are constant or not, and handling the cacheing of subexpressions. Returns a list of Statement objects. For arguments, see documentation for :func:`translate`. ''' code = strip_empty_lines(deindent(code)) lines = re.split(r'[;\n]', code) lines = [LineInfo(code=line) for line in lines if len(line)] if DEBUG: print 'INPUT CODE:' print code dtypes = dict((name, var.dtype) for name, var in variables.iteritems() if not isinstance(var, Function)) # we will do inference to work out which lines are := and which are = defined = set(k for k, v in variables.iteritems() if not isinstance(v, AuxiliaryVariable)) scalars = set(k for k, v in variables.iteritems() if getattr(v, 'scalar', False)) for line in lines: # parse statement into "var op expr" var, op, expr, comment = parse_statement(line.code) if op == '=': if var not in defined: op = ':=' defined.add(var) if var not in dtypes: dtypes[var] = dtype # determine whether this is a scalar variable identifiers = get_identifiers_recursively([expr], variables) # In the following we assume that all unknown identifiers are # scalar constants -- this should cover numerical literals and # e.g. "True" or "inf". is_scalar = all((name in scalars) or not (name in defined) for name in identifiers) if is_scalar: scalars.add(var) statement = Statement(var, op, expr, comment, dtype=dtypes[var], scalar=var in scalars) line.statement = statement # for each line will give the variable being written to line.write = var # each line will give a set of variables which are read line.read = get_identifiers_recursively([expr], variables) # All writes to scalar variables must happen before writes to vector # variables scalar_write_done = False for line in lines: stmt = line.statement if stmt.op != ':=' and stmt.var in scalars and scalar_write_done: raise SyntaxError( ('All writes to scalar variables in a code block ' 'have to be made before writes to vector ' 'variables. Illegal write to %s.') % line.write) elif not stmt.var in scalars: scalar_write_done = True if DEBUG: print 'PARSED STATEMENTS:' for line in lines: print line.statement, 'Read:' + str( line.read), 'Write:' + line.write # all variables which are written to at some point in the code block # used to determine whether they should be const or not all_write = set(line.write for line in lines) if DEBUG: print 'ALL WRITE:', all_write # backwards compute whether or not variables will be read again # note that will_read for a line gives the set of variables it will read # on the current line or subsequent ones. will_write gives the set of # variables that will be written after the current line will_read = set() will_write = set() for line in lines[::-1]: will_read = will_read.union(line.read) line.will_read = will_read.copy() line.will_write = will_write.copy() will_write.add(line.write) if DEBUG: print 'WILL READ/WRITE:' for line in lines: print line.statement, 'Read:' + str( line.will_read), 'Write:' + str(line.will_write) # generate cacheing statements for common subexpressions # cached subexpressions need to be recomputed whenever they are to be used # on the next line, and currently invalid (meaning that the current value # stored in the subexpression variable is no longer accurate because one # of the variables appearing in it has changed). All subexpressions start # as invalid, and are invalidated whenever one of the variables appearing # in the RHS changes value. subexpressions = dict((name, val) for name, val in variables.items() if isinstance(val, Subexpression)) # sort subexpressions into an order so that subexpressions that don't depend # on other subexpressions are first subexpr_deps = dict((name, [dep for dep in subexpr.identifiers if dep in subexpressions]) for \ name, subexpr in subexpressions.items()) sorted_subexpr_vars = topsort(subexpr_deps) if DEBUG: print 'SUBEXPRESSIONS:', subexpressions.keys() statements = [] # all start as invalid valid = dict((name, False) for name in subexpressions.keys()) # none are yet defined (or declared) subdefined = dict((name, False) for name in subexpressions.keys()) for line in lines: stmt = line.statement read = line.read write = line.write will_read = line.will_read will_write = line.will_write # check that all subexpressions in expr are valid, and if not # add a definition/set its value, and set it to be valid # scan through in sorted order so that recursive subexpression dependencies # are handled in the right order for var in sorted_subexpr_vars: if var not in read: continue # if subexpression, and invalid if not valid.get(var, True): # all non-subexpressions are valid subexpression = subexpressions[var] # if already defined/declared if subdefined[var]: op = '=' constant = False else: op = ':=' subdefined[var] = True dtypes[var] = variables[var].dtype # set to constant only if we will not write to it again constant = var not in will_write # check all subvariables are not written to again as well if constant: ids = subexpression.identifiers constant = all(v not in will_write for v in ids) valid[var] = True statement = Statement(var, op, subexpression.expr, comment='', dtype=variables[var].dtype, constant=constant, subexpression=True, scalar=var in scalars) statements.append(statement) var, op, expr, comment = stmt.var, stmt.op, stmt.expr, stmt.comment # invalidate any subexpressions including var, recursively # we do this by having a set of variables that are invalid that we # start with the changed var and increase by any subexpressions we # find that have a dependency on something in the invalid set. We # go through in sorted subexpression order so that the invalid set # is increased in the right order invalid = set([var]) for subvar in sorted_subexpr_vars: spec = subexpressions[subvar] spec_ids = set(spec.identifiers) if spec_ids.intersection(invalid): valid[subvar] = False invalid.add(subvar) # constant only if we are declaring a new variable and we will not # write to it again constant = op == ':=' and var not in will_write statement = Statement(var, op, expr, comment, dtype=dtypes[var], constant=constant, scalar=var in scalars) statements.append(statement) if DEBUG: print 'OUTPUT STATEMENTS:' for stmt in statements: print stmt return statements
def make_statements(code, variables, dtype): ''' Turn a series of abstract code statements into Statement objects, inferring whether each line is a set/declare operation, whether the variables are constant or not, and handling the cacheing of subexpressions. Returns a list of Statement objects. For arguments, see documentation for :func:`translate`. ''' code = strip_empty_lines(deindent(code)) lines = re.split(r'[;\n]', code) lines = [LineInfo(code=line) for line in lines if len(line)] if DEBUG: print 'INPUT CODE:' print code dtypes = dict((name, var.dtype) for name, var in variables.iteritems() if not isinstance(var, Function)) # we will do inference to work out which lines are := and which are = defined = set(k for k, v in variables.iteritems() if not isinstance(v, AuxiliaryVariable)) for line in lines: # parse statement into "var op expr" var, op, expr = parse_statement(line.code) if op=='=' and var not in defined: op = ':=' defined.add(var) if var not in dtypes: dtypes[var] = dtype statement = Statement(var, op, expr, dtypes[var]) line.statement = statement # for each line will give the variable being written to line.write = var # each line will give a set of variables which are read line.read = get_identifiers_recursively(expr, variables) if DEBUG: print 'PARSED STATEMENTS:' for line in lines: print line.statement, 'Read:'+str(line.read), 'Write:'+line.write # all variables which are written to at some point in the code block # used to determine whether they should be const or not all_write = set(line.write for line in lines) if DEBUG: print 'ALL WRITE:', all_write # backwards compute whether or not variables will be read again # note that will_read for a line gives the set of variables it will read # on the current line or subsequent ones. will_write gives the set of # variables that will be written after the current line will_read = set() will_write = set() for line in lines[::-1]: will_read = will_read.union(line.read) line.will_read = will_read.copy() line.will_write = will_write.copy() will_write.add(line.write) if DEBUG: print 'WILL READ/WRITE:' for line in lines: print line.statement, 'Read:'+str(line.will_read), 'Write:'+str(line.will_write) # generate cacheing statements for common subexpressions # cached subexpressions need to be recomputed whenever they are to be used # on the next line, and currently invalid (meaning that the current value # stored in the subexpression variable is no longer accurate because one # of the variables appearing in it has changed). All subexpressions start # as invalid, and are invalidated whenever one of the variables appearing # in the RHS changes value. #subexpressions = get_all_subexpressions() subexpressions = dict((name, val) for name, val in variables.items() if isinstance(val, Subexpression)) subexpressions = translate_subexpressions(subexpressions, variables) if DEBUG: print 'SUBEXPRESSIONS:', subexpressions.keys() statements = [] # all start as invalid valid = dict((name, False) for name in subexpressions.keys()) # none are yet defined (or declared) subdefined = dict((name, False) for name in subexpressions.keys()) for line in lines: stmt = line.statement read = line.read write = line.write will_read = line.will_read will_write = line.will_write # check that all subexpressions in expr are valid, and if not # add a definition/set its value, and set it to be valid for var in read: # if subexpression, and invalid if not valid.get(var, True): # all non-subexpressions are valid # if already defined/declared if subdefined[var]: op = '=' constant = False else: op = ':=' subdefined[var] = True dtypes[var] = variables[var].dtype # set to constant only if we will not write to it again constant = var not in will_write # check all subvariables are not written to again as well if constant: ids = subexpressions[var].identifiers constant = all(v not in will_write for v in ids) valid[var] = True statement = Statement(var, op, subexpressions[var].expr, variables[var].dtype, constant=constant, subexpression=True) statements.append(statement) var, op, expr = stmt.var, stmt.op, stmt.expr # invalidate any subexpressions including var for subvar, spec in subexpressions.items(): if var in spec.identifiers: valid[subvar] = False # constant only if we are declaring a new variable and we will not # write to it again constant = op==':=' and var not in will_write statement = Statement(var, op, expr, dtypes[var], constant=constant) statements.append(statement) if DEBUG: print 'OUTPUT STATEMENTS:' for stmt in statements: print stmt return statements
def check_units_statements(code, variables): """ Check the units for a series of statements. Setting a model variable has to use the correct unit. For newly introduced temporary variables, the unit is determined and used to check the following statements to ensure consistency. Parameters ---------- code : str The statements as a (multi-line) string variables : dict of `Variable` objects The information about all variables used in `code` (including `Constant` objects for external variables) Raises ------ KeyError In case on of the identifiers cannot be resolved. DimensionMismatchError If an unit mismatch occurs during the evaluation. """ variables = dict(variables) # Avoid a circular import from brian2.codegen.translation import analyse_identifiers newly_defined, _, unknown = analyse_identifiers(code, variables) if len(unknown): raise AssertionError( f"Encountered unknown identifiers, this should not " f"happen at this stage. Unknown identifiers: {unknown}") code = re.split(r'[;\n]', code) for line in code: line = line.strip() if not len(line): continue # skip empty lines varname, op, expr, comment = parse_statement(line) if op in ('+=', '-=', '*=', '/=', '%='): # Replace statements such as "w *=2" by "w = w * 2" expr = f'{varname} {op[0]} {expr}' elif op == '=': pass else: raise AssertionError(f'Unknown operator "{op}"') expr_unit = parse_expression_dimensions(expr, variables) if varname in variables: expected_unit = variables[varname].dim fail_for_dimension_mismatch(expr_unit, expected_unit, ('The right-hand-side of code ' 'statement "%s" does not have the ' 'expected unit {expected}') % line, expected=expected_unit) elif varname in newly_defined: # note the unit for later variables[varname] = Variable(name=varname, dimensions=get_dimensions(expr_unit), scalar=False) else: raise AssertionError( f"Variable '{varname}' is neither in the variables " f"dictionary nor in the list of undefined " f"variables.")
def check_units_statements(code, namespace, variables): ''' Check the units for a series of statements. Setting a model variable has to use the correct unit. For newly introduced temporary variables, the unit is determined and used to check the following statements to ensure consistency. Parameters ---------- expression : str The expression to evaluate. namespace : dict-like The namespace of external variables. variables : dict of `Variable` objects The information about the internal variables Raises ------ KeyError In case on of the identifiers cannot be resolved. DimensionMismatchError If an unit mismatch occurs during the evaluation. ''' known = set(variables.keys()) | set(namespace.keys()) newly_defined, _, unknown = analyse_identifiers(code, known) if len(unknown): raise AssertionError( ('Encountered unknown identifiers, this should ' 'not happen at this stage. Unkown identifiers: %s' % unknown)) # We want to add newly defined variables to the variables dictionary so we # make a copy now variables = dict(variables) code = re.split(r'[;\n]', code) for line in code: line = line.strip() if not len(line): continue # skip empty lines varname, op, expr = parse_statement(line) if op in ('+=', '-=', '*=', '/=', '%='): # Replace statements such as "w *=2" by "w = w * 2" expr = '{var} {op_first} {expr}'.format(var=varname, op_first=op[0], expr=expr) op = '=' elif op == '=': pass else: raise AssertionError('Unknown operator "%s"' % op) expr_unit = parse_expression_unit(expr, namespace, variables) if varname in variables: fail_for_dimension_mismatch(variables[varname].unit, expr_unit, ('Code statement "%s" does not use ' 'correct units' % line)) elif varname in newly_defined: # note the unit for later variables[varname] = Variable(expr_unit, is_bool=False, scalar=False) else: raise AssertionError(('Variable "%s" is neither in the variables ' 'dictionary nor in the list of undefined ' 'variables.' % varname))
def check_units_statements(code, namespace, variables): ''' Check the units for a series of statements. Setting a model variable has to use the correct unit. For newly introduced temporary variables, the unit is determined and used to check the following statements to ensure consistency. Parameters ---------- expression : str The expression to evaluate. namespace : dict-like The namespace of external variables. variables : dict of `Variable` objects The information about the internal variables Raises ------ KeyError In case on of the identifiers cannot be resolved. DimensionMismatchError If an unit mismatch occurs during the evaluation. ''' known = set(variables.keys()) | set(namespace.keys()) newly_defined, _, unknown = analyse_identifiers(code, known) if len(unknown): raise AssertionError(('Encountered unknown identifiers, this should ' 'not happen at this stage. Unkown identifiers: %s' % unknown)) # We want to add newly defined variables to the variables dictionary so we # make a copy now variables = dict(variables) code = re.split(r'[;\n]', code) for line in code: line = line.strip() if not len(line): continue # skip empty lines varname, op, expr = parse_statement(line) if op in ('+=', '-=', '*=', '/=', '%='): # Replace statements such as "w *=2" by "w = w * 2" expr = '{var} {op_first} {expr}'.format(var=varname, op_first=op[0], expr=expr) op = '=' elif op == '=': pass else: raise AssertionError('Unknown operator "%s"' % op) expr_unit = parse_expression_unit(expr, namespace, variables) if varname in variables: fail_for_dimension_mismatch(variables[varname].unit, expr_unit, ('Code statement "%s" does not use ' 'correct units' % line)) elif varname in newly_defined: # note the unit for later variables[varname] = Variable(expr_unit, is_bool=False, scalar=False) else: raise AssertionError(('Variable "%s" is neither in the variables ' 'dictionary nor in the list of undefined ' 'variables.' % varname))