def replace_second_order_products(expr: sp.Expr, search_symbols: Iterable[sp.Symbol], positive: Optional[bool] = None, replace_mixed: Optional[List[Assignment]] = None) -> sp.Expr: """Replaces second order mixed terms like 4*x*y by 2*( (x+y)**2 - x**2 - y**2 ). This makes the term longer - simplify usually is undoing these - however this transformation can be done to find more common sub-expressions Args: expr: input expression search_symbols: symbols that are searched for for example, given [x,y,z] terms like x*y, x*z, z*y are replaced positive: there are two ways to do this substitution, either with term (x+y)**2 or (x-y)**2 . if positive=True the first version is done, if positive=False the second version is done, if positive=None the sign is determined by the sign of the mixed term that is replaced replace_mixed: if a list is passed here, the expr x+y or x-y is replaced by a special new symbol and the replacement equation is added to the list """ mixed_symbols_replaced = set([e.lhs for e in replace_mixed]) if replace_mixed is not None else set() if expr.is_Mul: distinct_search_symbols = set() nr_of_search_terms = 0 other_factors = sp.Integer(1) for t in expr.args: if t in search_symbols: nr_of_search_terms += 1 distinct_search_symbols.add(t) else: other_factors *= t if len(distinct_search_symbols) == 2 and nr_of_search_terms == 2: u, v = sorted(list(distinct_search_symbols), key=lambda symbol: symbol.name) if positive is None: other_factors_without_symbols = other_factors for s in other_factors.atoms(sp.Symbol): other_factors_without_symbols = other_factors_without_symbols.subs(s, 1) positive = other_factors_without_symbols.is_positive assert positive is not None sign = 1 if positive else -1 if replace_mixed is not None: new_symbol_str = 'P' if positive else 'M' mixed_symbol_name = u.name + new_symbol_str + v.name mixed_symbol = sp.Symbol(mixed_symbol_name.replace("_", "")) if mixed_symbol not in mixed_symbols_replaced: mixed_symbols_replaced.add(mixed_symbol) replace_mixed.append(Assignment(mixed_symbol, u + sign * v)) else: mixed_symbol = u + sign * v return sp.Rational(1, 2) * sign * other_factors * (mixed_symbol ** 2 - u ** 2 - v ** 2) param_list = [replace_second_order_products(a, search_symbols, positive, replace_mixed) for a in expr.args] result = expr.func(*param_list, evaluate=False) if param_list else expr return result