def remove_conditionals(f):
    """ Return a copy of f where conditionals are replaced by & and |. """
    if f.op == OP_IMPLIES:
        p, q = f.args[0], f.args[1]
        return (remove_conditionals(~p) | remove_conditionals(q))
    elif f.op == OP_IMPLIED_BY:
        p, q = f.args[0], f.args[1]
        return (remove_conditionals(p) | remove_conditionals(~q))
    elif f.op == OP_EQUIVALENT:
        p, q = f.args[0], f.args[1]
        return (remove_conditionals(p) & remove_conditionals(q) |
                remove_conditionals(~p) & remove_conditionals(~q))
    elif is_quantified(f):
        return Quantifier(f.op, f.vars, remove_conditionals(f.args[0]))
    else:
        return Expr(f.op, *[remove_conditionals(x) for x in f.args])
def push_neg(f):
    """ Return a copy of f that has negation as close to terms as possible.
    Unlike nnf(), push_neg() leaves conditionals as they are.

    """
    if f.op == OP_NOT:
        arg = f.args[0]
        if arg.op == OP_NOT:
            return push_neg(arg.args[0])
        elif arg.op == OP_AND:
            return Expr(OP_OR, *[push_neg(Expr(OP_NOT, x)) for x in arg.args])
        elif arg.op == OP_OR:
            return Expr(OP_AND, *[push_neg(Expr(OP_NOT, x)) for x in arg.args])
        elif arg.op == OP_IMPLIES: # p -> q
            p, q = arg.args[0], arg.args[1]
            return ~(push_neg(p) >> push_neg(q))
        elif arg.op == OP_IMPLIED_BY: # p <- q
            p, q = arg.args[0], arg.args[1]
            return ~(push_neg(p) << push_neg(q))
        elif arg.op == OP_EQUIVALENT: # p <-> q
            p, q = arg.args[0], arg.args[1]
            return ~(push_neg(p) ** push_neg(q))
        elif arg.op == OP_FORALL:
            return Quantifier(OP_EXISTS, arg.vars,
                              push_neg(Expr(OP_NOT, arg.args[0])))
        elif arg.op == OP_EXISTS:
            return Quantifier(OP_FORALL, arg.vars,
                              push_neg(Expr(OP_NOT, arg.args[0])))
        else:
            return f
    elif f.op == OP_AND or f.op == OP_OR:
            return Expr(f.op, *[push_neg(x) for x in f.args])
    elif f.op == OP_IMPLIES:
        p, q = f.args[0], f.args[1]
        return (push_neg(p) >> push_neg(q))
    elif f.op == OP_IMPLIED_BY:
        p, q = f.args[0], f.args[1]
        return (push_neg(p) << push_neg(q))
    elif f.op == OP_EQUIVALENT:
        p, q = f.args[0], f.args[1]
        return (push_neg(p) ** push_neg(q))
    elif is_quantified(f):
        return Quantifier(f.op, f.vars, push_neg(f.args[0]))
    else:
        return f
 def helper(f, used, substs):
     if is_variable(f):
         if f in substs: return substs[f]
         else: return f
     if is_quantified(f):
         # does this quantifier use a variable that is already used?
         clashing = (used & set(f.vars)) # set intersection
         if len(clashing) > 0:
             for var in clashing:
                 existing = used.union(*list(map(vars, substs.values())))
                 # rename any clashing variable
                 var2 = variant(var, existing)
                 substs[var] = var2
             used.update(f.vars)
             arg = helper(f.args[0], used, substs)
             return Quantifier(f.op,
                               [subst(substs, x) for x in f.vars],
                               *[subst(substs, x) for x in f.args])
         else:
             used.update(f.vars)
             arg = helper(f.args[0], used, substs)
             return Quantifier(f.op, f.vars, arg)
     else:
         return Expr(f.op, *[helper(x, used, substs) for x in f.args])
 def drop(f):
     if is_quantified(f):
         return drop(f.args[0])
     else:
         return f
    def pushquants(f):
    #    print('pushquant({0})'.format(f))
        if is_quantified(f):
            arg = f.args[0]
            if arg.op == OP_NOT:
                return ~pushquants(Quantifier(opposite(f.op), f.vars, arg.args[0]))
            elif is_quantified(arg):
                arg1 = pushquants(arg)
                # did pushquants have any effect?
                if arg1 == arg:
