def simplify_and( x: sympy.Basic, gen: typing.Optional[sympy.Symbol] = None, extra_conditions: typing.Optional[sympy.Basic] = True) -> sympy.Basic: """ Some rules, because SymPy currently does not automatically simplify them... """ assert isinstance(x, sympy.Basic), "type x: %r" % type(x) from sympy.solvers.inequalities import reduce_rational_inequalities from sympy.core.relational import Relational syms = [] if gen is not None: syms.append(gen) w1 = sympy.Wild("w1") w2 = sympy.Wild("w2") for sub_expr in x.find(sympy.Eq(w1, w2)): m = sub_expr.match(sympy.Eq(w1, w2)) ws_ = m[w1], m[w2] for w_ in ws_: if isinstance(w_, sympy.Symbol) and w_ not in syms: syms.append(w_) for w_ in x.free_symbols: if w_ not in syms: syms.append(w_) if len(syms) >= 1: _c = syms[0] if len(syms) >= 2: n = syms[1] else: n = sympy.Wild("n") else: return x x = x.replace(((_c - 2 * n >= -1) & (_c - 2 * n <= -1)), sympy.Eq(_c, 2 * n - 1)) # probably not needed anymore... apply_rules = True while apply_rules: apply_rules = False for and_expr in x.find(sympy.And): assert isinstance(and_expr, sympy.And) and_expr_ = reduce_rational_inequalities([and_expr.args], _c) # print(and_expr, "->", and_expr_) if and_expr_ != and_expr: x = x.replace(and_expr, and_expr_) and_expr = and_expr_ if and_expr == sympy.sympify(False): continue if isinstance(and_expr, sympy.Rel): continue assert isinstance(and_expr, sympy.And) and_expr_args = list(and_expr.args) # for i, part in enumerate(and_expr_args): # and_expr_args[i] = part.simplify() if all([ isinstance(part, Relational) and _c in part.free_symbols for part in and_expr_args ]): # No equality, as that should have been resolved above. rel_ops = ["==", ">=", "<="] if not (_c.is_Integer or _c.assumptions0["integer"]): rel_ops.extend(["<", ">"]) rhs_by_c = {op: [] for op in rel_ops} for part in and_expr_args: assert isinstance(part, Relational) part = _solve_inequality(part, _c) assert isinstance(part, Relational) assert part.lhs == _c rel_op, rhs = part.rel_op, part.rhs if _c.is_Integer or _c.assumptions0["integer"]: if rel_op == "<": rhs = rhs - 1 rel_op = "<=" elif rel_op == ">": rhs = rhs + 1 rel_op = ">=" assert rel_op in rhs_by_c, "x: %r, _c: %r, and expr: %r, part %r" % ( x, _c, and_expr, part) other_rhs = rhs_by_c[rel_op] assert isinstance(other_rhs, list) need_to_add = True for rhs_ in other_rhs: cmp = Relational.ValidRelationOperator[rel_op](rhs, rhs_) if simplify_and( sympy.And(sympy.Not(cmp), extra_conditions)) == sympy.sympify( False): # checks True... other_rhs.remove(rhs_) break elif simplify_and(sympy.And( cmp, extra_conditions)) == sympy.sympify(False): need_to_add = False break # else: # raise NotImplementedError("cannot compare %r in %r; extra cond %r" % (cmp, and_expr, extra_conditions)) if need_to_add: other_rhs.append(rhs) if rhs_by_c[">="] and rhs_by_c["<="]: all_false = False for lhs in rhs_by_c[">="]: for rhs in rhs_by_c["<="]: if sympy.Lt(lhs, rhs) == sympy.sympify(False): all_false = True if sympy.Eq(lhs, rhs) == sympy.sympify(True): rhs_by_c["=="].append(lhs) if all_false: x = x.replace(and_expr, False) continue if rhs_by_c["=="]: all_false = False while len(rhs_by_c["=="]) >= 2: lhs, rhs = rhs_by_c["=="][:2] if sympy.Eq(lhs, rhs) == sympy.sympify(False): all_false = True break elif sympy.Eq(lhs, rhs) == sympy.sympify(True): rhs_by_c["=="].pop(1) else: raise NotImplementedError( "cannot cmp %r == %r. rhs_by_c %r" % (lhs, rhs, rhs_by_c)) if all_false: x = x.replace(and_expr, False) continue new_parts = [sympy.Eq(_c, rhs_by_c["=="][0])] for op in rel_ops: for part in rhs_by_c[op]: new_parts.append( Relational.ValidRelationOperator[op]( rhs_by_c["=="][0], part).simplify()) else: # no "==" new_parts = [] for op in rel_ops: for part in rhs_by_c[op]: new_parts.append( Relational.ValidRelationOperator[op](_c, part)) assert new_parts and_expr_ = sympy.And(*new_parts) # print(and_expr, "--->", and_expr_) x = x.replace(and_expr, and_expr_) and_expr = and_expr_ # Probably all the remaining hard-coded rules are not needed anymore with the more generic code above... if sympy.Eq(_c, 2 * n) in and_expr.args: if (_c - 2 * n <= -1) in and_expr.args: x = x.replace(and_expr, False) continue if sympy.Eq(_c - 2 * n, -1) in and_expr.args: x = x.replace(and_expr, False) continue if (_c - n <= -1) in and_expr.args: x = x.replace(and_expr, False) continue if (_c >= n) in and_expr.args and (_c - n <= -1) in and_expr.args: x = x.replace(and_expr, False) continue if sympy.Eq(_c - 2 * n, -1) in and_expr.args: # assume n>=1 if (_c >= n) in and_expr.args: x = x.replace( and_expr, sympy.And( * [arg for arg in and_expr.args if arg != (_c >= n)])) apply_rules = True break if (_c - n >= -1) in and_expr.args: x = x.replace( and_expr, sympy.And(*[ arg for arg in and_expr.args if arg != (_c - n >= -1) ])) apply_rules = True break if (_c >= n) in and_expr.args: if (_c - n >= -1) in and_expr.args: x = x.replace( and_expr, sympy.And(*[ arg for arg in and_expr.args if arg != (_c - n >= -1) ])) apply_rules = True break if (_c - n >= -1) in and_expr.args and (_c - n <= -1) in and_expr.args: args = list(and_expr.args) args.remove((_c - n >= -1)) args.remove((_c - n <= -1)) args.append(sympy.Eq(_c - n, -1)) if (_c - 2 * n <= -1) in args: args.remove((_c - 2 * n <= -1)) x = x.replace(and_expr, sympy.And(*args)) apply_rules = True break return x