def __init__(self, variables, scalar_statements, extra_lio_prefix=''):
     angelaASTRenderer.__init__(self, variables, copy_variables=False)
     self.loop_invariants = OrderedDict()
     self.loop_invariant_dtypes = {}
     self.value = 0
     self.node_renderer = NodeRenderer()
     self.arithmetic_simplifier = ArithmeticSimplifier(variables)
     self.scalar_statements = scalar_statements
     if extra_lio_prefix is None:
         extra_lio_prefix = ''
     if len(extra_lio_prefix):
         extra_lio_prefix = extra_lio_prefix + '_'
     self.extra_lio_prefix = extra_lio_prefix
def cancel_identical_terms(primary, inverted):
    '''
    Cancel terms in a collection, e.g. a+b-a should be cancelled to b

    Simply renders the nodes into expressions and removes whenever there is a common expression
    in primary and inverted.

    Parameters
    ----------
    primary : list of AST nodes
        These are the nodes that are positive with respect to the operator, e.g.
        in x*y/z it would be [x, y].
    inverted : list of AST nodes
        These are the nodes that are inverted with respect to the operator, e.g.
        in x*y/z it would be [z].

    Returns
    -------
    primary : list of AST nodes
        Primary nodes after cancellation
    inverted : list of AST nodes
        Inverted nodes after cancellation
    '''
    nr = NodeRenderer()
    expressions = dict((node, nr.render_node(node)) for node in primary)
    expressions.update(dict((node, nr.render_node(node)) for node in inverted))
    new_primary = []
    inverted_expressions = [expressions[term] for term in inverted]
    for term in primary:
        expr = expressions[term]
        if expr in inverted_expressions and term.stateless:
            new_inverted = []
            for iterm in inverted:
                if expressions[iterm] == expr:
                    expr = ''  # handled
                else:
                    new_inverted.append(iterm)
            inverted = new_inverted
            inverted_expressions = [expressions[term] for term in inverted]
        else:
            new_primary.append(term)
    return new_primary, inverted
 def render_node(self, node):
     '''
     Assumes that the node has already been fully processed by angelaASTRenderer
     '''
     if not hasattr(node, 'simplified'):
         node = super(ArithmeticSimplifier, self).render_node(node)
         node.simplified = True
     # can't evaluate vector expressions, so abandon in this case
     if not node.scalar:
         return node
     # No evaluation necessary for simple names or numbers
     if node.__class__.__name__ in [
             'Name', 'NameConstant', 'Num', 'Constant'
     ]:
         return node
     # Don't evaluate stateful nodes (e.g. those containing a rand() call)
     if not node.stateless:
         return node
     # try fully evaluating using assumptions
     expr = NodeRenderer().render_node(node)
     val, evaluated = evaluate_expr(expr, self.assumptions_ns)
     if evaluated:
         if node.dtype == 'boolean':
             val = bool(val)
             if hasattr(ast, 'Constant'):
                 newnode = ast.Constant(val)
             elif hasattr(ast, 'NameConstant'):
                 newnode = ast.NameConstant(val)
             else:
                 # None is the expression context, we don't use it so we just set to None
                 newnode = ast.Name(repr(val), None)
         elif node.dtype == 'integer':
             val = int(val)
         else:
             val = prefs.core.default_float_dtype(val)
         if node.dtype != 'boolean':
             if hasattr(ast, 'Constant'):
                 newnode = ast.Constant(val)
             else:
                 newnode = ast.Num(val)
         newnode.dtype = node.dtype
         newnode.scalar = True
         newnode.stateless = node.stateless
         newnode.complexity = 0
         return newnode
     return node
 def render_BinOp(self, node):
     node.left.parent = weakref.proxy(node)
     node.right.parent = weakref.proxy(node)
     node.left = self.render_node(node.left)
     node.right = self.render_node(node.right)
     # TODO: we could capture some syntax errors here, e.g. bool+bool
     # captures, e.g. int+float->float
     newdtype = dtype_hierarchy[max(dtype_hierarchy[subnode.dtype]
                                    for subnode in [node.left, node.right])]
     if node.op.__class__.__name__ == 'Div':
         # Division turns integers into floating point values
         newdtype = 'float'
         # Give a warning if the code uses floating point division where it
         # previously might have used floor division
         if node.left.dtype == node.right.dtype == 'integer':
             # This would have led to floor division in earlier versions of
             # angela (except for the numpy target on Python 3)
             # Ignore cases where the user already took care of this by
             # wrapping the result of the division in int(...) or
             # floor(...)
             if not (hasattr(node, 'parent')
                     and node.parent.__class__.__name__ == 'Call'
                     and node.parent.func.id in ['int', 'floor']):
                 rendered_expr = NodeRenderer().render_node(node)
                 msg = ('The expression "{}" divides two integer values. '
                        'In previous versions of angela, this would have '
                        'used either an integer ("flooring") or a floating '
                        'point division, depending on the Python version '
                        'and the code generation target. In the current '
                        'version, it always uses a floating point '
                        'division. Explicitly ask for an  integer division '
                        '("//"), or turn one of the operands into a '
                        'floating point value (e.g. replace "1/2" by '
                        '"1.0/2") to no longer receive this '
                        'warning.'.format(rendered_expr))
                 logger.warn(msg, 'floating_point_division', once=True)
     node.dtype = newdtype
     node.scalar = node.left.scalar and node.right.scalar
     node.complexity = 1 + node.left.complexity + node.right.complexity
     node.stateless = node.left.stateless and node.right.stateless
     return node
