Example #1
0
def replace_second_order_products(expr: sp.Expr, search_symbols: Iterable[sp.Symbol],
                                  positive: Optional[bool] = None,
                                  replace_mixed: Optional[List[Assignment]] = None) -> sp.Expr:
    """Replaces second order mixed terms like 4*x*y by 2*( (x+y)**2 - x**2 - y**2 ).

    This makes the term longer - simplify usually is undoing these - however this
    transformation can be done to find more common sub-expressions

    Args:
        expr: input expression
        search_symbols: symbols that are searched for
                         for example, given [x,y,z] terms like x*y, x*z, z*y are replaced
        positive: there are two ways to do this substitution, either with term
                 (x+y)**2 or (x-y)**2 . if positive=True the first version is done,
                 if positive=False the second version is done, if positive=None the
                 sign is determined by the sign of the mixed term that is replaced
        replace_mixed: if a list is passed here, the expr x+y or x-y is replaced by a special new symbol
                       and the replacement equation is added to the list
    """
    mixed_symbols_replaced = set([e.lhs for e in replace_mixed]) if replace_mixed is not None else set()

    if expr.is_Mul:
        distinct_search_symbols = set()
        nr_of_search_terms = 0
        other_factors = sp.Integer(1)
        for t in expr.args:
            if t in search_symbols:
                nr_of_search_terms += 1
                distinct_search_symbols.add(t)
            else:
                other_factors *= t
        if len(distinct_search_symbols) == 2 and nr_of_search_terms == 2:
            u, v = sorted(list(distinct_search_symbols), key=lambda symbol: symbol.name)
            if positive is None:
                other_factors_without_symbols = other_factors
                for s in other_factors.atoms(sp.Symbol):
                    other_factors_without_symbols = other_factors_without_symbols.subs(s, 1)
                positive = other_factors_without_symbols.is_positive
                assert positive is not None
            sign = 1 if positive else -1
            if replace_mixed is not None:
                new_symbol_str = 'P' if positive else 'M'
                mixed_symbol_name = u.name + new_symbol_str + v.name
                mixed_symbol = sp.Symbol(mixed_symbol_name.replace("_", ""))
                if mixed_symbol not in mixed_symbols_replaced:
                    mixed_symbols_replaced.add(mixed_symbol)
                    replace_mixed.append(Assignment(mixed_symbol, u + sign * v))
            else:
                mixed_symbol = u + sign * v
            return sp.Rational(1, 2) * sign * other_factors * (mixed_symbol ** 2 - u ** 2 - v ** 2)

    param_list = [replace_second_order_products(a, search_symbols, positive, replace_mixed) for a in expr.args]
    result = expr.func(*param_list, evaluate=False) if param_list else expr
    return result