def cosine_law_crit_angle(self): """Use cosine law to find cos^2(theta) between three points node1---node2---node3 to assert that it is less than cos^2(thetaC) where thetaC is the critical crossing angle :param node1: Outside node :param node2: Middle connecting node :param node3: Outside node :returns: cos^2 as calculated using cosine law (a_dot_b^2/a^2*b^2) """ node1 = self.get_input_nodes().values()[0] node2 = self.get_input_nodes().values()[1] node3 = self.get_output_node() # Lengths of channels aX = Minus(node1.get_x(), node2.get_x()) aY = Minus(node1.get_y(), node2.get_y()) bX = Minus(node3.get_x(), node2.get_x()) bY = Minus(node3.get_y(), node2.get_y()) # Dot products between each channel a_dot_b_squared = Pow(Plus(Times(aX, bX), Times(aY, bY)), Real(2)) a_squared_b_squared = Times( Plus(Times(aX, aX), Times(aY, aY)), Plus(Times(bX, bX), Times(bY, bY)), ) return Div(a_dot_b_squared, a_squared_b_squared)
def test_substitution_on_functions(self): i, r = FreshSymbol(INT), FreshSymbol(REAL) f = Symbol("f", FunctionType(BOOL, [INT, REAL])) phi = Function(f, [Plus(i, Int(1)), Minus(r, Real(2))]) phi_sub = substitute(phi, {i: Int(0)}).simplify() self.assertEqual(phi_sub, Function(f, [Int(1), Minus(r, Real(2))])) phi_sub = substitute(phi, {r: Real(0)}).simplify() self.assertEqual(phi_sub, Function(f, [Plus(i, Int(1)), Real(-2)])) phi_sub = substitute(phi, {r: Real(0), i: Int(0)}).simplify() self.assertEqual(phi_sub, Function(f, [Int(1), Real(-2)]))
def test_lira(self): varA = Symbol("A", REAL) varB = Symbol("B", INT) with self.assertRaises(PysmtTypeError): f = And(LT(varA, Plus(varA, Real(1))), GT(varA, Minus(varB, Int(1)))) f = And(LT(varA, Plus(varA, Real(1))), GT(varA, ToReal(Minus(varB, Int(1))))) g = Equals(varA, ToReal(varB)) self.assertUnsat(And(f, g, Equals(varA, Real(1.2))), "Formula was expected to be unsat", logic=QF_UFLIRA)
def test_misc(self): bool_list = [ And(self.x, self.y), Or(self.x, self.y), Not(self.x), self.x, Equals(self.p, self.q), GE(self.p, self.q), LE(self.p, self.q), GT(self.p, self.q), LT(self.p, self.q), Bool(True), Ite(self.x, self.y, self.x) ] # TODO: FORALL EXISTS real_list = [ self.r, Real(4), Plus(self.r, self.s), Plus(self.r, Real(2)), Minus(self.s, self.r), Times(self.r, Real(1)), Div(self.r, Real(1)), Ite(self.x, self.r, self.s), ] int_list = [ self.p, Int(4), Plus(self.p, self.q), Plus(self.p, Int(2)), Minus(self.p, self.q), Times(self.p, Int(1)), Ite(self.x, self.p, self.q), ] for f in bool_list: t = self.tc.walk(f) self.assertEqual(t, BOOL, f) for f in real_list: t = self.tc.walk(f) self.assertEqual(t, REAL, f) for f in int_list: t = self.tc.walk(f) self.assertEqual(t, INT, f)
def calculate_droplet_volume(self, h, w, wIn, epsilon, qD, qC): """From paper DOI:10.1039/c002625e. Calculating the droplet volume created in a T-junction Unit is volume in m^3 :param Symbol h: Height of channel :param Symbol w: Width of continuous/output channel :param Symbol wIn: Width of dispersed_channel :param Symbol epsilon: Equals 0.414*radius of rounded edge where channels join :param Symbol qD: Flow rate in dispersed_channel :param Symbol qC: Flow rate in continuous_channel """ q_gutter = Real(0.1) # normalizedVFill = 3pi/8 - (pi/2)(1 - pi/4)(h/w) v_fill_simple = Minus( Times(Real((3, 8)), Real(math.pi)), Times( Times(Div(Real(math.pi), Real(2)), Minus(Real(1), Div(Real(math.pi), Real(4)))), Div(h, w))) hw_parallel = Div(Times(h, w), Plus(h, w)) # r_pinch = w+((wIn-(hw_parallel - eps))+sqrt(2*((wIn-hw_parallel)*(w-hw_parallel)))) r_pinch = Plus( w, Plus( Minus(wIn, Minus(hw_parallel, epsilon)), Pow( Times( Real(2), Times(Minus(wIn, hw_parallel), Minus(w, hw_parallel))), Real(0.5)))) r_fill = w alpha = Times( Minus(Real(1), Div(Real(math.pi), Real(4))), Times( Pow(Minus(Real(1), q_gutter), Real(-1)), Plus( Minus(Pow(Div(r_pinch, w), Real(2)), Pow(Div(r_fill, w), Real(2))), Times( Div(Real(math.pi), Real(4)), Times(Minus(Div(r_pinch, w), Div(r_fill, w)), Div(h, w)))))) return Times(Times(h, Times(w, w)), Plus(v_fill_simple, Times(alpha, Div(qD, qC))))
def test_int(self): p, q = Symbol("p", INT), Symbol("q", INT) f = Or(Equals(Times(p, Int(5)), Minus(p, q)), LT(p, q), LE(Int(6), Int(1))) f_string = self.print_to_string(f) self.assertEqual(f_string, "(or (= (* p 5) (- p q)) (< p q) (<= 6 1))")
def test_minus_1(self): """walk_minus should not create nested Plus nodes""" x = Symbol("x", INT) y = Symbol("y", INT) oldx = Symbol("oldx", INT) m_1 = Int(-1) i_2 = Int(2) i_4 = Int(4) src = Times(i_2, oldx) src = Plus(src, x) src = Minus(src, Times(i_4, y)) src = Times(m_1, src) td = TimesDistributor() res = td.walk(src) self.assertValid(Equals(src, res)) # root is Plus. self.assertTrue(res.is_plus(), "Expeted summation, got: {}".format(res)) # no other Plus in children: only Times of symbs and constants. stack = list(res.args()) while stack: curr = stack.pop() if curr.is_times(): stack.extend(curr.args()) else: self.assertTrue(curr.is_symbol() or curr.is_constant(), "Expected leaf, got: {}".format(res))
def add_relu_simplex_friendly_eager(self): # Eager lemma encoding for relus zero = Real(0) for relu_out, relu_in in self.relus: #Introduce f = relu_out - relu_in f = FreshSymbol(REAL) self.formulae.append(Equals(f, Minus(relu_out, relu_in))) self.formulae.append(Implies(GT(relu_in, zero), Equals(f, zero))) self.formulae.append( Implies(LE(relu_in, zero), Equals(relu_out, zero)))
def test_minus_0(self): x = Symbol("x", INT) y = Symbol("y", INT) i_0 = Int(0) src = Plus(x, y) src = Minus(x, src) src = LT(src, i_0) td = TimesDistributor() res = td.walk(src) self.assertValid(Iff(src, res))
def test_eq(self): varA = Symbol("At", INT) varB = Symbol("Bt", INT) f = And(LT(varA, Plus(varB, Int(1))), GT(varA, Minus(varB, Int(1)))) g = Equals(varA, varB) self.assertValid(Iff(f, g), "Formulae were expected to be equivalent", logic=QF_LIA)
def test_boolean(self): varA = Symbol("At", INT) varB = Symbol("Bt", INT) f = And(LT(varA, Plus(varB, Int(1))), GT(varA, Minus(varB, Int(1)))) g = Equals(varA, varB) h = Iff(f, g) tc = get_env().stc res = tc.walk(h) self.assertEqual(res, BOOL)
def test_int(self): p, q = Symbol("p", INT), Symbol("q", INT) f = Or(Equals(Times(p, Int(5)), Minus(p, q)), LT(p, q), LE(Int(6), Int(1))) self.assertEqual(f.to_smtlib(daggify=False), "(or (= (* p 5) (- p q)) (< p q) (<= 6 1))") self.assertEqual( f.to_smtlib(daggify=True), "(let ((.def_0 (<= 6 1))) (let ((.def_1 (< p q))) (let ((.def_2 (- p q))) (let ((.def_3 (* p 5))) (let ((.def_4 (= .def_3 .def_2))) (let ((.def_5 (or .def_4 .def_1 .def_0))) .def_5))))))" )
def test_qe_eq(self): qe = QuantifierEliminator(logic=LRA) varA = Symbol("A", BOOL) varB = Symbol("B", BOOL) varAt = Symbol("At", REAL) varBt = Symbol("Bt", REAL) f = And(Iff(varA, GE(Minus(varAt, varBt), Real(0))), Iff(varB, LT(Minus(varAt, varBt), Real(1)))) qf = Exists([varBt, varA], f) r1 = qe.eliminate_quantifiers(qf) try: self.assertValid(Iff(r1, qf), logic=LRA, msg="The two formulas should be equivalent.") except SolverReturnedUnknownResultError: pass
def main(): x,y = [Symbol(n, REAL) for n in "xy"] f_sat = Implies(And(GT(y, Real(0)), LT(y, Real(10))), LT(Minus(y, Times(x, Real(2))), Real(7))) f_incomplete = And(GT(x, Real(0)), LE(x, Real(10)), Implies(And(GT(y, Real(0)), LE(y, Real(10)), Not(Equals(x, y))), GT(y, x))) run_test([y], f_sat) run_test([y], f_incomplete)
def test_times_distributivity(self): r = Symbol("r", REAL) s = Symbol("s", REAL) td = TimesDistributor() f = Times(Plus(r, Real(1)), Real(3)) fp = td.walk(f) self.assertValid(Equals(f, fp), (f, fp)) f = Times(Plus(r, Real(1)), s) fp = td.walk(f) self.assertValid(Equals(f, fp), (f, fp)) f = Times(Plus(r, Real(1), s), Real(3)) fp = td.walk(f) self.assertValid(Equals(f, fp), (f, fp)) f = Times(Minus(r, Real(1)), Real(3)) fp = td.walk(f) self.assertValid(Equals(f, fp), (f, fp)) f = Times(Minus(r, Real(1)), s) fp = td.walk(f) self.assertValid(Equals(f, fp), (f, fp)) f = Times(Minus(Real(1), s), Real(3)) fp = td.walk(f) self.assertValid(Equals(f, fp), (f, fp)) f = Times(Minus(r, Real(1)), Plus(r, s)) fp = td.walk(f) self.assertValid(Equals(f, fp), (f, fp)) # (r + 1) * (s-1) = r*s + (-r) + s - 1 f = Times(Plus(r, Real(1)), Minus(s, Real(1))) fp = td.walk(f).simplify() target = Plus(Times(r, s), Times(r, Real(-1)), s, Real(-1)) self.assertValid(Equals(fp, target), fp) self.assertTrue(fp.is_plus(), fp)
def calculate_channel_output_pressure(self): """Calculate the pressure at the output of a channel using P_out = R * Q - P_in Unit for pressure is Pascals - kg/(m*s^2) :param str channel_name: Name of the channel :returns: SMT expression of the difference between pressure into the channel and R*Q """ P_in = self.get_port_in().get_pressure() R = self.get_resistance() Q = self.get_flow_rate() return Minus(P_in, Times(R, Q))
def add_relu_simplex_friendly_OA(self): zero = Real(0) for relu_out, relu_in in self.relus: #Introduce f = relu_out - relu_in f = FreshSymbol(REAL) self.formulae.append(Equals(f, Minus(relu_out, relu_in))) # MAX abstraction self.formulae.append(GE(f, zero)) self.formulae.append(GE(relu_out, zero)) # MAX - case based upper bound self.formulae.append(Implies(GT(relu_in, zero), LE(f, zero))) self.formulae.append(Implies(LE(relu_in, zero), LE(relu_out, zero)))
def test_infix_extended(self): p, r, x, y = self.p, self.r, self.x, self.y get_env().enable_infix_notation = True self.assertEqual(Plus(p, Int(1)), p + 1) self.assertEqual(Plus(r, Real(1)), r + 1) self.assertEqual(Times(r, Real(1)), r * 1) self.assertEqual(Minus(p, Int(1)), p - 1) self.assertEqual(Minus(r, Real(1)), r - 1) self.assertEqual(Times(r, Real(1)), r * 1) self.assertEqual(Plus(r, Real(1.5)), r + 1.5) self.assertEqual(Minus(r, Real(1.5)), r - 1.5) self.assertEqual(Times(r, Real(1.5)), r * 1.5) self.assertEqual(Plus(r, Real(1.5)), 1.5 + r) self.assertEqual(Times(r, Real(1.5)), 1.5 * r) with self.assertRaises(TypeError): foo = p + 1.5 self.assertEqual(Not(x), ~x) self.assertEqual(Times(r, Real(-1)), -r) self.assertEqual(Times(p, Int(-1)), -p) self.assertEqual(Xor(x, y), x ^ y) self.assertEqual(And(x, y), x & y) self.assertEqual(Or(x, y), x | y) self.assertEqual(Or(x, TRUE()), x | True) self.assertEqual(Or(x, TRUE()), True | x) self.assertEqual(And(x, TRUE()), x & True) self.assertEqual(And(x, TRUE()), True & x) get_env().enable_infix_notation = False
def test_subst(self): varA = Symbol("At", INT) varB = Symbol("Bt", INT) f = And(LT(varA, Plus(varB, Int(1))), GT(varA, Minus(varB, Int(1)))) g = Equals(varA, varB) h = Iff(f, g) res = substitute(h, subs={varA: varB}) self.assertEqual(res, h.substitute({varA: varB})) res = substitute(h, subs={varA: Int(1)}) self.assertEqual(res, h.substitute({varA: Int(1)}))
def ISE(model1, model2, seed, sample_count, engine='pa'): assert(set(model1.get_vars()) == set(model2.get_vars())),\ "M1 vars: {}\n M2 vars: {}".format(model1.get_vars(),model2.get_vars()) support1, weightfun1 = model1.support, model1.weightfun support2, weightfun2 = model2.support, model2.weightfun support_d = Or(support1, support2) weight_d = Ite(And(support1, support2), Times(Minus(weightfun1, weightfun2), Minus(weightfun1, weightfun2)), Ite(support1, Times(weightfun1, weightfun1), Times(weightfun2, weightfun2))) domain, _ = merged_domain(model1, model2) if engine == 'pa': engine = PredicateAbstractionEngine(domain, support_d, weight_d) result = solver.compute_volume() elif engine == 'rej': result = None solver = RejectionEngine(domain, support_d, weight_d, sample_count=sample_count, seed=seed) while result is None: #logger.debug("Attempting with sample_count {}".format( #solver.sample_count)) result = solver.compute_volume() solver.sample_count *= 2 else: raise NotImplementedError() return result
def pythagorean_length(self): # TODO: How do I test this!! """Use Pythagorean theorem to assert that the channel length (hypoteneuse) squared is equal to the legs squared so channel length is solved for :param str channel_name: Name of the channel :returns: SMT expression of the equality of the side lengths squared and the channel length squared """ port_in = self.get_port_in(self) port_out = self.get_port_out(self) side_a = Minus(port_in.get_x(), port_in.get_x()) side_b = Minus(port_in.get_y(), port_in.get_y()) a_squared = Pow(side_a, Real(2)) b_squared = Pow(side_b, Real(2)) a_squared_plus_b_squared = Plus(a_squared, b_squared) c_squared = Pow(self.get_length(), Real(2)) return Equals(a_squared_plus_b_squared, c_squared)
def channels_in_straight_line(self): """Create expressions to assert that 2 channels are in a straight line with each other by asserting that a triangle between the 2 end nodes and the middle node has zero area :returns: Expression asserting area of triangle formed between all three nodes to be 0 """ # TODO: will this be an issue with classes! # Check that these nodes connect # try: # self.dg[node1_name][node2_name] # self.dg[node2_name][node3_name] # except TypeError as e: # raise TypeError("Tried asserting that 2 channels are in a straight\ # line but they aren't connected") node1 = self.get_input_nodes().values()[0] node2 = self.get_input_nodes().values()[1] node3 = self.get_output_node() # Constrain that continuous and output ports are in a straight line by # setting the area of the triangle formed between those two points and # the center of the t-junct to be 0 # Formula for area of a triangle given 3 points # x_i (y_p − y_j ) + x_p (y_j − y_i ) + x_j (y_i − y_p ) / 2 return Equals( Real(0), Div( Plus( Times(node1.get_x(), Minus(node3.get_y(), node2.get_y())), Plus( Times(node3.get_x(), Minus(node2.get_y(), node1.get_y())), Times(node2.get_x(), Minus(node1.get_y(), node3.get_y())))), Real(2)))
def simple_pressure_flow(self): """Assert difference in pressure at the two end nodes for a channel equals the flow rate in the channel times the channel resistance More complicated calculation available through analytical_pressure_flow method (TBD) :param str channel_name: Name of the channel :returns: SMT expression of equality between delta(P) and Q*R """ port_in_name = self.get_port_in().get_name() port_in = self.get_port_in() port_out = self.get_port_out() p1 = port_in.get_pressure() p2 = port_out.get_pressure() Q = self.get_flow_rate() R = self.get_resistance() return Equals(Minus(p1, p2), Times(Q, R))
def virtual_link_constraints(constraints, frameSet, vlinkSet, g): ''' 生产者虚帧先于网络链路帧先于消费者的虚帧 参数: solver:求解器实例 frameSet: 帧集合(vlid_t -> link -> [frame],两层字典) vlinkSet: 包含虚链路信息的集合 g: 模型中的时间同步精度,单位:us ''' # 首先取出同一条虚链路的所有帧 for vlid_t in frameSet: frameSameVLink = frameSet[vlid_t] # 取出VLink类中的链路集合 vl = vlinkSet[vlid_t].vl # 任意两条不同的相邻的Link满足: for i in range(len(vl) - 1): # 取出前一个集合的最后一个帧和后一个集合的第一个帧 linkI = vl[i] linkJ = vl[i + 1] frameListLinkI = frameSameVLink[linkI] frameListLinkJ = frameSameVLink[linkJ] # 最后一个帧 frameLastLinkI = frameListLinkI[len(frameListLinkI) - 1] # 第一个帧 frameFirstLinkJ = frameListLinkJ[0] # aCon = (linkJ.macrotick * frameFirstLinkJ.offset - linkI.delay - # g >= linkI.macrotick * (frameLastLinkI.offset + frameLastLinkI.L)) aCon = GE( Minus(Times(Int(linkJ.macrotick), frameFirstLinkJ.offset), Int(linkI.delay + g)), Times(Int(linkI.macrotick), Plus(frameLastLinkI.offset, Int(frameLastLinkI.L)))) #print(aCon) constraints.append(aCon) # print(solver) #time_start = time.time() # print(solver.check()) # print(solver.model()) #time_end = time.time() # print('#计算用时:{} s'.format(time_end-time_start)) # pdb.set_trace() return True
def test_precedences(self): p = HRParser() a, b, c = (Symbol(v) for v in "abc") x, y = (Symbol(v, REAL) for v in "xy") tests = [] tests.append(("a | b & c", Or(a, And(b, c)))) tests.append(("a & b | c", Or(And(a, b), c))) f1 = LE(Plus(Plus(x, y), Real(5)), Real(7)) f2 = LE(Plus(x, Real(5)), Minus(Real(7), y)) tests.append(("x + y + 5.0 <= 7.0", f1)) tests.append(("x + 5.0 <= 7.0 - y", f2)) tests.append(("x + y + 5.0 <= 7.0 & x + 5.0 <= 7.0 - y", And(f1, f2))) tests.append( ("x + 5.0 <= 7.0 - y | x + y + 5.0 <= 7.0 & x + 5.0 <= 7.0 - y", Or(f2, And(f1, f2)))) for string, formula in tests: self.assertEqual(p.parse(string), formula) self.assertEqual(parse(string), formula)
def calculate_resistance(self): """Calculate the droplet resistance in a channel using: R = (12 * mu * L) / (w * h^3 * (1 - 0.630 (h/w)) ) This formula assumes that channel height < width, so the first term returned is the assertion for that Unit for resistance is kg/(m^4*s) :param str channel_name: Name of the channel :returns: list -- two SMT expressions, first asserts that channel height is less than width, second is the above expression in SMT form """ w = self.get_width() h = self.get_height() mu = self.get_viscosity() chL = self.get_length() return (LT(h, w), Div( Times(Real(12), Times(mu, chL)), Times( w, Times(Pow(h, Real(3)), Minus(Real(1), Times(Real(0.63), Div(h, w)))))))
def get_full_example_formulae(environment=None): """Return a list of Examples using the given environment.""" if environment is None: environment = get_env() with environment: x = Symbol("x", BOOL) y = Symbol("y", BOOL) p = Symbol("p", INT) q = Symbol("q", INT) r = Symbol("r", REAL) s = Symbol("s", REAL) aii = Symbol("aii", ARRAY_INT_INT) ari = Symbol("ari", ArrayType(REAL, INT)) arb = Symbol("arb", ArrayType(REAL, BV8)) abb = Symbol("abb", ArrayType(BV8, BV8)) nested_a = Symbol("a_arb_aii", ArrayType(ArrayType(REAL, BV8), ARRAY_INT_INT)) rf = Symbol("rf", FunctionType(REAL, [REAL, REAL])) rg = Symbol("rg", FunctionType(REAL, [REAL])) ih = Symbol("ih", FunctionType(INT, [REAL, INT])) ig = Symbol("ig", FunctionType(INT, [INT])) bf = Symbol("bf", FunctionType(BOOL, [BOOL])) bg = Symbol("bg", FunctionType(BOOL, [BOOL])) bv8 = Symbol("bv1", BV8) bv16 = Symbol("bv2", BV16) result = [ # Formula, is_valid, is_sat, is_qf Example(hr="(x & y)", expr=And(x, y), is_valid=False, is_sat=True, logic=pysmt.logics.QF_BOOL), Example(hr="(x <-> y)", expr=Iff(x, y), is_valid=False, is_sat=True, logic=pysmt.logics.QF_BOOL), Example(hr="((x | y) & (! (x | y)))", expr=And(Or(x, y), Not(Or(x, y))), is_valid=False, is_sat=False, logic=pysmt.logics.QF_BOOL), Example(hr="(x & (! y))", expr=And(x, Not(y)), is_valid=False, is_sat=True, logic=pysmt.logics.QF_BOOL), Example(hr="(False -> True)", expr=Implies(FALSE(), TRUE()), is_valid=True, is_sat=True, logic=pysmt.logics.QF_BOOL), # # LIA # Example(hr="((q < p) & (x -> y))", expr=And(GT(p, q), Implies(x, y)), is_valid=False, is_sat=True, logic=pysmt.logics.QF_IDL), Example(hr="(((p + q) = 5) & (q < p))", expr=And(Equals(Plus(p, q), Int(5)), GT(p, q)), is_valid=False, is_sat=True, logic=pysmt.logics.QF_LIA), Example(hr="((q <= p) | (p <= q))", expr=Or(GE(p, q), LE(p, q)), is_valid=True, is_sat=True, logic=pysmt.logics.QF_IDL), Example(hr="(! (p < (q * 2)))", expr=Not(LT(p, Times(q, Int(2)))), is_valid=False, is_sat=True, logic=pysmt.logics.QF_LIA), Example(hr="(p < (p - (5 - 2)))", expr=GT(Minus(p, Minus(Int(5), Int(2))), p), is_valid=False, is_sat=False, logic=pysmt.logics.QF_IDL), Example(hr="((x ? 7 : ((p + -1) * 3)) = q)", expr=Equals( Ite(x, Int(7), Times(Plus(p, Int(-1)), Int(3))), q), is_valid=False, is_sat=True, logic=pysmt.logics.QF_LIA), Example(hr="(p < (q + 1))", expr=LT(p, Plus(q, Int(1))), is_valid=False, is_sat=True, logic=pysmt.logics.QF_LIA), # # LRA # Example(hr="((s < r) & (x -> y))", expr=And(GT(r, s), Implies(x, y)), is_valid=False, is_sat=True, logic=pysmt.logics.QF_RDL), Example(hr="(((r + s) = 28/5) & (s < r))", expr=And(Equals(Plus(r, s), Real(Fraction("5.6"))), GT(r, s)), is_valid=False, is_sat=True, logic=pysmt.logics.QF_LRA), Example(hr="((s <= r) | (r <= s))", expr=Or(GE(r, s), LE(r, s)), is_valid=True, is_sat=True, logic=pysmt.logics.QF_RDL), Example(hr="(! ((r * 2.0) < (s * 2.0)))", expr=Not(LT(Div(r, Real((1, 2))), Times(s, Real(2)))), is_valid=False, is_sat=True, logic=pysmt.logics.QF_LRA), Example(hr="(! (r < (r - (5.0 - 2.0))))", expr=Not(GT(Minus(r, Minus(Real(5), Real(2))), r)), is_valid=True, is_sat=True, logic=pysmt.logics.QF_RDL), Example(hr="((x ? 7.0 : ((s + -1.0) * 3.0)) = r)", expr=Equals( Ite(x, Real(7), Times(Plus(s, Real(-1)), Real(3))), r), is_valid=False, is_sat=True, logic=pysmt.logics.QF_LRA), # # EUF # Example(hr="(bf(x) <-> bg(x))", expr=Iff(Function(bf, (x, )), Function(bg, (x, ))), is_valid=False, is_sat=True, logic=pysmt.logics.QF_UF), Example(hr="(rf(5.0, rg(r)) = 0.0)", expr=Equals(Function(rf, (Real(5), Function(rg, (r, )))), Real(0)), is_valid=False, is_sat=True, logic=pysmt.logics.QF_UFLRA), Example(hr="((rg(r) = (5.0 + 2.0)) <-> (rg(r) = 7.0))", expr=Iff(Equals(Function(rg, [r]), Plus(Real(5), Real(2))), Equals(Function(rg, [r]), Real(7))), is_valid=True, is_sat=True, logic=pysmt.logics.QF_UFLRA), Example( hr="((r = (s + 1.0)) & (rg(s) = 5.0) & (rg((r - 1.0)) = 7.0))", expr=And([ Equals(r, Plus(s, Real(1))), Equals(Function(rg, [s]), Real(5)), Equals(Function(rg, [Minus(r, Real(1))]), Real(7)) ]), is_valid=False, is_sat=False, logic=pysmt.logics.QF_UFLRA), # # BV # Example(hr="((1_32 & 0_32) = 0_32)", expr=Equals(BVAnd(BVOne(32), BVZero(32)), BVZero(32)), is_valid=True, is_sat=True, logic=pysmt.logics.QF_BV), Example(hr="((! 2_3) = 5_3)", expr=Equals(BVNot(BV("010")), BV("101")), is_valid=True, is_sat=True, logic=pysmt.logics.QF_BV), Example(hr="((7_3 xor 0_3) = 0_3)", expr=Equals(BVXor(BV("111"), BV("000")), BV("000")), is_valid=False, is_sat=False, logic=pysmt.logics.QF_BV), Example(hr="((bv1::bv1) u< 0_16)", expr=BVULT(BVConcat(bv8, bv8), BVZero(16)), is_valid=False, is_sat=False, logic=pysmt.logics.QF_BV), Example(hr="(1_32[0:7] = 1_8)", expr=Equals(BVExtract(BVOne(32), end=7), BVOne(8)), is_valid=True, is_sat=True, logic=pysmt.logics.QF_BV), Example(hr="(0_8 u< (((bv1 + 1_8) * 5_8) u/ 5_8))", expr=BVUGT( BVUDiv(BVMul(BVAdd(bv8, BVOne(8)), BV(5, width=8)), BV(5, width=8)), BVZero(8)), is_valid=False, is_sat=True, logic=pysmt.logics.QF_BV), Example(hr="(0_16 u<= bv2)", expr=BVUGE(bv16, BVZero(16)), is_valid=True, is_sat=True, logic=pysmt.logics.QF_BV), Example(hr="(0_16 s<= bv2)", expr=BVSGE(bv16, BVZero(16)), is_valid=False, is_sat=True, logic=pysmt.logics.QF_BV), Example( hr="((0_32 u< (5_32 u% 2_32)) & ((5_32 u% 2_32) u<= 1_32))", expr=And( BVUGT(BVURem(BV(5, width=32), BV(2, width=32)), BVZero(32)), BVULE(BVURem(BV(5, width=32), BV(2, width=32)), BVOne(32))), is_valid=True, is_sat=True, logic=pysmt.logics.QF_BV), Example(hr="((((1_32 + (- 1_32)) << 1_32) >> 1_32) = 1_32)", expr=Equals( BVLShr(BVLShl(BVAdd(BVOne(32), BVNeg(BVOne(32))), 1), 1), BVOne(32)), is_valid=False, is_sat=False, logic=pysmt.logics.QF_BV), Example(hr="((1_32 - 1_32) = 0_32)", expr=Equals(BVSub(BVOne(32), BVOne(32)), BVZero(32)), is_valid=True, is_sat=True, logic=pysmt.logics.QF_BV), # Rotations Example(hr="(((1_32 ROL 1) ROR 1) = 1_32)", expr=Equals(BVRor(BVRol(BVOne(32), 1), 1), BVOne(32)), is_valid=True, is_sat=True, logic=pysmt.logics.QF_BV), # Extensions Example(hr="((0_5 ZEXT 11) = (0_1 SEXT 15))", expr=Equals(BVZExt(BVZero(5), 11), BVSExt(BVZero(1), 15)), is_valid=True, is_sat=True, logic=pysmt.logics.QF_BV), Example(hr="((bv2 - bv2) = 0_16)", expr=Equals(BVSub(bv16, bv16), BVZero(16)), is_valid=True, is_sat=True, logic=pysmt.logics.QF_BV), Example(hr="((bv2 - bv2)[0:7] = bv1)", expr=Equals(BVExtract(BVSub(bv16, bv16), 0, 7), bv8), is_valid=False, is_sat=True, logic=pysmt.logics.QF_BV), Example(hr="((bv2[0:7] bvcomp bv1) = 1_1)", expr=Equals(BVComp(BVExtract(bv16, 0, 7), bv8), BVOne(1)), is_valid=False, is_sat=True, logic=pysmt.logics.QF_BV), Example(hr="((bv2 bvcomp bv2) = 0_1)", expr=Equals(BVComp(bv16, bv16), BVZero(1)), is_valid=False, is_sat=False, logic=pysmt.logics.QF_BV), Example(hr="(bv2 s< bv2)", expr=BVSLT(bv16, bv16), is_valid=False, is_sat=False, logic=pysmt.logics.QF_BV), Example(hr="(bv2 s< 0_16)", expr=BVSLT(bv16, BVZero(16)), is_valid=False, is_sat=True, logic=pysmt.logics.QF_BV), Example(hr="((bv2 s< 0_16) | (0_16 s<= bv2))", expr=Or(BVSGT(BVZero(16), bv16), BVSGE(bv16, BVZero(16))), is_valid=True, is_sat=True, logic=pysmt.logics.QF_BV), Example(hr="(bv2 u< bv2)", expr=BVULT(bv16, bv16), is_valid=False, is_sat=False, logic=pysmt.logics.QF_BV), Example(hr="(bv2 u< 0_16)", expr=BVULT(bv16, BVZero(16)), is_valid=False, is_sat=False, logic=pysmt.logics.QF_BV), Example(hr="((bv2 | 0_16) = bv2)", expr=Equals(BVOr(bv16, BVZero(16)), bv16), is_valid=True, is_sat=True, logic=pysmt.logics.QF_BV), Example(hr="((bv2 & 0_16) = 0_16)", expr=Equals(BVAnd(bv16, BVZero(16)), BVZero(16)), is_valid=True, is_sat=True, logic=pysmt.logics.QF_BV), Example(hr="((0_16 s< bv2) & ((bv2 s/ 65535_16) s< 0_16))", expr=And(BVSLT(BVZero(16), bv16), BVSLT(BVSDiv(bv16, SBV(-1, 16)), BVZero(16))), is_valid=False, is_sat=True, logic=pysmt.logics.QF_BV), Example(hr="((0_16 s< bv2) & ((bv2 s% 1_16) s< 0_16))", expr=And(BVSLT(BVZero(16), bv16), BVSLT(BVSRem(bv16, BVOne(16)), BVZero(16))), is_valid=False, is_sat=False, logic=pysmt.logics.QF_BV), Example(hr="((bv2 u% 1_16) = 0_16)", expr=Equals(BVURem(bv16, BVOne(16)), BVZero(16)), is_valid=True, is_sat=True, logic=pysmt.logics.QF_BV), Example(hr="((bv2 s% 1_16) = 0_16)", expr=Equals(BVSRem(bv16, BVOne(16)), BVZero(16)), is_valid=True, is_sat=True, logic=pysmt.logics.QF_BV), Example(hr="((bv2 s% (- 1_16)) = 0_16)", expr=Equals(BVSRem(bv16, BVNeg(BVOne(16))), BVZero(16)), is_valid=True, is_sat=True, logic=pysmt.logics.QF_BV), Example(hr="((bv2 a>> 0_16) = bv2)", expr=Equals(BVAShr(bv16, BVZero(16)), bv16), is_valid=True, is_sat=True, logic=pysmt.logics.QF_BV), Example(hr="((0_16 s<= bv2) & ((bv2 a>> 1_16) = (bv2 >> 1_16)))", expr=And( BVSLE(BVZero(16), bv16), Equals(BVAShr(bv16, BVOne(16)), BVLShr(bv16, BVOne(16)))), is_valid=False, is_sat=True, logic=pysmt.logics.QF_BV), # # Quantification # Example(hr="(forall y . (x -> y))", expr=ForAll([y], Implies(x, y)), is_valid=False, is_sat=True, logic=pysmt.logics.BOOL), Example(hr="(forall p, q . ((p + q) = 0))", expr=ForAll([p, q], Equals(Plus(p, q), Int(0))), is_valid=False, is_sat=False, logic=pysmt.logics.LIA), Example( hr="(forall r, s . (((0.0 < r) & (0.0 < s)) -> ((r - s) < r)))", expr=ForAll([r, s], Implies(And(GT(r, Real(0)), GT(s, Real(0))), (LT(Minus(r, s), r)))), is_valid=True, is_sat=True, logic=pysmt.logics.LRA), Example(hr="(exists x, y . (x -> y))", expr=Exists([x, y], Implies(x, y)), is_valid=True, is_sat=True, logic=pysmt.logics.BOOL), Example(hr="(exists p, q . ((p + q) = 0))", expr=Exists([p, q], Equals(Plus(p, q), Int(0))), is_valid=True, is_sat=True, logic=pysmt.logics.LIA), Example(hr="(exists r . (forall s . (r < (r - s))))", expr=Exists([r], ForAll([s], GT(Minus(r, s), r))), is_valid=False, is_sat=False, logic=pysmt.logics.LRA), Example(hr="(forall r . (exists s . (r < (r - s))))", expr=ForAll([r], Exists([s], GT(Minus(r, s), r))), is_valid=True, is_sat=True, logic=pysmt.logics.LRA), Example(hr="(x & (forall r . ((r + s) = 5.0)))", expr=And(x, ForAll([r], Equals(Plus(r, s), Real(5)))), is_valid=False, is_sat=False, logic=pysmt.logics.LRA), Example(hr="(exists x . ((x <-> (5.0 < s)) & (s < 3.0)))", expr=Exists([x], (And(Iff(x, GT(s, Real(5))), LT(s, Real(3))))), is_valid=False, is_sat=True, logic=pysmt.logics.LRA), # # UFLIRA # Example(hr="((p < ih(r, q)) & (x -> y))", expr=And(GT(Function(ih, (r, q)), p), Implies(x, y)), is_valid=False, is_sat=True, logic=pysmt.logics.QF_UFLIRA), Example( hr= "(((p - 3) = q) -> ((p < ih(r, (q + 3))) | (ih(r, p) <= p)))", expr=Implies( Equals(Minus(p, Int(3)), q), Or(GT(Function(ih, (r, Plus(q, Int(3)))), p), LE(Function(ih, (r, p)), p))), is_valid=True, is_sat=True, logic=pysmt.logics.QF_UFLIRA), Example( hr= "(((ToReal((p - 3)) = r) & (ToReal(q) = r)) -> ((p < ih(ToReal((p - 3)), (q + 3))) | (ih(r, p) <= p)))", expr=Implies( And(Equals(ToReal(Minus(p, Int(3))), r), Equals(ToReal(q), r)), Or( GT( Function( ih, (ToReal(Minus(p, Int(3))), Plus(q, Int(3)))), p), LE(Function(ih, (r, p)), p))), is_valid=True, is_sat=True, logic=pysmt.logics.QF_UFLIRA), Example( hr= "(! (((ToReal((p - 3)) = r) & (ToReal(q) = r)) -> ((p < ih(ToReal((p - 3)), (q + 3))) | (ih(r, p) <= p))))", expr=Not( Implies( And(Equals(ToReal(Minus(p, Int(3))), r), Equals(ToReal(q), r)), Or( GT( Function(ih, (ToReal(Minus( p, Int(3))), Plus(q, Int(3)))), p), LE(Function(ih, (r, p)), p)))), is_valid=False, is_sat=False, logic=pysmt.logics.QF_UFLIRA), Example( hr= """("Did you know that any string works? #yolo" & "10" & "|#somesolverskeepthe||" & " ")""", expr=And(Symbol("Did you know that any string works? #yolo"), Symbol("10"), Symbol("|#somesolverskeepthe||"), Symbol(" ")), is_valid=False, is_sat=True, logic=pysmt.logics.QF_BOOL), # # Arrays # Example(hr="((q = 0) -> (aii[0 := 0] = aii[0 := q]))", expr=Implies( Equals(q, Int(0)), Equals(Store(aii, Int(0), Int(0)), Store(aii, Int(0), q))), is_valid=True, is_sat=True, logic=pysmt.logics.QF_ALIA), Example(hr="(aii[0 := 0][0] = 0)", expr=Equals(Select(Store(aii, Int(0), Int(0)), Int(0)), Int(0)), is_valid=True, is_sat=True, logic=pysmt.logics.QF_ALIA), Example(hr="((Array{Int, Int}(0)[1 := 1] = aii) & (aii[1] = 0))", expr=And(Equals(Array(INT, Int(0), {Int(1): Int(1)}), aii), Equals(Select(aii, Int(1)), Int(0))), is_valid=False, is_sat=False, logic=pysmt.logics.get_logic_by_name("QF_ALIA*")), Example(hr="((Array{Int, Int}(0)[1 := 3] = aii) & (aii[1] = 3))", expr=And(Equals(Array(INT, Int(0), {Int(1): Int(3)}), aii), Equals(Select(aii, Int(1)), Int(3))), is_valid=False, is_sat=True, logic=pysmt.logics.get_logic_by_name("QF_ALIA*")), Example(hr="((Array{Real, Int}(10) = ari) & (ari[6/5] = 0))", expr=And(Equals(Array(REAL, Int(10)), ari), Equals(Select(ari, Real((6, 5))), Int(0))), is_valid=False, is_sat=False, logic=pysmt.logics.get_logic_by_name("QF_AUFBVLIRA*")), Example( hr= "((Array{Real, Int}(0)[1.0 := 10][2.0 := 20][3.0 := 30][4.0 := 40] = ari) & (! ((ari[0.0] = 0) & (ari[1.0] = 10) & (ari[2.0] = 20) & (ari[3.0] = 30) & (ari[4.0] = 40))))", expr=And( Equals( Array( REAL, Int(0), { Real(1): Int(10), Real(2): Int(20), Real(3): Int(30), Real(4): Int(40) }), ari), Not( And(Equals(Select(ari, Real(0)), Int(0)), Equals(Select(ari, Real(1)), Int(10)), Equals(Select(ari, Real(2)), Int(20)), Equals(Select(ari, Real(3)), Int(30)), Equals(Select(ari, Real(4)), Int(40))))), is_valid=False, is_sat=False, logic=pysmt.logics.get_logic_by_name("QF_AUFBVLIRA*")), Example( hr= "((Array{Real, Int}(0)[1.0 := 10][2.0 := 20][3.0 := 30][4.0 := 40][5.0 := 50] = ari) & (! ((ari[0.0] = 0) & (ari[1.0] = 10) & (ari[2.0] = 20) & (ari[3.0] = 30) & (ari[4.0] = 40) & (ari[5.0] = 50))))", expr=And( Equals( Array( REAL, Int(0), { Real(1): Int(10), Real(2): Int(20), Real(3): Int(30), Real(4): Int(40), Real(5): Int(50) }), ari), Not( And(Equals(Select(ari, Real(0)), Int(0)), Equals(Select(ari, Real(1)), Int(10)), Equals(Select(ari, Real(2)), Int(20)), Equals(Select(ari, Real(3)), Int(30)), Equals(Select(ari, Real(4)), Int(40)), Equals(Select(ari, Real(5)), Int(50))))), is_valid=False, is_sat=False, logic=pysmt.logics.get_logic_by_name("QF_AUFBVLIRA*")), Example( hr= "((a_arb_aii = Array{Array{Real, BV{8}}, Array{Int, Int}}(Array{Int, Int}(7))) -> (a_arb_aii[arb][42] = 7))", expr=Implies( Equals(nested_a, Array(ArrayType(REAL, BV8), Array(INT, Int(7)))), Equals(Select(Select(nested_a, arb), Int(42)), Int(7))), is_valid=True, is_sat=True, logic=pysmt.logics.get_logic_by_name("QF_AUFBVLIRA*")), Example(hr="(abb[bv1 := y_][bv1 := z_] = abb[bv1 := z_])", expr=Equals( Store(Store(abb, bv8, Symbol("y_", BV8)), bv8, Symbol("z_", BV8)), Store(abb, bv8, Symbol("z_", BV8))), is_valid=True, is_sat=True, logic=pysmt.logics.QF_ABV), Example(hr="((r / s) = (r * s))", expr=Equals(Div(r, s), Times(r, s)), is_valid=False, is_sat=True, logic=pysmt.logics.QF_NRA), Example(hr="(2.0 = (r * r))", expr=Equals(Real(2), Times(r, r)), is_valid=False, is_sat=True, logic=pysmt.logics.QF_NRA), Example(hr="((p ^ 2) = 0)", expr=Equals(Pow(p, Int(2)), Int(0)), is_valid=False, is_sat=True, logic=pysmt.logics.QF_NIA), Example(hr="((r ^ 2.0) = 0.0)", expr=Equals(Pow(r, Real(2)), Real(0)), is_valid=False, is_sat=True, logic=pysmt.logics.QF_NRA), Example(hr="((r * r * r) = 25.0)", expr=Equals(Times(r, r, r), Real(25)), is_valid=False, is_sat=True, logic=pysmt.logics.QF_NRA), Example(hr="((5.0 * r * 5.0) = 25.0)", expr=Equals(Times(Real(5), r, Real(5)), Real(25)), is_valid=False, is_sat=True, logic=pysmt.logics.QF_LRA), Example(hr="((p * p * p) = 25)", expr=Equals(Times(p, p, p), Int(25)), is_valid=False, is_sat=False, logic=pysmt.logics.QF_NIA), Example(hr="((5 * p * 5) = 25)", expr=Equals(Times(Int(5), p, Int(5)), Int(25)), is_valid=False, is_sat=True, logic=pysmt.logics.QF_LIA), Example(hr="(((1 - 1) * p * 1) = 0)", expr=Equals(Times(Minus(Int(1), Int(1)), p, Int(1)), Int(0)), is_valid=True, is_sat=True, logic=pysmt.logics.QF_LIA), # Huge Fractions: Example( hr= "((r * 1606938044258990275541962092341162602522202993782792835301376/7) = -20480000000000000000000000.0)", expr=Equals(Times(r, Real(Fraction(2**200, 7))), Real(-200**11)), is_valid=False, is_sat=True, logic=pysmt.logics.QF_LRA), Example(hr="(((r + 5.0 + s) * (s + 2.0 + r)) = 0.0)", expr=Equals( Times(Plus(r, Real(5), s), Plus(s, Real(2), r)), Real(0)), is_valid=False, is_sat=True, logic=pysmt.logics.QF_NRA), Example( hr= "(((p + 5 + q) * (p - (q - 5))) = ((p * p) + (10 * p) + 25 + (-1 * q * q)))", expr=Equals( Times(Plus(p, Int(5), q), Minus(p, Minus(q, Int(5)))), Plus(Times(p, p), Times(Int(10), p), Int(25), Times(Int(-1), q, q))), is_valid=True, is_sat=True, logic=pysmt.logics.QF_NIA), ] return result
def solve_smt_problem(max_outputs, max_unique=None, timeout=None): #constraints: input_constraints = set() output_constraints = set() anonymityset_constraints = set() txfee_constraints = set() invariants = set() #variables: total_in = Symbol("total_in", INT) #total satoshis from inputs total_out = Symbol("total_out", INT) #total satoshis sent to outputs num_outputs = Symbol("num_outputs", INT) #num outputs actually used in the tx max_outputs_sym = Symbol("max_outputs", INT) #the symbolic variable for max_outputs num_unique_outputs = Symbol("num_unique_outputs", INT) #num uniquely identifiable outputs txsize = Symbol( "txsize", INT ) #estimated tx size in vbytes, given the number of inputs and outputs in the tx txfee = Symbol("txfee", INT) #estimated tx fee, given the supplied feerate party_gives = dict( ) #party ID -> total satoshis on inputs contributed by that party party_gets = dict( ) #party ID -> total satoshis on outputs belonging to that party party_txfee = dict( ) #party ID -> satoshis contributed by that party towards the txfee party_cjfee = dict() #party ID -> satoshis earned by that party as a cjfee input_party = dict( ) #index into inputs -> party ID that contributed that input input_amt = dict( ) #index into inputs -> satoshis contributed by that input output_party = dict( ) #index into outputs -> party ID to whom the output belongs output_amt = dict() #index into outputs -> satoshis sent to that output output_not_unique = dict( ) #index into outputs -> 1 if output is uniquely identifiable, else 0 main_cj_amt = Symbol( "main_cj_amt", INT ) #satoshi size of the outputs in the biggest anonymity set including all parties for (party, _) in example_txfees: party_gives[party] = Symbol("party_gives[%d]" % party, INT) party_gets[party] = Symbol("party_gets[%d]" % party, INT) party_txfee[party] = Symbol("party_txfee[%d]" % party, INT) party_cjfee[party] = Symbol("party_cjfee[%d]" % party, INT) for i in range(0, num_inputs): input_party[i] = Symbol("input_party[%d]" % i, INT) input_amt[i] = Symbol("input_amt[%d]" % i, INT) for i in range(0, max_outputs): output_party[i] = Symbol("output_party[%d]" % i, INT) output_amt[i] = Symbol("output_amt[%d]" % i, INT) output_not_unique[i] = Symbol("output_not_unique[%d]" % i, INT) #constraint construction: #party_txfee and party_cjfee bindings and (for the taker) constraints: for (party, fee_contribution) in example_txfees: if party != example_taker: txfee_constraints.add( Equals(party_txfee[party], Int(fee_contribution))) else: other_party_txfees = reduce( lambda x, y: x + y, [x[1] if x[0] != party else 0 for x in example_txfees]) txfee_constraints.add( Equals(Minus(party_txfee[party], party_cjfee[party]), Minus(party_gives[party], party_gets[party]))) txfee_constraints.add( Equals(txfee, Plus(party_txfee[party], Int(other_party_txfees)))) for (party, fee) in example_cjfee: if party != example_taker: txfee_constraints.add(Equals(party_cjfee[party], Int(fee))) #input_party and input_amt bindings: for i in range(0, num_inputs): input_constraints.add(Equals(input_party[i], Int(example_inputs[i][0]))) input_constraints.add(Equals(input_amt[i], Int(example_inputs[i][1]))) #add constraints on output_party and output_amt: # -either output_party[i] == -1 and output_amt[i] == 0 # -or else output_amt[i] > 0 output_unused = list() for i in range(0, max_outputs): output_is_unused = Equals(output_party[i], Int(-1)) output_unused.append(output_is_unused) min_delta_satisfied = Or(output_is_unused, And([Or(Equals(output_amt[i], output_amt[j]), Or(GE(output_amt[j], Plus(output_amt[i], Int(min_output_amt_delta))), LE(output_amt[j], Minus(output_amt[i], Int(min_output_amt_delta)))))\ for j in filter(lambda j: j != i, range(0, max_outputs))])) output_constraints.add(min_delta_satisfied) output_constraints.add( Ite(output_is_unused, Equals(output_amt[i], Int(0)), GT(output_amt[i], Int(max(0, min_output_amt - 1))))) #calculate num_outputs and bind max_outputs: output_constraints.add( Equals(num_outputs, Plus([bool_to_int(Not(x)) for x in output_unused]))) output_constraints.add(Equals(max_outputs_sym, Int(max_outputs))) #txfee, party_gets, and party_gives calculation/constraints/binding: for party in parties: #party_gives and input constraint/invariant input_constraints.add(Equals(party_gives[party], Plus([Int(a)\ for (p, a) in filter(lambda x: x[0] == party, example_inputs)]))) #txfee calculations: if party != example_taker: txfee_constraints.add( Equals( party_gets[party], Plus(party_cjfee[party], Minus(party_gives[party], party_txfee[party])))) else: fee_contributions = Plus([ party_txfee[p] for p in filter(lambda x: x != example_taker, parties) ]) cjfees = Plus([ party_cjfee[p] for p in filter(lambda x: x != example_taker, parties) ]) txfee_constraints.add( Equals( party_gets[party], Plus(fee_contributions, Minus(party_gives[party], Plus(cjfees, txfee))))) #party_gets and output constraint/invariant: output_constraints.add( Equals( party_gets[party], Plus([ Ite(Equals(output_party[i], Int(party)), output_amt[i], Int(0)) for i in range(0, max_outputs) ]))) #build anonymity set constraints: #first, no matter what, we retain the core CoinJoin with the biggest anonymity set: num_outputs_at_main_cj_amt = Plus( [bool_to_int(Equals(v, main_cj_amt)) for (k, v) in output_amt.items()]) anonymityset_constraints.add( Equals( main_cj_amt, Int(example_amt) if example_amt != 0 else party_gets[example_taker])) anonymityset_constraints.add( GE(num_outputs_at_main_cj_amt, Int(len(parties)))) #also, each party should only have at most one output not part of any anonymity set: for party in parties: def belongs_and_unique(idx): disequal = [Or(Not(Equals(v, output_amt[idx])), Equals(output_party[k], Int(party)))\ for (k, v) in filter(lambda x: x[0] != idx, output_amt.items())] return And(Equals(output_party[idx], Int(party)), And(disequal)) unique_amt_count = Plus([ bool_to_int(belongs_and_unique(k)) for (k, v) in output_amt.items() ]) anonymityset_constraints.add(LE(unique_amt_count, Int(1))) #calculate how many outputs are uniquely identifiable (unused outputs are excluded): for (idx, amt) in output_amt.items(): not_unique = Or(Equals(output_party[idx], Int(-1)), Or([And(Equals(v, amt), Not(Equals(output_party[k], output_party[idx])))\ for (k, v) in filter(lambda x: x[0] != idx, output_amt.items())])) anonymityset_constraints.add( Equals(output_not_unique[idx], bool_to_int(not_unique))) anonymityset_constraints.add( Equals( num_unique_outputs, Minus( max_outputs_sym, Plus([ not_unique for (_, not_unique) in output_not_unique.items() ])))) #constrain (if set) the number of uniquely-identifiable outputs #(i.e. those not in an anonymity set with cardinality > 1): if max_unique is not None: anonymityset_constraints.add(LE(num_unique_outputs, Int(max_unique))) #set transaction invariants: invariants.add(Equals(total_in, Plus(total_out, txfee))) invariants.add(Equals(total_in, Plus([v for (k, v) in input_amt.items()]))) invariants.add( Equals(total_in, Plus([v for (k, v) in party_gives.items()]))) invariants.add( Equals(total_out, Plus([v for (k, v) in output_amt.items()]))) invariants.add( Equals(total_out, Plus([v for (k, v) in party_gets.items()]))) #build txfee calculation constraint: 11 + 68 * num_inputs + 31 * num_outputs txfee_constraints.add( Equals(txsize, Plus(Int(11 + 68 * num_inputs), Times(Int(31), num_outputs)))) txfee_constraints.add(GE(txfee, Times(txsize, Int(min_feerate)))) txfee_constraints.add(LE(txfee, Times(txsize, Int(max_feerate)))) #finish problem construction: constraints = list() for x in [ input_constraints, invariants, txfee_constraints, output_constraints, anonymityset_constraints ]: for c in x: constraints.append(c) problem = And(constraints) with Solver(name='z3', solver_options={'timeout': solver_iteration_timeout}) as s: try: if s.solve([problem]): model_lines = sorted( str(s.get_model()).replace("'", "").split('\n')) result = ([ s.get_py_value(num_outputs), s.get_py_value(num_unique_outputs) ], parse_model_lines(model_lines)) return result else: return None except SolverReturnedUnknownResultError: return None
def _op_raw_sub(self, *args): return Minus(*args)
def expr_to_pysmt(context: TranslationContext, expr: Expr, *, is_expectation: bool = False, allow_infinity: bool = False) -> FNode: """ Translate a pGCL expression to a pySMT formula. Note that substitution expressions are not allowed here (they are not supported in pySMT). You can pass in the optional `is_expectation` parameter to have all integer values converted to real values. If `allow_infinity` is `True`, then infinity expressions will be mapped directly to the `infinity` variable of the given :py:class:`TranslationContext`. Take care to appropriately constrain the `infinity` variable! Note that arithmetic expressions may not contain infinity, to prevent expressions like `infinity - infinity`. .. doctest:: >>> from probably.pgcl.parser import parse_expr >>> from pysmt.shortcuts import Symbol >>> from pysmt.typing import INT >>> expr = parse_expr("x + 4 * 13") >>> context = TranslationContext({"x": Symbol("x", INT)}) >>> expr_to_pysmt(context, expr) (x + (4 * 13)) """ if isinstance(expr, BoolLitExpr): return TRUE() if expr.value else FALSE() elif isinstance(expr, NatLitExpr): if is_expectation: return ToReal(Int(expr.value)) else: return Int(expr.value) elif isinstance(expr, FloatLitExpr): if expr.is_infinite(): if not allow_infinity: raise Exception( f"Infinity is not allowed in this expression: {expr}") return context.infinity else: return Real(Fraction(expr.value)) elif isinstance(expr, VarExpr): var = context.variables[expr.var] if is_expectation and get_type(var) == INT: var = ToReal(var) return var elif isinstance(expr, UnopExpr): operand = expr_to_pysmt(context, expr.expr, is_expectation=False, allow_infinity=allow_infinity) if expr.operator == Unop.NEG: return Not(operand) elif expr.operator == Unop.IVERSON: return Ite(operand, Real(1), Real(0)) elif isinstance(expr, BinopExpr): # `is_expectation` is disabled if we enter a non-arithmetic expression # (we do not convert integers to reals within a boolean expression such # as `x == y`, for example). # # Similarly, `allow_infinity` is disabled if we enter an arithmetic # expression because calculations with infinity are hard to make sense of. is_arith_op = expr.operator in [Binop.PLUS, Binop.MINUS, Binop.TIMES] is_expectation = is_expectation # TODO: and is_arith_op allow_infinity = allow_infinity # TODO: and not is_arith_op?!??! lhs = expr_to_pysmt(context, expr.lhs, is_expectation=is_expectation, allow_infinity=allow_infinity) rhs = expr_to_pysmt(context, expr.rhs, is_expectation=is_expectation, allow_infinity=allow_infinity) if expr.operator == Binop.OR: return Or(lhs, rhs) elif expr.operator == Binop.AND: return And(lhs, rhs) elif expr.operator == Binop.LEQ: return LE(lhs, rhs) elif expr.operator == Binop.LE: return LT(lhs, rhs) elif expr.operator == Binop.EQ: return EqualsOrIff(lhs, rhs) elif expr.operator == Binop.PLUS: return Plus(lhs, rhs) elif expr.operator == Binop.MINUS: return Ite(LE(lhs, rhs), (Int(0) if get_type(lhs) == INT else Real(0)), Minus(lhs, rhs)) elif expr.operator == Binop.TIMES: return Times(lhs, rhs) elif isinstance(expr, SubstExpr): raise Exception("Substitution expression is not allowed here.") raise Exception("unreachable")