def test_minus_1(self): """walk_minus should not create nested Plus nodes""" x = Symbol("x", INT) y = Symbol("y", INT) oldx = Symbol("oldx", INT) m_1 = Int(-1) i_2 = Int(2) i_4 = Int(4) src = Times(i_2, oldx) src = Plus(src, x) src = Minus(src, Times(i_4, y)) src = Times(m_1, src) td = TimesDistributor() res = td.walk(src) self.assertValid(Equals(src, res)) # root is Plus. self.assertTrue(res.is_plus(), "Expeted summation, got: {}".format(res)) # no other Plus in children: only Times of symbs and constants. stack = list(res.args()) while stack: curr = stack.pop() if curr.is_times(): stack.extend(curr.args()) else: self.assertTrue(curr.is_symbol() or curr.is_constant(), "Expected leaf, got: {}".format(res))
def test_minus_0(self): x = Symbol("x", INT) y = Symbol("y", INT) i_0 = Int(0) src = Plus(x, y) src = Minus(x, src) src = LT(src, i_0) td = TimesDistributor() res = td.walk(src) self.assertValid(Iff(src, res))
def test_times_distributivity_smtlib_nra(self): from pysmt.test.smtlib.parser_utils import formulas_from_smtlib_test_set test_set = formulas_from_smtlib_test_set(logics=[QF_LRA, QF_NRA]) for (_, fname, f, _) in test_set: td = TimesDistributor() _ = td.walk(f) for (old, new) in td.memoization.items(): if not old.is_times(): continue if old is new: continue # Nothing changed self.assertValid(Equals(old, new), (old, new), solver_name="z3")
def test_times_distributivity(self): r = Symbol("r", REAL) s = Symbol("s", REAL) td = TimesDistributor() f = Times(Plus(r, Real(1)), Real(3)) fp = td.walk(f) self.assertValid(Equals(f, fp), (f, fp)) f = Times(Plus(r, Real(1)), s) fp = td.walk(f) self.assertValid(Equals(f, fp), (f, fp)) f = Times(Plus(r, Real(1), s), Real(3)) fp = td.walk(f) self.assertValid(Equals(f, fp), (f, fp)) f = Times(Minus(r, Real(1)), Real(3)) fp = td.walk(f) self.assertValid(Equals(f, fp), (f, fp)) f = Times(Minus(r, Real(1)), s) fp = td.walk(f) self.assertValid(Equals(f, fp), (f, fp)) f = Times(Minus(Real(1), s), Real(3)) fp = td.walk(f) self.assertValid(Equals(f, fp), (f, fp)) f = Times(Minus(r, Real(1)), Plus(r, s)) fp = td.walk(f) self.assertValid(Equals(f, fp), (f, fp)) # (r + 1) * (s-1) = r*s + (-r) + s - 1 f = Times(Plus(r, Real(1)), Minus(s, Real(1))) fp = td.walk(f).simplify() target = Plus(Times(r, s), Times(r, Real(-1)), s, Real(-1)) self.assertValid(Equals(fp, target), fp) self.assertTrue(fp.is_plus(), fp)