def reflected_binary_operator(op): """ Factory function for making binary operator methods on a Factor. Returns a function, "reflected_binary_operator" suitable for implementing functions like __radd__. """ assert not is_comparison(op) @preprocess(other=numbers_to_float64) @with_name(method_name_for_op(op, commute=True)) def reflected_binary_operator(self, other): if isinstance(self, NumericalExpression): self_expr, other_expr, new_inputs = self.build_binary_op(op, other) return NumExprFactor("({left}) {op} ({right})".format( left=other_expr, right=self_expr, op=op, ), new_inputs, dtype=binop_return_dtype( op, other.dtype, self.dtype)) # Only have to handle the numeric case because in all other valid cases # the corresponding left-binding method will be called. elif isinstance(other, Number): return NumExprFactor( "{constant} {op} x_0".format(op=op, constant=other), binds=(self, ), dtype=binop_return_dtype(op, other.dtype, self.dtype), ) raise BadBinaryOperator(op, other, self) return reflected_binary_operator
def reflected_binary_operator(op): """ Factory function for making binary operator methods on a Factor. Returns a function, "reflected_binary_operator" suitable for implementing functions like __radd__. """ assert not is_comparison(op) def reflected_binary_operator(self, other): if isinstance(self, NumericalExpression): self_expr, other_expr, new_inputs = self.build_binary_op( op, other ) return NumExprFactor( "({left}) {op} ({right})".format( left=other_expr, right=self_expr, op=op, ), new_inputs, ) # Only have to handle the numeric case because in all other valid cases # the corresponding left-binding method will be called. elif isinstance(other, Number): return NumExprFactor( "{constant} {op} x_0".format(op=op, constant=other), binds=(self,), ) raise BadBinaryOperator(op, other, self) return reflected_binary_operator
def binop_return_dtype(op, left, right): """ Compute the expected return dtype for the given binary operator. Parameters ---------- op : str Operator symbol, (e.g. '+', '-', ...). left : numpy.dtype Dtype of left hand side. right : numpy.dtype Dtype of right hand side. Returns ------- outdtype : numpy.dtype The dtype of the result of `left <op> right`. """ if is_comparison(op): if left != right: raise TypeError( "Don't know how to compute {left} {op} {right}.\n" "Comparisons are only supported between Factors of equal " "dtypes.".format(left=left, op=op, right=right) ) return bool_dtype elif left != float64_dtype or right != float64_dtype: raise TypeError( "Don't know how to compute {left} {op} {right}.\n" "Arithmetic operators are only supported on Factors of " "dtype 'float64'.".format( left=left.name, op=op, right=right.name, ) ) return float64_dtype
def binop_return_dtype(op, left, right): """ Compute the expected return dtype for the given binary operator. Parameters ---------- op : str Operator symbol, (e.g. '+', '-', ...). left : numpy.dtype Dtype of left hand side. right : numpy.dtype Dtype of right hand side. Returns ------- outdtype : numpy.dtype The dtype of the result of `left <op> right`. """ if is_comparison(op): if left != right: raise TypeError( "Don't know how to compute {left} {op} {right}.\n" "Comparisons are only supported between Factors of equal " "dtypes.".format(left=left, op=op, right=right) ) return bool_dtype elif left != float64_dtype or right != float64_dtype: raise TypeError( "Don't know how to compute {left} {op} {right}.\n" "Arithmetic operators are only supported between Factors of " "dtype 'float64'.".format( left=left.name, op=op, right=right.name, ) ) return float64_dtype
def binop_return_type(op): if is_comparison(op): return NumExprFilter else: return NumExprFactor