def parse_expression_dimensions(expr, variables, orig_expr=None):
    '''
    Returns the unit value of an expression, and checks its validity
    
    Parameters
    ----------
    expr : str
        The expression to check.
    variables : dict
        Dictionary of all variables used in the `expr` (including `Constant`
        objects for external variables)
    
    Returns
    -------
    unit : Quantity
        The output unit of the expression
    
    Raises
    ------
    SyntaxError
        If the expression cannot be parsed, or if it uses ``a**b`` for ``b``
        anything other than a constant number.
    DimensionMismatchError
        If any part of the expression is dimensionally inconsistent.
    '''

    # If we are working on a string, convert to the top level node    
    if isinstance(expr, str):
        orig_expr = expr
        mod = ast.parse(expr, mode='eval')
        expr = mod.body
    if expr.__class__ is getattr(ast, 'NameConstant', None):
        # new class for True, False, None in Python 3.4
        value = expr.value
        if value is True or value is False:
            return DIMENSIONLESS
        else:
            raise ValueError('Do not know how to handle value %s' % value)
    if expr.__class__ is ast.Name:
        name = expr.id
        # Raise an error if a function is called as if it were a variable
        # (most of the time this happens for a TimedArray)
        if name in variables and isinstance(variables[name], Function):
            raise SyntaxError('%s was used like a variable/constant, but it is '
                              'a function.' % name,
                              ("<string>",
                               expr.lineno,
                               expr.col_offset + 1,
                               orig_expr)
                              )
        if name in variables:
            return get_dimensions(variables[name])
        elif name in ['True', 'False']:
            return DIMENSIONLESS
        else:
            raise KeyError('Unknown identifier %s' % name)
    elif (expr.__class__ is ast.Num or
          expr.__class__ is getattr(ast, 'Constant', None)):  # Python 3.8
        return DIMENSIONLESS
    elif expr.__class__ is ast.BoolOp:
        # check that the units are valid in each subexpression
        for node in expr.values:
            parse_expression_dimensions(node, variables, orig_expr=orig_expr)
        # but the result is a bool, so we just return 1 as the unit
        return DIMENSIONLESS
    elif expr.__class__ is ast.Compare:
        # check that the units are consistent in each subexpression
        subexprs = [expr.left]+expr.comparators
        subunits = []
        for node in subexprs:
            subunits.append(parse_expression_dimensions(node, variables, orig_expr=orig_expr))
        for left_dim, right_dim in zip(subunits[:-1], subunits[1:]):
            if not have_same_dimensions(left_dim, right_dim):
                msg = ('Comparison of expressions with different units. Expression '
                       '"{}" has unit ({}), while expression "{}" has units ({})').format(
                            NodeRenderer().render_node(expr.left), get_dimensions(left_dim),
                            NodeRenderer().render_node(expr.comparators[0]), get_dimensions(right_dim))
                raise DimensionMismatchError(msg)
        # but the result is a bool, so we just return 1 as the unit
        return DIMENSIONLESS
    elif expr.__class__ is ast.Call:
        if len(expr.keywords):
            raise ValueError("Keyword arguments not supported.")
        elif getattr(expr, 'starargs', None) is not None:
            raise ValueError("Variable number of arguments not supported")
        elif getattr(expr, 'kwargs', None) is not None:
            raise ValueError("Keyword arguments not supported")

        func = variables.get(expr.func.id, None)
        if func is None:
            raise SyntaxError('Unknown function %s' % expr.func.id,
                              ("<string>",
                               expr.lineno,
                               expr.col_offset + 1,
                               orig_expr)
                              )
        if not hasattr(func, '_arg_units') or not hasattr(func, '_return_unit'):
            raise ValueError(('Function %s does not specify how it '
                              'deals with units.') % expr.func.id)

        if len(func._arg_units) != len(expr.args):
            raise SyntaxError('Function %s was called with %d parameters, '
                              'needs %d.' % (expr.func.id,
                                             len(expr.args),
                                             len(func._arg_units)),
                              ("<string>",
                               expr.lineno,
                               expr.col_offset + len(expr.func.id) + 1,
                               orig_expr))



        for idx, (arg, expected_unit) in enumerate(zip(expr.args,
                                                       func._arg_units)):
            arg_unit = parse_expression_dimensions(arg, variables,
                                                   orig_expr=orig_expr)
            # A "None" in func._arg_units means: No matter what unit
            if expected_unit is None:
                continue
            # A string means: same unit as other argument
            elif isinstance(expected_unit, str):
                arg_idx = func._arg_names.index(expected_unit)
                expected_unit = parse_expression_dimensions(expr.args[arg_idx],
                                                              variables,
                                                              orig_expr=orig_expr)
                if not have_same_dimensions(arg_unit, expected_unit):
                    msg = (f'Argument number {idx + 1} for function '
                           f'{expr.func.id} was supposed to have the '
                           f'same units as argument number {arg_idx + 1}, but '
                           f'\'{NodeRenderer().render_node(arg)}\' has unit '
                           f'{get_unit_for_display(arg_unit)}, while '
                           f'\'{NodeRenderer().render_node(expr.args[arg_idx])}\' '
                           f'has unit {get_unit_for_display(expected_unit)}')
                    raise DimensionMismatchError(msg)
            elif expected_unit == bool:
                if not is_boolean_expression(arg, variables):
                    raise TypeError(('Argument number %d for function %s was '
                                     'expected to be a boolean value, but is '
                                     '"%s".') % (idx + 1, expr.func.id,
                                                 NodeRenderer().render_node(arg)))
            else:
                if not have_same_dimensions(arg_unit, expected_unit):
                    msg = ('Argument number {} for function {} does not have the '
                           'correct units. Expression "{}" has units ({}), but '
                           'should be ({}).').format(
                        idx+1, expr.func.id,
                        NodeRenderer().render_node(arg),
                        get_dimensions(arg_unit), get_dimensions(expected_unit))
                    raise DimensionMismatchError(msg)

        if func._return_unit == bool:
            return DIMENSIONLESS
        elif isinstance(func._return_unit, (Unit, int)):
            # Function always returns the same unit
            return getattr(func._return_unit, 'dim', DIMENSIONLESS)
        else:
            # Function returns a unit that depends on the arguments
            arg_units = [parse_expression_dimensions(arg, variables, orig_expr=orig_expr)
                         for arg in expr.args]
            return func._return_unit(*arg_units).dim

    elif expr.__class__ is ast.BinOp:
        op = expr.op.__class__.__name__
        left_dim = parse_expression_dimensions(expr.left, variables, orig_expr=orig_expr)
        right_dim = parse_expression_dimensions(expr.right, variables, orig_expr=orig_expr)
        if op in ['Add', 'Sub', 'Mod']:
            # dimensions should be the same
            if left_dim is not right_dim:
                op_symbol = {'Add': '+', 'Sub': '-', 'Mod': '%'}.get(op)
                left_str = NodeRenderer().render_node(expr.left)
                right_str = NodeRenderer().render_node(expr.right)
                left_unit = get_unit_for_display(left_dim)
                right_unit = get_unit_for_display(right_dim)
                error_msg = ('Expression "{left} {op} {right}" uses '
                             'inconsistent units ("{left}" has unit '
                             '{left_unit}; "{right}" '
                             'has unit {right_unit})').format(left=left_str,
                                                             right=right_str,
                                                             op=op_symbol,
                                                             left_unit=left_unit,
                                                             right_unit=right_unit)
                raise DimensionMismatchError(error_msg)
            u = left_dim
        elif op == 'Mult':
            u = left_dim*right_dim
        elif op == 'Div':
            u = left_dim/right_dim
        elif op == 'FloorDiv':
            if not (left_dim is DIMENSIONLESS and right_dim is DIMENSIONLESS):
                if left_dim is DIMENSIONLESS:
                    col_offset = expr.right.col_offset + 1
                else:
                    col_offset = expr.left.col_offset + 1
                raise SyntaxError('Floor division can only be used on '
                                  'dimensionless values.',
                                  ("<string>",
                                   expr.lineno,
                                   col_offset,
                                   orig_expr)
                                  )
            u = DIMENSIONLESS
        elif op == 'Pow':
            if left_dim is DIMENSIONLESS and right_dim is DIMENSIONLESS:
                return DIMENSIONLESS
            n = _get_value_from_expression(expr.right, variables)
            u = left_dim**n
        else:
            raise SyntaxError("Unsupported operation "+op,
                              ("<string>",
                               expr.lineno,
                               getattr(expr.left, 'end_col_offset',
                                       len(NodeRenderer().render_node(expr.left))) + 1,
                               orig_expr)
                              )
        return u
    elif expr.__class__ is ast.UnaryOp:
        op = expr.op.__class__.__name__
        # check validity of operand and get its unit
        u = parse_expression_dimensions(expr.operand, variables, orig_expr=orig_expr)
        if op == 'Not':
            return DIMENSIONLESS
        else:
            return u
    else:
        raise SyntaxError('Unsupported operation ' + str(expr.__class__.__name__),
                          ("<string>",
                           expr.lineno,
                           expr.col_offset + 1,
                           orig_expr)
                          )