#                    # if no, see if changing the order of the quants has an effect
#                    arg2 = Quantifier(f.op, f.vars, arg.args[0])
#                    arg2_ = pushquants(arg2)
#                    if arg2 == arg2_:
#                        return Quantifier(f.op, f.vars, arg1)
#                    else:
#                        return Quantifier(arg.op, arg.vars, arg2_)
                    return Quantifier(f.op, f.vars, arg1)
                else:
                    return pushquants(Quantifier(f.op, f.vars, arg1))
            elif arg.op in [OP_AND, OP_OR]:
                arg1 = arg.args[0]
                arg2 = arg.args[1]
                variable = f.vars[0]
                if variable in fvars(arg1) and variable in fvars(arg2):
                    if ((arg.op == OP_AND and f.op == OP_FORALL) or
                        (arg.op == OP_OR and f.op == OP_EXISTS)):
                        return pushquants(Expr(arg.op,
                                          Quantifier(f.op, f.vars, arg1),
                                          Quantifier(f.op, f.vars, arg2)))
                    else:
                        return Quantifier(f.op, f.vars, pushquants(arg))
                if variable in fvars(arg1):
                    return pushquants(Expr(arg.op,
                                      Quantifier(f.op, f.vars, arg1), arg2))
                if variable in fvars(arg2):
                    return pushquants(Expr(arg.op, arg1,
                                           Quantifier(f.op, f.vars, arg2)))
                else:
                    get_log().warning('Dropped quantifier "{0} : {1}" from "{2}"'
                                      .format(f.op, f.vars, f))
                    return arg
            elif arg.op == OP_IMPLIES:
                arg1 = arg.args[0]
                arg2 = arg.args[1]
                variable = f.vars[0]
                if variable in fvars(arg1) and variable in fvars(arg2):
                    return Quantifier(f.op, f.vars, pushquants(arg))
                elif variable in fvars(arg1):
                    quant = opposite(f.op)
                    return (pushquants(Quantifier(quant, variable, arg1)) >>
                            pushquants(arg2))
                elif variable in fvars(arg2):
                    return (pushquants(arg1) >>
                            pushquants(Quantifier(f.op, variable, arg2)))
                else:
                    get_log().warning('Dropped quantifier "{0} : {1}" from "{2}"'
                                      .format(f.op, f.vars, f))
                    return (pushquants(arg1) >> pushquants(arg2))
            elif arg.op == OP_IMPLIED_BY:
                arg1 = arg.args[0]
                arg2 = arg.args[1]
                variable = f.vars[0]
                if variable in fvars(arg1) and variable in fvars(arg2):
                    return Quantifier(f.op, f.vars, pushquants(arg))
                elif variable in fvars(arg1):
                    return (pushquants(Quantifier(f.op, variable, arg1)) <<
                            pushquants(arg2))
                elif variable in fvars(arg2):
                    quant = opposite(f.op)
                    return (pushquants(arg1) <<
                            pushquants(Quantifier(quant, variable, arg2)))
                else:
                    get_log().warning('Dropped quantifier "{0} : {1}" from "{2}"'
                                      .format(f.op, f.vars, f))
                    return (pushquants(arg1) << pushquants(arg2))
            else:
                return Quantifier(f.op, f.vars[0], pushquants(f.args[0]))
        elif f.op == OP_NOT:
            return ~pushquants(f.args[0])
        elif f.op in BINARY_LOGIC_OPS:
            return Expr(f.op, *[pushquants(x) for x in f.args])
        else:
            return f
 def pullquants(f):
 #    print('pullquants({0})'.format(str(f)))
     if f.op == OP_AND:
         arg1 = pullquants(f.args[0])
         arg2 = pullquants(f.args[1])
         if arg1.op == OP_FORALL and arg2.op == OP_FORALL:
             return rename_and_pull(f, OP_FORALL,
                                    arg1.vars[0], arg2.vars[0],
                                    arg1.args[0], arg2.args[0])
         elif is_quantified(arg1):
             return rename_and_pull(f, arg1.op,
                                    arg1.vars[0], None,
                                    arg1.args[0], arg2)
         elif is_quantified(arg2):
             return rename_and_pull(f, arg2.op,
                                    None, arg2.vars[0],
                                    arg1, arg2.args[0])
         else:
             return (arg1 & arg2)
     elif f.op == OP_OR:
         arg1 = pullquants(f.args[0])
         arg2 = pullquants(f.args[1])
         if arg1.op == OP_EXISTS and arg2.op == OP_EXISTS:
             return rename_and_pull(f, OP_EXISTS,
                                    arg1.vars[0], arg2.vars[0],
                                    arg1.args[0], arg2.args[0])
         elif is_quantified(arg1):
             return rename_and_pull(f, arg1.op,
                                    arg1.vars[0], None,
                                    arg1.args[0], arg2)
         elif is_quantified(arg2):
             return rename_and_pull(f, arg2.op,
                                    None, arg2.vars[0],
                                    arg1, arg2.args[0])
         else:
             return (arg1 | arg2)
     elif f.op == OP_IMPLIES:
         arg1 = pullquants(f.args[0])
         arg2 = pullquants(f.args[1])
         if is_quantified(arg1):
             return rename_and_pull(f, opposite(arg1.op),
                                    arg1.vars[0], None,
                                    arg1.args[0], arg2)
         elif is_quantified(arg2):
             return rename_and_pull(f, arg2.op,
                                    None, arg2.vars[0],
                                    arg1, arg2.args[0])
         else:
             return (arg1 >> arg2)
     elif f.op == OP_IMPLIED_BY:
         arg1 = pullquants(f.args[0])
         arg2 = pullquants(f.args[1])
         if is_quantified(arg1):
             return rename_and_pull(f, arg1.op,
                                    arg1.vars[0], None,
                                    arg1.args[0], arg2)
         elif is_quantified(arg2):
             return rename_and_pull(f, opposite(arg2.op),
                                    None, arg2.vars[0],
                                    arg1, arg2.args[0])
         else:
             return (arg1 << arg2)
     elif f.op == OP_EQUIVALENT:
         arg1 = pullquants(f.args[0])
         arg2 = pullquants(f.args[1])
         return pullquants((arg1 >> arg2) & (arg1 << arg2))
     elif f.op == OP_NOT:
         arg = pullquants(f.args[0])
         if is_quantified(arg):
             return Quantifier(opposite(arg.op), f.vars, ~(arg.args[0]))
         else:
             return (~arg)
     else:
         return f
