def __init__(self, template, template_source): self.template = template self.template_source = template_source #: The set of variables in this template self.variables = set([]) #: The indices over which the template iterates completely self.iterate_all = set([]) #: Read-only variables that are changed by this template self.writes_read_only = set([]) # This is the bit inside {} for USES_VARIABLES { list of words } specifier_blocks = re.findall(r'\bUSES_VARIABLES\b\s*\{(.*?)\}', template_source, re.M | re.S) # Same for ITERATE_ALL iterate_all_blocks = re.findall(r'\bITERATE_ALL\b\s*\{(.*?)\}', template_source, re.M | re.S) # And for WRITES_TO_READ_ONLY_VARIABLES writes_read_only_blocks = re.findall( r'\bWRITES_TO_READ_ONLY_VARIABLES\b\s*\{(.*?)\}', template_source, re.M | re.S) #: Does this template allow writing to scalar variables? self.allows_scalar_write = 'ALLOWS_SCALAR_WRITE' in template_source for block in specifier_blocks: self.variables.update(get_identifiers(block)) for block in iterate_all_blocks: self.iterate_all.update(get_identifiers(block)) for block in writes_read_only_blocks: self.writes_read_only.update(get_identifiers(block))
def update_abstract_code(self, run_namespace): code = self.group.events[self.event] # Raise a useful error message when the user used a angela1 syntax if not isinstance(code, str): if isinstance(code, Quantity): t = 'a quantity' else: t = '%s' % type(code) error_msg = 'Threshold condition has to be a string, not %s.' % t if self.event == 'spike': try: vm_var = _guess_membrane_potential(self.group.equations) except AttributeError: # not a group with equations... vm_var = None if vm_var is not None: error_msg += " Probably you intended to use '%s > ...'?" % vm_var raise TypeError(error_msg) self.user_code = '_cond = ' + code identifiers = get_identifiers(code) variables = self.group.resolve_all(identifiers, run_namespace, user_identifiers=identifiers) if not is_boolean_expression(code, variables): raise TypeError(('Threshold condition "%s" is not a boolean ' 'expression') % code) if self.group._refractory is False or self.event != 'spike': self.abstract_code = '_cond = %s' % code else: self.abstract_code = '_cond = (%s) and not_refractory' % code
def get_identifiers_recursively(expressions, variables, include_numbers=False): ''' Gets all the identifiers in a list of expressions, recursing down into subexpressions. Parameters ---------- expressions : list of str List of expressions to check. variables : dict-like Dictionary of `Variable` objects include_numbers : bool, optional Whether to include number literals in the output. Defaults to ``False``. ''' if len(expressions): identifiers = set.union(*[ get_identifiers(expr, include_numbers=include_numbers) for expr in expressions ]) else: identifiers = set() for name in set(identifiers): if name in variables and isinstance(variables[name], Subexpression): s_identifiers = get_identifiers_recursively( [variables[name].expr], variables, include_numbers=include_numbers) identifiers |= s_identifiers return identifiers
def array_read_write(self, statements): ''' Helper function, gives the set of ArrayVariables that are read from and written to in the series of statements. Returns the pair read, write of sets of variable names. ''' variables = self.variables variable_indices = self.variable_indices read = set() write = set() for stmt in statements: ids = get_identifiers(stmt.expr) # if the operation is inplace this counts as a read. if stmt.inplace: ids.add(stmt.var) read = read.union(ids) if stmt.scalar or variable_indices[stmt.var] == '0': if stmt.op != ':=' and not self.allows_scalar_write: raise SyntaxError( ('Writing to scalar variable %s ' 'not allowed in this context.' % stmt.var)) for name in ids: if (name in variables and isinstance(variables[name], ArrayVariable) and not (variables[name].scalar or variable_indices[name] == '0')): raise SyntaxError( ('Cannot write to scalar variable %s ' 'with an expression referring to ' 'vector variable %s') % (stmt.var, name)) write.add(stmt.var) read = set(varname for varname, var in list(variables.items()) if isinstance(var, ArrayVariable) and varname in read) write = set(varname for varname, var in list(variables.items()) if isinstance(var, ArrayVariable) and varname in write) # Gather the indices stored as arrays (ignore _idx which is special) indices = set() indices |= set( variable_indices[varname] for varname in read if not variable_indices[varname] in ('_idx', '0') and isinstance( variables[variable_indices[varname]], ArrayVariable)) indices |= set( variable_indices[varname] for varname in write if not variable_indices[varname] in ('_idx', '0') and isinstance( variables[variable_indices[varname]], ArrayVariable)) # don't list arrays that are read explicitly and used as indices twice read -= indices return read, write, indices
def ufunc_at_vectorisation(self, statement, variables, indices, conditional_write_vars, created_vars, used_variables): if not self._use_ufunc_at_vectorisation: raise VectorisationError() # Avoids circular import from angela2.devices.device import device # See https://github.com/angela-team/angela2/pull/531 for explanation used = set(get_identifiers(statement.expr)) used = used.intersection(k for k in list(variables.keys()) if k in indices and indices[k]!='_idx') used_variables.update(used) if statement.var in used_variables: raise VectorisationError() expr = NumpyNodeRenderer(auto_vectorise=self.auto_vectorise).render_expr(statement.expr) if statement.op == ':=' or indices[statement.var] == '_idx' or not statement.inplace: if statement.op == ':=': op = '=' else: op = statement.op line = '{var} {op} {expr}'.format(var=statement.var, op=op, expr=expr) elif statement.inplace: if statement.op == '+=': ufunc_name = '_numpy.add' elif statement.op == '*=': ufunc_name = '_numpy.multiply' elif statement.op == '/=': ufunc_name = '_numpy.divide' elif statement.op == '-=': ufunc_name = '_numpy.subtract' else: raise VectorisationError() line = '{ufunc_name}.at({array_name}, {idx}, {expr})'.format( ufunc_name=ufunc_name, array_name=device.get_array_name(variables[statement.var]), idx=indices[statement.var], expr=expr) line = self.conditional_write(line, statement, variables, conditional_write_vars=conditional_write_vars, created_vars=created_vars) else: raise VectorisationError() if len(statement.comment): line += ' # ' + statement.comment return line
def before_run(self, run_namespace=None): rates_var = self.variables['rates'] if isinstance(rates_var, Subexpression): # Check that the units of the expression make sense expr = rates_var.expr identifiers = get_identifiers(expr) variables = self.resolve_all(identifiers, run_namespace, user_identifiers=identifiers) unit = parse_expression_dimensions(rates_var.expr, variables) fail_for_dimension_mismatch( unit, Hz, "The expression provided for " "PoissonGroup's 'rates' " "argument, has to have units " "of Hz") super(PoissonGroup, self).before_run(run_namespace)
def _get_refractory_code(self, run_namespace): ref = self.group._refractory if ref is False: # No refractoriness abstract_code = '' elif isinstance(ref, Quantity): fail_for_dimension_mismatch(ref, second, ('Refractory period has to ' 'be specified in units ' 'of seconds but got ' '{value}'), value=ref) if prefs.legacy.refractory_timing: abstract_code = 'not_refractory = (t - lastspike) > %f\n' % ref else: abstract_code = 'not_refractory = timestep(t - lastspike, dt) >= timestep(%f, dt)\n' % ref else: identifiers = get_identifiers(ref) variables = self.group.resolve_all(identifiers, run_namespace, user_identifiers=identifiers) dims = parse_expression_dimensions(str(ref), variables) if dims is second.dim: if prefs.legacy.refractory_timing: abstract_code = '(t - lastspike) > %s\n' % ref else: abstract_code = 'not_refractory = timestep(t - lastspike, dt) >= timestep(%s, dt)\n' % ref elif dims is DIMENSIONLESS: if not is_boolean_expression(str(ref), variables): raise TypeError(('Refractory expression is dimensionless ' 'but not a boolean value. It needs to ' 'either evaluate to a timespan or to a ' 'boolean value.')) # boolean condition # we have to be a bit careful here, we can't just use the given # condition as it is, because we only want to *leave* # refractoriness, based on the condition abstract_code = 'not_refractory = not_refractory or not (%s)\n' % ref else: raise TypeError(('Refractory expression has to evaluate to a ' 'timespan or a boolean value, expression' '"%s" has units %s instead') % (ref, dims)) return abstract_code
def find_used_variables(self, statements, other_variables): ''' Find all the variables used in the right hand side of the given expressions. Parameters ---------- statements : list list of statement objects Returns ------- used_variables : dict dictionary of variables that are used as variable name (str), `Variable` pairs. ''' variables = self.variables used_variables = {} for statement in statements: lhs, op, rhs, comment = (statement.var, statement.op, statement.expr, statement.comment) for var in (get_identifiers(rhs)): if var in self.function_names: continue try: var_obj = variables[var] except KeyError: var_obj = other_variables[var] used_variables[ var] = var_obj # save as object because this has # all needed info (dtype, name, isarray) # I don't know a nicer way to do this, the above way misses write # variables (e.g. not_refractory).. read, write, _ = self.array_read_write(statements) for var in (read | write): if var not in used_variables: used_variables[var] = variables[ var] # will always be array and # thus exist in variables return used_variables
def get_map(environ_var, relrootdir, pattern, the_map, path_exclusions=[]): if the_map: return the_map if environ_var in os.environ: rootdir = os.environ[environ_var] else: rootdir, _ = os.path.split(__file__) rootdir = os.path.normpath(os.path.join(rootdir, relrootdir)) fnames = [fname for fname in GlobDirectoryWalker(rootdir, '*' + pattern)] for exclude in path_exclusions: fnames = [fname for fname in fnames if exclude not in fname] shortfnames = [os.path.relpath(fname, rootdir) for fname in fnames] exnames = [ fname.replace('/', '.').replace('\\', '.').replace(pattern, '') for fname in shortfnames ] for fname, shortfname, exname in zip(fnames, shortfnames, exnames): with open(fname, 'r') as f: ex = f.read() ids = get_identifiers(ex) for id in ids: the_map[id].append((shortfname.replace('\\', '/'), exname)) return the_map
def variables_to_namespace(self): # Variables can refer to values that are either constant (e.g. dt) # or change every timestep (e.g. t). We add the values of the # constant variables here and add the names of non-constant variables # to a list # A list containing tuples of name and a function giving the value self.nonconstant_values = [] for name, var in self.variables.items(): if isinstance(var, Function): self._insert_func_namespace(var) if isinstance(var, (AuxiliaryVariable, Subexpression)): continue try: value = var.get_value() except (TypeError, AttributeError): # A dummy Variable without value or a function self.namespace[name] = var continue if isinstance(var, ArrayVariable): self.namespace[self.device.get_array_name( var, self.variables)] = value self.namespace['_num' + name] = var.get_len() if var.scalar and var.constant: self.namespace[name] = value.item() else: self.namespace[name] = value if isinstance(var, DynamicArrayVariable): dyn_array_name = self.generator_class.get_array_name( var, access_data=False) self.namespace[dyn_array_name] = self.device.get_value( var, access_data=False) # Also provide the Variable object itself in the namespace (can be # necessary for resize operations, for example) self.namespace['_var_' + name] = var # Get all identifiers in the code -- note that this is not a smart # function, it will get identifiers from strings, comments, etc. This # is not a problem here, since we only use this list to filter out # things. If we include something incorrectly, this only means that we # will pass something into the namespace unnecessarily. all_identifiers = get_identifiers(self.code.run) # Filter out all unneeded objects self.namespace = { k: v for k, v in self.namespace.items() if k in all_identifiers } # There is one type of objects that we have to inject into the # namespace with their current value at each time step: dynamic # arrays that change in size during runs, where the size change is not # initiated by the template itself for name, var in self.variables.items(): if (isinstance(var, DynamicArrayVariable) and var.needs_reference_update): array_name = self.device.get_array_name(var, self.variables) if array_name in self.namespace: self.nonconstant_values.append((array_name, var.get_value)) if '_num' + name in self.namespace: self.nonconstant_values.append( ('_num' + name, var.get_len))
def check_for_order_independence(statements, variables, indices): ''' Check that the sequence of statements doesn't depend on the order in which the indices are iterated through. ''' # Remove stateless functions from variables (only bother with ones that are used) all_used_vars = set() for statement in statements: all_used_vars.update(get_identifiers(statement.expr)) variables = variables.copy() for var in set(variables.keys()).intersection(all_used_vars): val = variables[var] if isinstance(val, Function): if val.stateless: del variables[var] else: raise OrderDependenceError( "Function %s may have internal state, " "which can lead to order dependence." % var) all_variables = [ v for v in variables if not isinstance(variables[v], Constant) ] # Main index variables are those whose index corresponds to the main index being iterated through. By # assumption/definition, these indices are unique, and any order-dependence cannot come from their values, # only from the values of the derived indices. In the most common case of Synapses, the main index would be # the synapse index, and the derived indices would be pre and postsynaptic indices (which can be repeated). unique_index = lambda v: (indices[v] != '0' and getattr( variables[indices[v]], 'unique', False)) main_index_variables = { v for v in all_variables if indices[v] == '_idx' or unique_index(v) } different_index_variables = set(all_variables) - main_index_variables # At the start, we assume all the different/derived index variables are permutation independent and we continue # to scan through the list of statements checking whether or not permutation-dependence has been introduced # until the permutation_independent set has stopped changing. permutation_independent = list(different_index_variables) permutation_dependent_aux_vars = set() changed_permutation_independent = True for statement in statements: if statement.op == ':=' and statement.var not in all_variables: main_index_variables.add(statement.var) all_variables.append(statement.var) while changed_permutation_independent: changed_permutation_independent = False for statement in statements: vars_in_expr = get_identifiers( statement.expr).intersection(all_variables) # any time a statement involves a LHS and RHS which only depend on itself, this doesn't change anything if {statement.var} == vars_in_expr: continue nonsyn_vars_in_expr = vars_in_expr.intersection( different_index_variables) permdep = any(var not in permutation_independent for var in nonsyn_vars_in_expr) permdep = permdep or any(var in permutation_dependent_aux_vars for var in vars_in_expr) if statement.op == ':=': # auxiliary variable created if permdep: if statement.var not in permutation_dependent_aux_vars: permutation_dependent_aux_vars.add(statement.var) changed_permutation_independent = True continue elif statement.var in main_index_variables: if permdep: raise OrderDependenceError() elif statement.var in different_index_variables: if statement.op in ('+=', '*=', '-=', '/='): if permdep: raise OrderDependenceError() if statement.var in permutation_independent: permutation_independent.remove(statement.var) changed_permutation_independent = True elif statement.op == '=': otheridx = [ v for v in variables if indices[v] not in (indices[statement.var], '_idx', '0') ] if any(var in otheridx for var in vars_in_expr): raise OrderDependenceError() if permdep: raise OrderDependenceError() if any(var in main_index_variables for var in vars_in_expr): raise OrderDependenceError() else: raise OrderDependenceError() else: raise AssertionError('Should never get here...')
def dimensions_and_type_from_string(unit_string): ''' Returns the physical dimensions that results from evaluating a string like "siemens / metre ** 2", allowing for the special string "1" to signify dimensionless units, the string "boolean" for a boolean and "integer" for an integer variable. Parameters ---------- unit_string : str The string that should evaluate to a unit Returns ------- d, type : (`Dimension`, {FLOAT, INTEGER or BOOL}) The resulting physical dimensions and the type of the variable. Raises ------ ValueError If the string cannot be evaluated to a unit. ''' # Lazy import to avoid circular dependency from angela2.core.namespace import DEFAULT_UNITS global _base_units_with_alternatives global _base_units if _base_units_with_alternatives is None: base_units_for_dims = {} for unit_name, unit in reversed(DEFAULT_UNITS.items()): if float(unit) == 1.0 and repr(unit)[-1] not in ['2', '3']: if unit.dim in base_units_for_dims: if unit_name not in base_units_for_dims[unit.dim]: base_units_for_dims[unit.dim].append(unit_name) else: base_units_for_dims[unit.dim] = [repr(unit)] if unit_name != repr(unit): base_units_for_dims[unit.dim].append(unit_name) alternatives = sorted( [tuple(values) for values in base_units_for_dims.values()]) _base_units = dict([(v, DEFAULT_UNITS[v]) for values in alternatives for v in values]) # Create a string that lists all allowed base units alternative_strings = [] for units in alternatives: string = units[0] if len(units) > 1: string += ' ({other_units})'.format( other_units=', '.join(units[1:])) alternative_strings.append(string) _base_units_with_alternatives = ', '.join(alternative_strings) unit_string = unit_string.strip() # Special case: dimensionless unit if unit_string == '1': return DIMENSIONLESS, FLOAT # Another special case: boolean variable if unit_string == 'boolean': return DIMENSIONLESS, BOOLEAN if unit_string == 'bool': raise TypeError("Use 'boolean' not 'bool' as the unit for a boolean " "variable.") # Yet another special case: integer variable if unit_string == 'integer': return DIMENSIONLESS, INTEGER # Check first whether the expression only refers to base units identifiers = get_identifiers(unit_string) for identifier in identifiers: if identifier not in _base_units: if identifier in DEFAULT_UNITS: # A known unit, but not a base unit base_unit = get_unit(DEFAULT_UNITS[identifier].dim) if not repr(base_unit) in _base_units: # Make sure that we don't suggest a unit that is not allowed # (should not happen, normally) base_unit = Unit(1, dim=base_unit.dim) raise ValueError( ('Unit specification refers to ' '"{identifier}", but this is not a base ' 'unit. Use "{base_unit}" ' 'instead.').format(identifier=identifier, base_unit=repr(base_unit))) else: # Not a known unit raise ValueError( ('Unit specification refers to ' '"{identifier}", but this is not a base ' 'unit. The following base units are ' 'allowed: ' '{allowed_units}.').format( identifier=identifier, allowed_units=_base_units_with_alternatives)) try: evaluated_unit = eval(unit_string, _base_units) except Exception as ex: raise ValueError(('Could not interpret "%s" as a unit specification: ' '%s') % (unit_string, ex)) # Check whether the result is a unit if not isinstance(evaluated_unit, Unit): if isinstance(evaluated_unit, Quantity): raise ValueError( ('"%s" does not evaluate to a unit but to a ' 'quantity -- make sure to only use units, e.g. ' '"siemens/metre**2" and not "1 * siemens/metre**2"') % unit_string) else: raise ValueError( ('"%s" does not evaluate to a unit, the result ' 'has type %s instead.' % (unit_string, type(evaluated_unit)))) # No error has been raised, all good return evaluated_unit.dim, FLOAT
def __init__(self, code): self._code = code.strip() # : Set of identifiers in the code string self.identifiers = get_identifiers(code)
def optimise_statements(scalar_statements, vector_statements, variables, blockname=''): ''' Optimise a sequence of scalar and vector statements Performs the following optimisations: 1. Constant evaluations (e.g. exp(0) to 1). See `evaluate_expr`. 2. Arithmetic simplifications (e.g. 0*x to 0). See `ArithmeticSimplifier`, `collect`. 3. Pulling out loop invariants (e.g. v*exp(-dt/tau) to a=exp(-dt/tau) outside the loop and v*a inside). See `Simplifier`. 4. Boolean simplifications (allowing the replacement of expressions with booleans with a sequence of if/thens). See `Simplifier`. Parameters ---------- scalar_statements : sequence of Statement Statements that only involve scalar values and should be evaluated in the scalar block. vector_statements : sequence of Statement Statements that involve vector values and should be evaluated in the vector block. variables : dict of (str, Variable) Definition of the types of the variables. blockname : str, optional Name of the block (used for LIO constant prefixes to avoid name clashes) Returns ------- new_scalar_statements : sequence of Statement As above but with loop invariants pulled out from vector statements new_vector_statements : sequence of Statement Simplified/optimised versions of statements ''' boolvars = dict((k, v) for k, v in variables.items() if hasattr(v, 'dtype') and angela_dtype_from_dtype(v.dtype) == 'boolean') # We use the Simplifier class by rendering each expression, which generates new scalar statements # stored in the Simplifier object, and these are then added to the scalar statements. simplifier = Simplifier(variables, scalar_statements, extra_lio_prefix=blockname) new_vector_statements = [] for stmt in vector_statements: # Carry out constant evaluation, arithmetic simplification and loop invariants new_expr = simplifier.render_expr(stmt.expr) new_stmt = Statement(stmt.var, stmt.op, new_expr, stmt.comment, dtype=stmt.dtype, constant=stmt.constant, subexpression=stmt.subexpression, scalar=stmt.scalar) # Now check if boolean simplification can be carried out complexity_std = expression_complexity(new_expr, simplifier.variables) idents = get_identifiers(new_expr) used_boolvars = [var for var in boolvars if var in idents] if len(used_boolvars): # We want to iterate over all the possible assignments of boolean variables to values in (True, False) bool_space = [[False, True] for var in used_boolvars] expanded_expressions = {} complexities = {} for bool_vals in itertools.product(*bool_space): # substitute those values into the expr and simplify (including potentially pulling out new # loop invariants) subs = dict((var, str(val)) for var, val in zip(used_boolvars, bool_vals)) curexpr = word_substitute(new_expr, subs) curexpr = simplifier.render_expr(curexpr) key = tuple( (var, val) for var, val in zip(used_boolvars, bool_vals)) expanded_expressions[key] = curexpr complexities[key] = expression_complexity( curexpr, simplifier.variables) # See Statement for details on these new_stmt.used_boolean_variables = used_boolvars new_stmt.boolean_simplified_expressions = expanded_expressions new_stmt.complexity_std = complexity_std new_stmt.complexities = complexities new_vector_statements.append(new_stmt) # Generate additional scalar statements for the loop invariants new_scalar_statements = copy.copy(scalar_statements) for expr, name in simplifier.loop_invariants.items(): dtype_name = simplifier.loop_invariant_dtypes[name] if dtype_name == 'boolean': dtype = bool elif dtype_name == 'integer': dtype = int else: dtype = prefs.core.default_float_dtype new_stmt = Statement(name, ':=', expr, '', dtype=dtype, constant=True, subexpression=False, scalar=True) new_scalar_statements.append(new_stmt) return new_scalar_statements, new_vector_statements
def make_statements(code, variables, dtype, optimise=True, blockname=''): ''' make_statements(code, variables, dtype, optimise=True, blockname='') Turn a series of abstract code statements into Statement objects, inferring whether each line is a set/declare operation, whether the variables are constant or not, and handling the cacheing of subexpressions. Parameters ---------- code : str A (multi-line) string of statements. variables : dict-like A dictionary of with `Variable` and `Function` objects for every identifier used in the `code`. dtype : `dtype` The data type to use for temporary variables optimise : bool, optional Whether to optimise expressions, including pulling out loop invariant expressions and putting them in new scalar constants. Defaults to ``False``, since this function is also used just to in contexts where we are not interested by this kind of optimisation. For the main code generation stage, its value is set by the `codegen.loop_invariant_optimisations` preference. blockname : str, optional A name for the block (used to name intermediate variables to avoid name clashes when multiple blocks are used together) Returns ------- scalar_statements, vector_statements : (list of `Statement`, list of `Statement`) Lists with statements that are to be executed once and statements that are to be executed once for every neuron/synapse/... (or in a vectorised way) Notes ----- If ``optimise`` is ``True``, then the ``scalar_statements`` may include newly introduced scalar constants that have been identified as loop-invariant and have therefore been pulled out of the vector statements. The resulting statements will also use augmented assignments where possible, i.e. a statement such as ``w = w + 1`` will be replaced by ``w += 1``. Also, statements involving booleans will have additional information added to them (see `Statement` for details) describing how the statement can be reformulated as a sequence of if/then statements. Calls `~angela2.codegen.optimisation.optimise_statements`. ''' code = strip_empty_lines(deindent(code)) lines = re.split(r'[;\n]', code) lines = [LineInfo(code=line) for line in lines if len(line)] # Do a copy so we can add stuff without altering the original dict variables = dict(variables) # we will do inference to work out which lines are := and which are = defined = set(k for k, v in variables.items() if not isinstance(v, AuxiliaryVariable)) for line in lines: statement = None # parse statement into "var op expr" var, op, expr, comment = parse_statement(line.code) if var in variables and isinstance(variables[var], Subexpression): raise SyntaxError("Illegal line '{line}' in abstract code. " "Cannot write to subexpression " "'{var}'.".format(line=line.code, var=var)) if op == '=': if var not in defined: op = ':=' defined.add(var) if var not in variables: annotated_ast = angela_ast(expr, variables) is_scalar = annotated_ast.scalar if annotated_ast.dtype == 'boolean': use_dtype = bool elif annotated_ast.dtype == 'integer': use_dtype = int else: use_dtype = dtype new_var = AuxiliaryVariable(var, dtype=use_dtype, scalar=is_scalar) variables[var] = new_var elif not variables[var].is_boolean: sympy_expr = str_to_sympy(expr, variables) if variables[var].is_integer: sympy_var = sympy.Symbol(var, integer=True) else: sympy_var = sympy.Symbol(var, real=True) try: collected = sympy.collect(sympy_expr, sympy_var, exact=True, evaluate=False) except AttributeError: # If something goes wrong during collection, e.g. collect # does not work for logical expressions collected = {1: sympy_expr} if (len(collected) == 2 and set(collected.keys()) == {1, sympy_var} and collected[sympy_var] == 1): # We can replace this statement by a += assignment statement = Statement(var, '+=', sympy_to_str(collected[1]), comment, dtype=variables[var].dtype, scalar=variables[var].scalar) elif len(collected) == 1 and sympy_var in collected: # We can replace this statement by a *= assignment statement = Statement(var, '*=', sympy_to_str(collected[sympy_var]), comment, dtype=variables[var].dtype, scalar=variables[var].scalar) if statement is None: statement = Statement(var, op, expr, comment, dtype=variables[var].dtype, scalar=variables[var].scalar) line.statement = statement # for each line will give the variable being written to line.write = var # each line will give a set of variables which are read line.read = get_identifiers_recursively([expr], variables) # All writes to scalar variables must happen before writes to vector # variables scalar_write_done = False for line in lines: stmt = line.statement if stmt.op != ':=' and variables[ stmt.var].scalar and scalar_write_done: raise SyntaxError( ('All writes to scalar variables in a code block ' 'have to be made before writes to vector ' 'variables. Illegal write to %s.') % line.write) elif not variables[stmt.var].scalar: scalar_write_done = True # all variables which are written to at some point in the code block # used to determine whether they should be const or not all_write = set(line.write for line in lines) # backwards compute whether or not variables will be read again # note that will_read for a line gives the set of variables it will read # on the current line or subsequent ones. will_write gives the set of # variables that will be written after the current line will_read = set() will_write = set() for line in lines[::-1]: will_read = will_read.union(line.read) line.will_read = will_read.copy() line.will_write = will_write.copy() will_write.add(line.write) subexpressions = dict((name, val) for name, val in variables.items() if isinstance(val, Subexpression)) # Check that no scalar subexpression refers to a vectorised function # (e.g. rand()) -- otherwise it would be differently interpreted depending # on whether it is used in a scalar or a vector context (i.e., even though # the subexpression is supposed to be scalar, it would be vectorised when # used as part of non-scalar expressions) for name, subexpr in subexpressions.items(): if subexpr.scalar: identifiers = get_identifiers(subexpr.expr) for identifier in identifiers: if (identifier in variables and getattr( variables[identifier], 'auto_vectorise', False)): raise SyntaxError(('The scalar subexpression {} refers to ' 'the implicitly vectorised function {} ' '-- this is not allowed since it leads ' 'to different interpretations of this ' 'subexpression depending on whether it ' 'is used in a scalar or vector ' 'context.').format(name, identifier)) # sort subexpressions into an order so that subexpressions that don't depend # on other subexpressions are first subexpr_deps = dict( (name, [dep for dep in subexpr.identifiers if dep in subexpressions]) for name, subexpr in subexpressions.items()) sorted_subexpr_vars = topsort(subexpr_deps) statements = [] # none are yet defined (or declared) subdefined = dict((name, None) for name in subexpressions) for line in lines: stmt = line.statement read = line.read write = line.write will_read = line.will_read will_write = line.will_write # update/define all subexpressions needed by this statement for var in sorted_subexpr_vars: if var not in read: continue subexpression = subexpressions[var] # if already defined/declared if subdefined[var] == 'constant': continue elif subdefined[var] == 'variable': op = '=' constant = False else: op = ':=' # check if the referred variables ever change ids = subexpression.identifiers constant = all(v not in will_write for v in ids) subdefined[var] = 'constant' if constant else 'variable' statement = Statement(var, op, subexpression.expr, comment='', dtype=variables[var].dtype, constant=constant, subexpression=True, scalar=variables[var].scalar) statements.append(statement) var, op, expr, comment = stmt.var, stmt.op, stmt.expr, stmt.comment # constant only if we are declaring a new variable and we will not # write to it again constant = op == ':=' and var not in will_write statement = Statement(var, op, expr, comment, dtype=variables[var].dtype, constant=constant, scalar=variables[var].scalar) statements.append(statement) scalar_statements = [s for s in statements if s.scalar] vector_statements = [s for s in statements if not s.scalar] if optimise and prefs.codegen.loop_invariant_optimisations: scalar_statements, vector_statements = optimise_statements( scalar_statements, vector_statements, variables, blockname=blockname) return scalar_statements, vector_statements
def analyse_identifiers(code, variables, recursive=False): ''' Analyses a code string (sequence of statements) to find all identifiers by type. In a given code block, some variable names (identifiers) must be given as inputs to the code block, and some are created by the code block. For example, the line:: a = b+c This could mean to create a new variable a from b and c, or it could mean modify the existing value of a from b or c, depending on whether a was previously known. Parameters ---------- code : str The code string, a sequence of statements one per line. variables : dict of `Variable`, set of names Specifiers for the model variables or a set of known names recursive : bool, optional Whether to recurse down into subexpressions (defaults to ``False``). Returns ------- newly_defined : set A set of variables that are created by the code block. used_known : set A set of variables that are used and already known, a subset of the ``known`` parameter. unknown : set A set of variables which are used by the code block but not defined by it and not previously known. Should correspond to variables in the external namespace. ''' if isinstance(variables, Mapping): known = set(k for k, v in variables.items() if not isinstance(k, AuxiliaryVariable)) else: known = set(variables) variables = dict( (k, Variable(name=k, dtype=np.float64)) for k in known) known |= STANDARD_IDENTIFIERS scalar_stmts, vector_stmts = make_statements(code, variables, np.float64, optimise=False) stmts = scalar_stmts + vector_stmts defined = set(stmt.var for stmt in stmts if stmt.op == ':=') if len(stmts) == 0: allids = set() elif recursive: if not isinstance(variables, Mapping): raise TypeError('Have to specify a variables dictionary.') allids = get_identifiers_recursively( [stmt.expr for stmt in stmts], variables) | {stmt.var for stmt in stmts} else: allids = set.union( *[get_identifiers(stmt.expr) for stmt in stmts]) | {stmt.var for stmt in stmts} dependent = allids.difference(defined, known) used_known = allids.intersection(known) - STANDARD_IDENTIFIERS return defined, used_known, dependent
def substitute_abstract_code_functions(code, funcs): ''' Performs inline substitution of all the functions in the code Parameters ---------- code : str The abstract code to make inline substitutions into. funcs : list, dict or set of AbstractCodeFunction The function substitutions to use, note in the case of a dict, the keys are ignored and the function name is used. Returns ------- code : str The code with inline substitutions performed. ''' if isinstance(funcs, (list, set)): newfuncs = dict() for f in funcs: newfuncs[f.name] = f funcs = newfuncs code = deindent(code) lines = ast.parse(code, mode='exec').body # This is a slightly nasty hack, but basically we just check by looking at # the existing identifiers how many inline operations have already been # performed by previous calls to this function ids = get_identifiers(code) funcstarts = {} for func in funcs.values(): subids = { id for id in ids if id.startswith('_inline_' + func.name + '_') } subids = { id.replace('_inline_' + func.name + '_', '') for id in subids } alli = [] for subid in subids: p = subid.find('_') if p > 0: subid = subid[:p] i = int(subid) alli.append(i) if len(alli) == 0: i = 0 else: i = max(alli) + 1 funcstarts[func.name] = i # Now we rewrite all the lines, replacing each line with a sequence of # lines performing the inlining newlines = [] for line in lines: for func in funcs.values(): rw = FunctionRewriter(func, funcstarts[func.name]) line = rw.visit(line) newlines.extend(rw.pre) funcstarts[func.name] = rw.numcalls newlines.append(line) # Now we render to a code string nr = NodeRenderer() newcode = '\n'.join(nr.render_node(line) for line in newlines) # We recurse until no changes in the code to ensure that all functions # are expanded if one function refers to another, etc. if newcode == code: return newcode else: return substitute_abstract_code_functions(newcode, funcs)
def vectorise_code(self, statements, variables, variable_indices, index='_idx'): created_vars = {stmt.var for stmt in statements if stmt.op == ':='} try: lines = [] used_variables = set() for statement in statements: lines.append('# Abstract code: {var} {op} {expr}'.format(var=statement.var, op=statement.op, expr=statement.expr)) # We treat every statement individually with its own read and write code # to be on the safe side read, write, indices, conditional_write_vars = self.arrays_helper([statement]) # We make sure that we only add code to `lines` after it went # through completely ufunc_lines = [] # No need to load a variable if it is only in read because of # the in-place operation if (statement.inplace and variable_indices[statement.var] != '_idx' and statement.var not in get_identifiers(statement.expr)): read = read - {statement.var} ufunc_lines.extend(self.read_arrays(read, write, indices, variables, variable_indices)) ufunc_lines.append(self.ufunc_at_vectorisation(statement, variables, variable_indices, conditional_write_vars, created_vars, used_variables, )) # Do not write back such values, the ufuncs have modified the # underlying array already if statement.inplace and variable_indices[statement.var] != '_idx': write = write - {statement.var} ufunc_lines.extend(self.write_arrays([statement], read, write, variables, variable_indices)) lines.extend(ufunc_lines) except VectorisationError: if self._use_ufunc_at_vectorisation: logger.info("Failed to vectorise code, falling back on Python loop: note that " "this will be very slow! Switch to another code generation target for " "best performance (e.g. cython). First line is: "+str(statements[0]), once=True) lines = [] lines.extend(['_full_idx = _idx', 'for _idx in _full_idx:', ' _vectorisation_idx = _idx' ]) read, write, indices, conditional_write_vars = self.arrays_helper(statements) lines.extend(indent(code) for code in self.read_arrays(read, write, indices, variables, variable_indices)) for statement in statements: line = self.translate_statement(statement) if statement.var in conditional_write_vars: lines.append(indent('if {}:'.format(conditional_write_vars[statement.var]))) lines.append(indent(line, 2)) else: lines.append(indent(line)) lines.extend(indent(code) for code in self.write_arrays(statements, read, write, variables, variable_indices)) return lines