Beispiel #1
0
 def __init__(self, variables, scalar_statements):
     BrianASTRenderer.__init__(self, variables)
     self.loop_invariants = OrderedDict()
     self.loop_invariant_dtypes = {}
     self.n = 0
     self.node_renderer = NodeRenderer(use_vectorisation_idx=False)
     self.arithmetic_simplifier = ArithmeticSimplifier(variables)
     self.scalar_statements = scalar_statements
Beispiel #2
0
 def __init__(self, variables, scalar_statements, extra_lio_prefix=''):
     BrianASTRenderer.__init__(self, variables, copy_variables=False)
     self.loop_invariants = OrderedDict()
     self.loop_invariant_dtypes = {}
     self.n = 0
     self.node_renderer = NodeRenderer(use_vectorisation_idx=False)
     self.arithmetic_simplifier = ArithmeticSimplifier(variables)
     self.scalar_statements = scalar_statements
     if extra_lio_prefix is None:
         extra_lio_prefix = ''
     if len(extra_lio_prefix):
         extra_lio_prefix = extra_lio_prefix+'_'
     self.extra_lio_prefix = extra_lio_prefix
Beispiel #3
0
 def __init__(self, variables, scalar_statements, extra_lio_prefix=''):
     BrianASTRenderer.__init__(self, variables, copy_variables=False)
     self.loop_invariants = OrderedDict()
     self.loop_invariant_dtypes = {}
     self.n = 0
     self.node_renderer = NodeRenderer(use_vectorisation_idx=False)
     self.arithmetic_simplifier = ArithmeticSimplifier(variables)
     self.scalar_statements = scalar_statements
     if extra_lio_prefix is None:
         extra_lio_prefix = ''
     if len(extra_lio_prefix):
         extra_lio_prefix = extra_lio_prefix + '_'
     self.extra_lio_prefix = extra_lio_prefix
Beispiel #4
0
 def __init__(self, variables):
     BrianASTRenderer.__init__(self, variables, copy_variables=False)
     self.assumptions = []
     self.assumptions_ns = dict(defaults_ns)
     self.bast_renderer = BrianASTRenderer(variables, copy_variables=False)