def kleene(f):
    """ Take a FOL formula and try to simplify it. """
#    print('\nSimplifying op "{0}" args "{1}"'.format(f.op, f.args))
    if is_quantified(f): # remove variable that are not in f
        vars = set(f.vars)
        fv = fvars(f.args[0])
        needed = vars & fv
        if needed == set():
            return kleene(f.args[0])
        else:
            arg = kleene(f.args[0])
            variables = [v for v in f.vars if v in needed] # keep the same order
            if (arg != f.args[0]) or (len(needed) < len(f.vars)):
                return kleene(Quantifier(f.op, variables, arg))
            else:
                return f
    elif f.op == OP_NOT:
        arg = f.args[0]
        if arg.op == OP_NOT: # double neg
            return kleene(arg.args[0])
        elif is_true(arg): # -TRUE --> FALSE
            return Expr(OP_FALSE)
        elif is_false(arg): # -FALSE --> TRUE
            return Expr(OP_TRUE)
        elif arg.op == OP_EQUALS:
            return Expr(OP_NOTEQUALS,
                        kleene(arg.args[0]),
                        kleene(arg.args[1]))
        elif arg.op == OP_NOTEQUALS:
            return Expr(OP_EQUALS,
                        kleene(arg.args[0]),
                        kleene(arg.args[1]))
        else:
            arg2 = kleene(arg)
            if arg2 != arg:
                return kleene(Expr(OP_NOT, arg2))
            return f
    elif f.op == OP_AND:
        if any(map(is_false, f.args)): # if one conjuct is FALSE, expr is FALSE
            return Expr(OP_FALSE)
        elif len(f.args) == 1:
            return kleene(f.args[0])
        else: # remove conjuncts that are TRUE and simplify args
            args = list(map(kleene, filter(lambda x: not is_true(x),f.args)))
            used = set()
            unique = []
            for arg in args:
                if arg not in used:
                    used.add(arg)
                    unique.append(arg)
            if args != unique or f.args != unique:
                return kleene(Expr(OP_AND, *unique))
            else:
                return f
    elif f.op == OP_OR:
        if any(map(is_true, f.args)): # if one conjuct is TRUE, expr is TRUE
            return Expr(OP_TRUE)
        elif len(f.args) == 1:
            return kleene(f.args[0])
        else: # remove conjuncts that are FALSE and simplify args
            args = list(map(kleene, filter(lambda x: not is_false(x),f.args)))
            used = set()
            unique = []
            for arg in args:
                if arg not in used:
                    used.add(arg)
                    unique.append(arg)
            if args != unique or f.args != unique:
                return kleene(Expr(OP_OR, *unique))
            else:
                return f
    elif f.op == OP_IMPLIES:
        if is_false(f.args[0]) or is_true(f.args[1]):
            return Expr(OP_TRUE)
        elif is_true(f.args[0]):
            return kleene(f.args[1])
        elif is_false(f.args[1]):
            return kleene(Expr(OP_NOT, kleene(f.args[0])))
        else:
            arg1 = kleene(f.args[0])
            arg2 = kleene(f.args[1])
            if arg1 != f.args[0] or arg2 != f.args[1]:
                return kleene(Expr(OP_IMPLIES, arg1, arg2))
            else:
                return f
    elif f.op == OP_IMPLIED_BY:
        if is_false(f.args[1]) or is_true(f.args[0]):
            return Expr(OP_TRUE)
        elif is_true(f.args[1]):
            return kleene(f.args[0])
        elif is_false(f.args[0]):
            return kleene(Expr(OP_NOT, kleene(f.args[1])))
        else:
            arg1 = kleene(f.args[0])
            arg2 = kleene(f.args[1])
            if arg1 != f.args[0] or arg2 != f.args[1]:
                return kleene(Expr(OP_IMPLIES, arg1, arg2))
            else:
                return f
    elif f.op == OP_EQUIVALENT:
        if is_true(f.args[0]):
            return kleene(f.args[1])
        elif is_true(f.args[1]):
            return kleene(f.args[0])
        elif is_false(f.args[0]):
            return kleene(Expr(OP_NOT, kleene(f.args[1])))
        elif is_false(f.args[1]):
            return kleene(Expr(OP_NOT, kleene(f.args[0])))
        else:
            arg1 = kleene(f.args[0])
            arg2 = kleene(f.args[1])
            if arg1 != f.args[0] or arg2 != f.args[1]:
                return kleene(Expr(OP_EQUIVALENT, arg1, arg2))
            else:
                return f
    else:
        return f