def seteq(a, b, eq2=eq): """ Set Equality For example (1, 2, 3) set equates to (2, 1, 3) >>> from logpy import var, run, seteq >>> x = var() >>> run(0, x, seteq(x, (1, 2))) ((1, 2), (2, 1)) >>> run(0, x, seteq((2, 1, x), (3, 1, 2))) (3,) """ if isinstance(a, tuple) and isinstance(b, tuple): if set(a) == set(b): return success elif len(a) != len(b): return fail else: c, d = tuple(sorted(a)), tuple(sorted(b)) if len(c) == 1: return (eq2, c[0], d[0]) return condeseq(( ((eq2, c[i], d[0]), (seteq, c[0:i] + c[i+1:], d[1:], eq2)) for i in range(len(c)))) if isvar(a) and isvar(b): raise EarlyGoalError() if isvar(a) and isinstance(b, tuple): c, d = a, b if isvar(b) and isinstance(a, tuple): c, d = b, a return (condeseq, ([eq(c, perm)] for perm in it.permutations(d, len(d))))
def seteq(a, b, eq2=eq): """ Set Equality For example (1, 2, 3) set equates to (2, 1, 3) >>> from logpy import var, run, seteq >>> x = var() >>> run(0, x, seteq(x, (1, 2))) ((1, 2), (2, 1)) >>> run(0, x, seteq((2, 1, x), (3, 1, 2))) (3,) """ if isinstance(a, tuple) and isinstance(b, tuple): if set(a) == set(b): return success elif len(a) != len(b): return fail else: c, d = a, b if len(c) == 1: return (eq2, c[0], d[0]) return (conde,) + tuple( ((eq2, c[i], d[0]), (seteq, c[0:i] + c[i+1:], d[1:], eq2)) for i in range(len(c))) if isvar(a) and isvar(b): raise EarlyGoalError() if isvar(a) and isinstance(b, tuple): c, d = a, b if isvar(b) and isinstance(a, tuple): c, d = b, a return (condeseq, ([eq(c, perm)] for perm in it.permutations(d, len(d))))
def f(s): args2 = reify(args, s) subsets = [self.index[key] for key in enumerate(args) if key in self.index] if subsets: # we are able to reduce the pool early facts = intersection(*sorted(subsets, key=len)) else: facts = self.facts varinds = [i for i, arg in enumerate(args2) if isvar(arg)] valinds = [i for i, arg in enumerate(args2) if not isvar(arg)] vars = index(args2, varinds) vals = index(args2, valinds) assert not any(var in s for var in vars) return ( merge(dict(zip(vars, index(fact, varinds))), s) for fact in self.facts if vals == index(fact, valinds) )
def funco(inputs, out): if isvar(inputs): raise EarlyGoalError() else: if isinstance(inputs, (tuple, list)): return (eq, func(*inputs), out) else: return (eq, func(inputs), out)
def permuteq(a, b, eq2=eq): """ Equality under permutation For example (1, 2, 2) equates to (2, 1, 2) under permutation >>> from logpy import var, run, permuteq >>> x = var() >>> run(0, x, permuteq(x, (1, 2))) ((1, 2), (2, 1)) >>> run(0, x, permuteq((2, 1, x), (2, 1, 2))) (2,) """ if isinstance(a, tuple) and isinstance(b, tuple): if len(a) != len(b): return fail elif set(a) == set(b) and len(set(a)) == len(a): return success else: c, d = a, b try: c, d = tuple(sorted(c)), tuple(sorted(d)) except: pass if len(c) == 1: return (eq2, c[0], d[0]) return condeseq((((eq2, c[i], d[0]), (permuteq, c[0:i] + c[i + 1 :], d[1:], eq2)) for i in range(len(c)))) if isvar(a) and isvar(b): raise EarlyGoalError() if isvar(a) or isvar(b): if isinstance(b, tuple): c, d = a, b elif isinstance(a, tuple): c, d = b, a return (condeseq, ([eq(c, perm)] for perm in unique(it.permutations(d, len(d)))))
def seteq(a, b, eq2=eq): """ Set Equality For example (1, 2, 3) set equates to (2, 1, 3) >>> from logpy import var, run, seteq >>> x = var() >>> run(0, x, seteq(x, (1, 2))) ((1, 2), (2, 1)) >>> run(0, x, seteq((2, 1, x), (3, 1, 2))) (3,) """ ts = lambda x: tuple(set(x)) if not isvar(a) and not isvar(b): return permuteq(ts(a), ts(b), eq2) elif not isvar(a): return permuteq(ts(a), b, eq2) elif not isvar(b): return permuteq(a, ts(b), eq2) else: return permuteq(a, b, eq2) raise Exception()