Beispiel #5
0
class ArithmeticSimplifier(BrianASTRenderer):
    '''
    Carries out the following arithmetic simplifications:

    1. Constant evaluation (e.g. exp(0)=1) by attempting to evaluate the expression in an "assumptions namespace"
    2. Binary operators, e.g. 0*x=0, 1*x=x, etc. You have to take care that the dtypes match here, e.g.
       if x is an integer, then 1.0*x shouldn't be replaced with x but left as 1.0*x.

    Parameters
    ----------
    variables : dict of (str, Variable)
        Usual definition of variables.
    assumptions : sequence of str
        Additional assumptions that can be used in simplification, each assumption is a string statement.
        These might be the scalar statements for example.
    '''
    def __init__(self, variables):
        BrianASTRenderer.__init__(self, variables, copy_variables=False)
        self.assumptions = []
        self.assumptions_ns = dict(defaults_ns)
        self.bast_renderer = BrianASTRenderer(variables, copy_variables=False)

    def render_node(self, node):
        '''
        Assumes that the node has already been fully processed by BrianASTRenderer
        '''
        if not hasattr(node, 'simplified'):
            node = super(ArithmeticSimplifier, self).render_node(node)
            node.simplified = True
        # can't evaluate vector expressions, so abandon in this case
        if not node.scalar:
            return node
        # No evaluation necessary for simple names or numbers
        if node.__class__.__name__ in ['Name', 'NameConstant', 'Num']:
            return node
        # Don't evaluate stateful nodes (e.g. those containing a rand() call)
        if not node.stateless:
            return node
        # try fully evaluating using assumptions
        expr = NodeRenderer().render_node(node)
        val, evaluated = evaluate_expr(expr, self.assumptions_ns)
        if evaluated:
            if node.dtype == 'boolean':
                val = bool(val)
                if hasattr(ast, 'NameConstant'):
                    newnode = ast.NameConstant(val)
                else:
                    # None is the expression context, we don't use it so we just set to None
                    newnode = ast.Name(repr(val), None)
            elif node.dtype == 'integer':
                val = int(val)
            else:
                val = float(val)
            if node.dtype != 'boolean':
                newnode = ast.Num(val)
            newnode.dtype = node.dtype
            newnode.scalar = True
            newnode.stateless = node.stateless
            newnode.complexity = 0
            return newnode
        return node

    def render_BinOp(self, node):
        if node.dtype == 'float': # only try to collect float type nodes
            if node.op.__class__.__name__ in ['Mult', 'Div', 'Add', 'Sub'] and not hasattr(node, 'collected'):
                newnode = self.bast_renderer.render_node(collect(node))
                newnode.collected = True
                return self.render_node(newnode)
        left = node.left = self.render_node(node.left)
        right = node.right = self.render_node(node.right)
        node = super(ArithmeticSimplifier, self).render_BinOp(node)
        op = node.op
        # Handle multiplication by 0 or 1
        if op.__class__.__name__ == 'Mult':
            for operand, other in [(left, right),
                                   (right, left)]:
                if operand.__class__.__name__ == 'Num':
                    if operand.n == 0:
                        # Do not remove stateful functions
                        if node.stateless:
                            return _replace_with_zero(operand, node)
                    if operand.n==1:
                        # only simplify this if the type wouldn't be cast by the operation
                        if dtype_hierarchy[operand.dtype] <= dtype_hierarchy[other.dtype]:
                            return other
        # Handle division by 1, or 0/x
        elif op.__class__.__name__ == 'Div':
            if left.__class__.__name__ == 'Num' and left.n == 0:  # 0/x
                if node.stateless:
                    # Do not remove stateful functions
                    return _replace_with_zero(left, node)
            if right.__class__.__name__ == 'Num' and right.n == 1:  # x/1
                # only simplify this if the type wouldn't be cast by the operation
                if dtype_hierarchy[right.dtype] <= dtype_hierarchy[left.dtype]:
                    return left
        # Handle addition of 0
        elif op.__class__.__name__ == 'Add':
            for operand, other in [(left, right),
                                   (right, left)]:
                if operand.__class__.__name__ == 'Num' and operand.n == 0:
                    # only simplify this if the type wouldn't be cast by the operation
                    if dtype_hierarchy[operand.dtype]<=dtype_hierarchy[other.dtype]:
                        return other
        # Handle subtraction of 0
        elif op.__class__.__name__ == 'Sub':
            if right.__class__.__name__ == 'Num' and right.n == 0:
                # only simplify this if the type wouldn't be cast by the operation
                if dtype_hierarchy[right.dtype]<=dtype_hierarchy[left.dtype]:
                    return left

        # simplify e.g. 2*float to 2.0*float to make things more explicit: not strictly necessary
        # but might be useful for some codegen targets
        if node.dtype=='float' and op.__class__.__name__ in ['Mult', 'Add', 'Sub', 'Div']:
            for subnode in [node.left, node.right]:
                if subnode.__class__.__name__ == 'Num':
                    subnode.dtype = 'float'
                    subnode.n = float(subnode.n)
        return node
Beispiel #6
0
 def __init__(self, variables):
     BrianASTRenderer.__init__(self, variables, copy_variables=False)
     self.assumptions = []
     self.assumptions_ns = dict(defaults_ns)
     self.bast_renderer = BrianASTRenderer(variables, copy_variables=False)