class Simplifier(angelaASTRenderer):
    '''
    Carry out arithmetic simplifications (see `ArithmeticSimplifier`) and loop invariants

    Parameters
    ----------
    variables : dict of (str, Variable)
        Usual definition of variables.
    scalar_statements : sequence of Statement
        Predefined scalar statements that can be used as part of simplification

    Notes
    -----

    After calling `render_expr` on a sequence of expressions (coming from vector statements typically),
    this object will have some new attributes:

    ``loop_invariants`` : OrderedDict of (expression, varname)
        varname will be of the form ``_lio_N`` where ``N`` is some integer, and the expressions will be
        strings that correspond to scalar-only expressions that can be evaluated outside of the vector
        block.
    ``loop_invariant_dtypes`` : dict of (varname, dtypename)
        dtypename will be one of ``'boolean'``, ``'integer'``, ``'float'``.
    '''
    def __init__(self, variables, scalar_statements, extra_lio_prefix=''):
        angelaASTRenderer.__init__(self, variables, copy_variables=False)
        self.loop_invariants = OrderedDict()
        self.loop_invariant_dtypes = {}
        self.value = 0
        self.node_renderer = NodeRenderer()
        self.arithmetic_simplifier = ArithmeticSimplifier(variables)
        self.scalar_statements = scalar_statements
        if extra_lio_prefix is None:
            extra_lio_prefix = ''
        if len(extra_lio_prefix):
            extra_lio_prefix = extra_lio_prefix + '_'
        self.extra_lio_prefix = extra_lio_prefix

    def render_expr(self, expr):
        node = angela_ast(expr, self.variables)
        node = self.arithmetic_simplifier.render_node(node)
        node = self.render_node(node)
        return self.node_renderer.render_node(node)

    def render_node(self, node):
        '''
        Assumes that the node has already been fully processed by angelaASTRenderer
        '''
        # can we pull this out?
        if node.scalar and node.complexity > 0:
            expr = self.node_renderer.render_node(
                self.arithmetic_simplifier.render_node(node))
            if expr in self.loop_invariants:
                name = self.loop_invariants[expr]
            else:
                self.value += 1
                name = '_lio_' + self.extra_lio_prefix + str(self.value)
                self.loop_invariants[expr] = name
                self.loop_invariant_dtypes[name] = node.dtype
                numpy_dtype = {
                    'boolean': bool,
                    'integer': int,
                    'float': prefs.core.default_float_dtype
                }[node.dtype]
                self.variables[name] = AuxiliaryVariable(name,
                                                         dtype=numpy_dtype,
                                                         scalar=True)
            # None is the expression context, we don't use it so we just set to None
            newnode = ast.Name(name, None)
            newnode.scalar = True
            newnode.dtype = node.dtype
            newnode.complexity = 0
            newnode.stateless = node.stateless
            return newnode
        # otherwise, render node as usual
        return super(Simplifier, self).render_node(node)
