Ejemplo n.º 1
0
def simplify_and(
        x: sympy.Basic,
        gen: typing.Optional[sympy.Symbol] = None,
        extra_conditions: typing.Optional[sympy.Basic] = True) -> sympy.Basic:
    """
  Some rules, because SymPy currently does not automatically simplify them...
  """
    assert isinstance(x, sympy.Basic), "type x: %r" % type(x)
    from sympy.solvers.inequalities import reduce_rational_inequalities
    from sympy.core.relational import Relational

    syms = []
    if gen is not None:
        syms.append(gen)

    w1 = sympy.Wild("w1")
    w2 = sympy.Wild("w2")
    for sub_expr in x.find(sympy.Eq(w1, w2)):
        m = sub_expr.match(sympy.Eq(w1, w2))
        ws_ = m[w1], m[w2]
        for w_ in ws_:
            if isinstance(w_, sympy.Symbol) and w_ not in syms:
                syms.append(w_)
    for w_ in x.free_symbols:
        if w_ not in syms:
            syms.append(w_)

    if len(syms) >= 1:
        _c = syms[0]
        if len(syms) >= 2:
            n = syms[1]
        else:
            n = sympy.Wild("n")
    else:
        return x

    x = x.replace(((_c - 2 * n >= -1) & (_c - 2 * n <= -1)),
                  sympy.Eq(_c, 2 * n - 1))  # probably not needed anymore...
    apply_rules = True
    while apply_rules:
        apply_rules = False
        for and_expr in x.find(sympy.And):
            assert isinstance(and_expr, sympy.And)

            and_expr_ = reduce_rational_inequalities([and_expr.args], _c)
            # print(and_expr, "->", and_expr_)
            if and_expr_ != and_expr:
                x = x.replace(and_expr, and_expr_)
                and_expr = and_expr_
                if and_expr == sympy.sympify(False):
                    continue
                if isinstance(and_expr, sympy.Rel):
                    continue
                assert isinstance(and_expr, sympy.And)

            and_expr_args = list(and_expr.args)
            # for i, part in enumerate(and_expr_args):
            #  and_expr_args[i] = part.simplify()
            if all([
                    isinstance(part, Relational) and _c in part.free_symbols
                    for part in and_expr_args
            ]):
                # No equality, as that should have been resolved above.
                rel_ops = ["==", ">=", "<="]
                if not (_c.is_Integer or _c.assumptions0["integer"]):
                    rel_ops.extend(["<", ">"])
                rhs_by_c = {op: [] for op in rel_ops}
                for part in and_expr_args:
                    assert isinstance(part, Relational)
                    part = _solve_inequality(part, _c)
                    assert isinstance(part, Relational)
                    assert part.lhs == _c
                    rel_op, rhs = part.rel_op, part.rhs
                    if _c.is_Integer or _c.assumptions0["integer"]:
                        if rel_op == "<":
                            rhs = rhs - 1
                            rel_op = "<="
                        elif rel_op == ">":
                            rhs = rhs + 1
                            rel_op = ">="
                    assert rel_op in rhs_by_c, "x: %r, _c: %r, and expr: %r, part %r" % (
                        x, _c, and_expr, part)
                    other_rhs = rhs_by_c[rel_op]
                    assert isinstance(other_rhs, list)
                    need_to_add = True
                    for rhs_ in other_rhs:
                        cmp = Relational.ValidRelationOperator[rel_op](rhs,
                                                                       rhs_)
                        if simplify_and(
                                sympy.And(sympy.Not(cmp),
                                          extra_conditions)) == sympy.sympify(
                                              False):  # checks True...
                            other_rhs.remove(rhs_)
                            break
                        elif simplify_and(sympy.And(
                                cmp,
                                extra_conditions)) == sympy.sympify(False):
                            need_to_add = False
                            break
                        # else:
                        #  raise NotImplementedError("cannot compare %r in %r; extra cond %r" % (cmp, and_expr, extra_conditions))
                    if need_to_add:
                        other_rhs.append(rhs)
                if rhs_by_c[">="] and rhs_by_c["<="]:
                    all_false = False
                    for lhs in rhs_by_c[">="]:
                        for rhs in rhs_by_c["<="]:
                            if sympy.Lt(lhs, rhs) == sympy.sympify(False):
                                all_false = True
                            if sympy.Eq(lhs, rhs) == sympy.sympify(True):
                                rhs_by_c["=="].append(lhs)
                    if all_false:
                        x = x.replace(and_expr, False)
                        continue
                if rhs_by_c["=="]:
                    all_false = False
                    while len(rhs_by_c["=="]) >= 2:
                        lhs, rhs = rhs_by_c["=="][:2]
                        if sympy.Eq(lhs, rhs) == sympy.sympify(False):
                            all_false = True
                            break
                        elif sympy.Eq(lhs, rhs) == sympy.sympify(True):
                            rhs_by_c["=="].pop(1)
                        else:
                            raise NotImplementedError(
                                "cannot cmp %r == %r. rhs_by_c %r" %
                                (lhs, rhs, rhs_by_c))
                    if all_false:
                        x = x.replace(and_expr, False)
                        continue
                    new_parts = [sympy.Eq(_c, rhs_by_c["=="][0])]
                    for op in rel_ops:
                        for part in rhs_by_c[op]:
                            new_parts.append(
                                Relational.ValidRelationOperator[op](
                                    rhs_by_c["=="][0], part).simplify())
                else:  # no "=="
                    new_parts = []
                    for op in rel_ops:
                        for part in rhs_by_c[op]:
                            new_parts.append(
                                Relational.ValidRelationOperator[op](_c, part))
                    assert new_parts
                and_expr_ = sympy.And(*new_parts)
                # print(and_expr, "--->", and_expr_)
                x = x.replace(and_expr, and_expr_)
                and_expr = and_expr_

            # Probably all the remaining hard-coded rules are not needed anymore with the more generic code above...
            if sympy.Eq(_c, 2 * n) in and_expr.args:
                if (_c - 2 * n <= -1) in and_expr.args:
                    x = x.replace(and_expr, False)
                    continue
                if sympy.Eq(_c - 2 * n, -1) in and_expr.args:
                    x = x.replace(and_expr, False)
                    continue
                if (_c - n <= -1) in and_expr.args:
                    x = x.replace(and_expr, False)
                    continue
            if (_c >= n) in and_expr.args and (_c - n <= -1) in and_expr.args:
                x = x.replace(and_expr, False)
                continue
            if sympy.Eq(_c - 2 * n, -1) in and_expr.args:  # assume n>=1
                if (_c >= n) in and_expr.args:
                    x = x.replace(
                        and_expr,
                        sympy.And(
                            *
                            [arg for arg in and_expr.args
                             if arg != (_c >= n)]))
                    apply_rules = True
                    break
                if (_c - n >= -1) in and_expr.args:
                    x = x.replace(
                        and_expr,
                        sympy.And(*[
                            arg for arg in and_expr.args
                            if arg != (_c - n >= -1)
                        ]))
                    apply_rules = True
                    break
            if (_c >= n) in and_expr.args:
                if (_c - n >= -1) in and_expr.args:
                    x = x.replace(
                        and_expr,
                        sympy.And(*[
                            arg for arg in and_expr.args
                            if arg != (_c - n >= -1)
                        ]))
                    apply_rules = True
                    break
            if (_c - n >= -1) in and_expr.args and (_c - n <=
                                                    -1) in and_expr.args:
                args = list(and_expr.args)
                args.remove((_c - n >= -1))
                args.remove((_c - n <= -1))
                args.append(sympy.Eq(_c - n, -1))
                if (_c - 2 * n <= -1) in args:
                    args.remove((_c - 2 * n <= -1))
                x = x.replace(and_expr, sympy.And(*args))
                apply_rules = True
                break
    return x