Beispiel #7
0
class ArithmeticSimplifier(BrianASTRenderer):
    '''
    Carries out the following arithmetic simplifications:

    1. Constant evaluation (e.g. exp(0)=1) by attempting to evaluate the expression in an "assumptions namespace"
    2. Binary operators, e.g. 0*x=0, 1*x=x, etc. You have to take care that the dtypes match here, e.g.
       if x is an integer, then 1.0*x shouldn't be replaced with x but left as 1.0*x.

    Parameters
    ----------
    variables : dict of (str, Variable)
        Usual definition of variables.
    assumptions : sequence of str
        Additional assumptions that can be used in simplification, each assumption is a string statement.
        These might be the scalar statements for example.
    '''
    def __init__(self, variables):
        BrianASTRenderer.__init__(self, variables, copy_variables=False)
        self.assumptions = []
        self.assumptions_ns = dict(defaults_ns)
        self.bast_renderer = BrianASTRenderer(variables, copy_variables=False)

    def render_node(self, node):
        '''
        Assumes that the node has already been fully processed by BrianASTRenderer
        '''
        if not hasattr(node, 'simplified'):
            node = super(ArithmeticSimplifier, self).render_node(node)
            node.simplified = True
        # can't evaluate vector expressions, so abandon in this case
        if not node.scalar:
            return node
        # No evaluation necessary for simple names or numbers
        if node.__class__.__name__ in ['Name', 'NameConstant', 'Num']:
            return node
        # Don't evaluate stateful nodes (e.g. those containing a rand() call)
        if not node.stateless:
            return node
        # try fully evaluating using assumptions
        expr = NodeRenderer().render_node(node)
        val, evaluated = evaluate_expr(expr, self.assumptions_ns)
        if evaluated:
            if node.dtype == 'boolean':
                val = bool(val)
                if hasattr(ast, 'NameConstant'):
                    newnode = ast.NameConstant(val)
                else:
                    # None is the expression context, we don't use it so we just set to None
                    newnode = ast.Name(repr(val), None)
            elif node.dtype == 'integer':
                val = int(val)
            else:
                val = prefs.core.default_float_dtype(val)
            if node.dtype != 'boolean':
                newnode = ast.Num(val)
            newnode.dtype = node.dtype
            newnode.scalar = True
            newnode.stateless = node.stateless
            newnode.complexity = 0
            return newnode
        return node

    def render_BinOp(self, node):
        if node.dtype == 'float':  # only try to collect float type nodes
            if node.op.__class__.__name__ in [
                    'Mult', 'Div', 'Add', 'Sub'
            ] and not hasattr(node, 'collected'):
                newnode = self.bast_renderer.render_node(collect(node))
                newnode.collected = True
                return self.render_node(newnode)
        left = node.left = self.render_node(node.left)
        right = node.right = self.render_node(node.right)
        node = super(ArithmeticSimplifier, self).render_BinOp(node)
        op = node.op
        # Handle multiplication by 0 or 1
        if op.__class__.__name__ == 'Mult':
            for operand, other in [(left, right), (right, left)]:
                if operand.__class__.__name__ == 'Num':
                    if operand.n == 0:
                        # Do not remove stateful functions
                        if node.stateless:
                            return _replace_with_zero(operand, node)
                    if operand.n == 1:
                        # only simplify this if the type wouldn't be cast by the operation
                        if dtype_hierarchy[operand.dtype] <= dtype_hierarchy[
                                other.dtype]:
                            return other
        # Handle division by 1, or 0/x
        elif op.__class__.__name__ == 'Div':
            if left.__class__.__name__ == 'Num' and left.n == 0:  # 0/x
                if node.stateless:
                    # Do not remove stateful functions
                    return _replace_with_zero(left, node)
            if right.__class__.__name__ == 'Num' and right.n == 1:  # x/1
                # only simplify this if the type wouldn't be cast by the operation
                if dtype_hierarchy[right.dtype] <= dtype_hierarchy[left.dtype]:
                    return left
        elif op.__class__.__name__ == 'FloorDiv':
            if left.__class__.__name__ == 'Num' and left.n == 0:  # 0//x
                if node.stateless:
                    # Do not remove stateful functions
                    return _replace_with_zero(left, node)
            # Only optimise floor division by 1 if both numbers are integers,
            # for floating point values, floor division by 1 changes the value,
            # and division by 1.0 can change the type for an integer value
            if (left.dtype == right.dtype == 'integer'
                    and right.__class__.__name__ == 'Num'
                    and right.n == 1):  # x//1
                return left
        # Handle addition of 0
        elif op.__class__.__name__ == 'Add':
            for operand, other in [(left, right), (right, left)]:
                if operand.__class__.__name__ == 'Num' and operand.n == 0:
                    # only simplify this if the type wouldn't be cast by the operation
                    if dtype_hierarchy[operand.dtype] <= dtype_hierarchy[
                            other.dtype]:
                        return other
        # Handle subtraction of 0
        elif op.__class__.__name__ == 'Sub':
            if right.__class__.__name__ == 'Num' and right.n == 0:
                # only simplify this if the type wouldn't be cast by the operation
                if dtype_hierarchy[right.dtype] <= dtype_hierarchy[left.dtype]:
                    return left

        # simplify e.g. 2*float to 2.0*float to make things more explicit: not strictly necessary
        # but might be useful for some codegen targets
        if node.dtype == 'float' and op.__class__.__name__ in [
                'Mult', 'Add', 'Sub', 'Div'
        ]:
            for subnode in [node.left, node.right]:
                if subnode.__class__.__name__ == 'Num':
                    subnode.dtype = 'float'
                    subnode.n = prefs.core.default_float_dtype(subnode.n)
        return node