def render_Num(self, node): node.complexity = 0 node.dtype = brian_dtype_from_value(get_node_value(node)) node.scalar = True node.stateless = True return node
def render_BinOp(self, node): if node.dtype == 'float': # only try to collect float type nodes if node.op.__class__.__name__ in [ 'Mult', 'Div', 'Add', 'Sub' ] and not hasattr(node, 'collected'): newnode = self.bast_renderer.render_node(collect(node)) newnode.collected = True return self.render_node(newnode) left = node.left = self.render_node(node.left) right = node.right = self.render_node(node.right) node = super(ArithmeticSimplifier, self).render_BinOp(node) op = node.op # Handle multiplication by 0 or 1 if op.__class__.__name__ == 'Mult': for operand, other in [(left, right), (right, left)]: if operand.__class__.__name__ in ['Num', 'Constant']: op_value = get_node_value(operand) if op_value == 0: # Do not remove stateful functions if node.stateless: return _replace_with_zero(operand, node) if op_value == 1: # only simplify this if the type wouldn't be cast by the operation if dtype_hierarchy[operand.dtype] <= dtype_hierarchy[ other.dtype]: return other # Handle division by 1, or 0/x elif op.__class__.__name__ == 'Div': if left.__class__.__name__ in [ 'Num', 'Constant' ] and get_node_value(left) == 0: # 0/x if node.stateless: # Do not remove stateful functions return _replace_with_zero(left, node) if right.__class__.__name__ in [ 'Num', 'Constant' ] and get_node_value(right) == 1: # x/1 # only simplify this if the type wouldn't be cast by the operation if dtype_hierarchy[right.dtype] <= dtype_hierarchy[left.dtype]: return left elif op.__class__.__name__ == 'FloorDiv': if left.__class__.__name__ in [ 'Num', 'Constant' ] and get_node_value(left) == 0: # 0//x if node.stateless: # Do not remove stateful functions return _replace_with_zero(left, node) # Only optimise floor division by 1 if both numbers are integers, # for floating point values, floor division by 1 changes the value, # and division by 1.0 can change the type for an integer value if (left.dtype == right.dtype == 'integer' and right.__class__.__name__ in ['Num', 'Constant'] and get_node_value(right) == 1): # x//1 return left # Handle addition of 0 elif op.__class__.__name__ == 'Add': for operand, other in [(left, right), (right, left)]: if operand.__class__.__name__ in [ 'Num', 'Constant' ] and get_node_value(operand) == 0: # only simplify this if the type wouldn't be cast by the operation if dtype_hierarchy[operand.dtype] <= dtype_hierarchy[ other.dtype]: return other # Handle subtraction of 0 elif op.__class__.__name__ == 'Sub': if right.__class__.__name__ in ['Num', 'Constant' ] and get_node_value(right) == 0: # only simplify this if the type wouldn't be cast by the operation if dtype_hierarchy[right.dtype] <= dtype_hierarchy[left.dtype]: return left # simplify e.g. 2*float to 2.0*float to make things more explicit: not strictly necessary # but might be useful for some codegen targets if node.dtype == 'float' and op.__class__.__name__ in [ 'Mult', 'Add', 'Sub', 'Div' ]: for subnode in [node.left, node.right]: if (subnode.__class__.__name__ in ['Num', 'Constant'] and not (get_node_value(subnode) is True or get_node_value(subnode) is False)): subnode.dtype = 'float' subnode.value = prefs.core.default_float_dtype( get_node_value(subnode)) return node