def parse_synapse_generator(expr):
    '''
    Returns a parsed form of a synapse generator expression.

    The general form is:

    ``element for iteration_variable in iterator_func(...)``

    or

    ``element for iteration_variable in iterator_func(...) if if_expression``

    Returns a dictionary with keys:

    ``original_expression``
        The original expression as a string.
    ``element``
        As above, a string expression.
    ``iteration_variable``
        A variable name, as above.
    ``iterator_func``
        String. Either ``range`` or ``sample``.
    ``if_expression``
        String expression or ``None``.
    ``iterator_kwds``
        Dictionary of key/value pairs representing the keywords. See
        `handle_range` and `handle_sample`.
    '''
    nr = NodeRenderer()
    parse_error = ("Error parsing expression '%s'. Expression must have "
                   "generator syntax, for example 'k for k in range(i-10, "
                   "i+10)'." % expr)
    try:
        node = ast.parse('[%s]' % expr, mode='eval').body
    except Exception as e:
        raise SyntaxError(parse_error + " Error encountered was %s" % e)
    if _cname(node) != 'ListComp':
        raise SyntaxError(parse_error + " Expression is not a generator "
                          "expression.")
    element = node.elt
    if len(node.generators) != 1:
        raise SyntaxError(parse_error + " Generator expression must involve "
                          "only one iterator.")
    generator = node.generators[0]
    target = generator.target
    if _cname(target) != 'Name':
        raise SyntaxError(parse_error +
                          " Generator must iterate over a single "
                          "variable (not tuple, etc.).")
    iteration_variable = target.id
    iterator = generator.iter
    if _cname(iterator) != 'Call' or _cname(iterator.func) != 'Name':
        raise SyntaxError(parse_error + " Iterator expression must be one of "
                          "the supported functions: " +
                          str(list(iterator_function_handlers)))
    iterator_funcname = iterator.func.id
    if iterator_funcname not in iterator_function_handlers:
        raise SyntaxError(parse_error + " Iterator expression must be one of "
                          "the supported functions: " +
                          str(list(iterator_function_handlers)))
    if (getattr(iterator, 'starargs', None) is not None
            or getattr(iterator, 'kwargs', None) is not None):
        raise SyntaxError(parse_error + " Star arguments not supported.")
    args = []
    for argnode in iterator.args:
        args.append(nr.render_node(argnode))
    keywords = {}
    for kwdnode in iterator.keywords:
        keywords[kwdnode.arg] = nr.render_node(kwdnode.value)
    try:
        iterator_handler = iterator_function_handlers[iterator_funcname]
        iterator_kwds = iterator_handler(*args, **keywords)
    except SyntaxError as exc:
        raise SyntaxError(parse_error + " " + exc.msg)
    if len(generator.ifs) == 0:
        condition = ast.parse('True', mode='eval').body
    elif len(generator.ifs) > 1:
        raise SyntaxError(parse_error + " Generator must have at most one if "
                          "statement.")
    else:
        condition = generator.ifs[0]
    parsed = {
        'original_expression': expr,
        'element': nr.render_node(element),
        'iteration_variable': iteration_variable,
        'iterator_func': iterator_funcname,
        'iterator_kwds': iterator_kwds,
        'if_expression': nr.render_node(condition),
    }
    return parsed