def _accept_word(self, ts_enc, ts, word, final_states): """ Check if a particular word with a given final state is accepted by ts """ # error is encoded in the final state bmc = BMC(ts_enc.helper, ts, TRUE()) solver = Solver(name='z3', logic=QF_BOOL) all_vars = set(ts.state_vars) all_vars.update(ts.input_vars) bmc.encode_up_to_k(solver, all_vars, len(word)) error = bmc.helper.get_formula_at_i(all_vars, final_states, len(word)) solver.add_assertion(error) # encode the word for i in range(len(word)): w_formula = ts_enc.r2a.get_msg_eq(word[i]) w_at_i = bmc.helper.get_formula_at_i(ts.input_vars, w_formula, i) solver.add_assertion(w_at_i) res = solver.solve() # if res: # model = solver.get_model() # print model return res
def test_btor_does_not_support_int_arrays(self): a = Symbol("a", ARRAY_INT_INT) formula = Equals(Select(Store(a, Int(10), Int(100)), Int(10)), Int(100)) btor = Solver(name="btor") with self.assertRaises(ConvertExpressionError): btor.add_assertion(formula)
def _compute_WMI_PA_no_boolean(self, lab_formula, pa_vars, labels, other_assignments={}): """Finds all the assignments that satisfy the given formula using AllSAT. Args: lab_formula (FNode): The labelled pysmt formula to examine. pa_vars (): labels (dict): The dictionary containing the correlation between each label and their true value. Yields: dict: One of the assignments that satisfies the formula. """ solver = Solver(name="msat", solver_options={"dpll.allsat_minimize_model" : "true"}) converter = solver.converter solver.add_assertion(lab_formula) lra_assignments = [] mathsat.msat_all_sat( solver.msat_env(), [converter.convert(v) for v in pa_vars], lambda model : WMI._callback(model, converter, lra_assignments)) for mu_lra in lra_assignments: assignments = {} for atom, value in WMI._get_assignments(mu_lra).items(): if atom in labels: atom = labels[atom] assignments[atom] = value assignments.update(other_assignments) yield assignments
def _compute_TTAs(self, formula): """Computes the total truth assignments of the given formula. This method first labels the formula and then uses the funtionality of mathsat called AllSAT to retrieve all the total truth assignments. Args: formula (FNode): The pysmt formula to examine. Returns: list: The list of all the total truth assignments. dict: The dictionary containing all the correspondence between the labels and their true value. """ labels = {} expressions = [] allsat_variables = set() # Label LRA atoms with fresh boolean variables labelled_formula, pa_vars, labels = self.label_formula(formula, formula.get_atoms()) # Perform AllSMT on the labelled formula solver = Solver(name="msat") converter = solver.converter solver.add_assertion(labelled_formula) models = [] mathsat.msat_all_sat(solver.msat_env(), [converter.convert(v) for v in pa_vars], lambda model : WMI._callback(model, converter, models)) return models, labels
def configure(inpcode, options): global solver global solver_name if ("pysmt" in options): solver_name = options["pysmt"] elif ("smtpipe" in options): solver_name = "custom: " + options["smtpipe"].split("/")[-1] solver_cmd = options["smtpipe"] # solver_logics = [BOOL, LIA, LRA, NIA, NRA, QF_LRA, QF_NIA, QF_NRA, QF_UFLIA, QF_UFLRA, QF_UFNIA, QF_UFNRA, UFLIRA, UFLRA, UFNIA, AUFNIRA] # solver_logics = [AUFNIRA] solver_logic = eval(options["pipe_logic"]) solver_logics = [solver_logic] env = get_env() env.factory.add_generic_solver(solver_name, solver_cmd, solver_logics) verbprint = utils.verbprint verbprint(2, "Parsing the input...", False) parser = SmtLibParser() script = parser.get_script(inpcode) formula = script.get_last_formula() # print "Got a formula: " + str(f) verbprint(2, "done.", True) verbprint(2, "Initializing the solver...", False) if ("smtpipe" in options): solver = Solver(name=solver_name, logic=solver_logic) #solver = SmtLibSolver(options["smtpipe"], env, LRA) else: solver = Solver(name=solver_name) verbprint(2, "done.", True) verbprint(2, "Loading the input file in the solver...", False) solver.add_assertion(formula) verbprint(2, "done.", True)
def main(path, solver_name): solver = Solver(solver_name) parser = SmtLibParser() script = parser.get_script_fname(path) formula = script.get_last_formula() solver.add_assertion(formula) result = solver.solve() print(result)
def test_btor_options(self): for (f, _, sat, logic) in get_example_formulae(): if logic == QF_BV: solver = Solver(name="btor", solver_options={"rewrite-level":0, "fun:dual-prop":1, "eliminate-slices":1}) solver.add_assertion(f) res = solver.solve() self.assertTrue(res == sat)
def test_examples_solving(self): for example in EXAMPLE_FORMULAS: if example.logic != pysmt.logics.BOOL: continue solver = Solver(logic=pysmt.logics.BOOL, name='bdd') solver.add_assertion(example.expr) if example.is_sat: self.assertTrue(solver.solve()) else: self.assertFalse(solver.solve())
def test_examples_solving(self): for example in get_example_formulae(): if example.logic != BOOL: continue solver = Solver(logic=BOOL, name='bdd') solver.add_assertion(example.expr) if example.is_sat: self.assertTrue(solver.solve()) else: self.assertFalse(solver.solve())
def test_msat_partial_model(self): msat = Solver(name="msat") x, y = Symbol("x"), Symbol("y") msat.add_assertion(x) c = msat.solve() self.assertTrue(c) model = msat.get_model() self.assertNotIn(y, model) self.assertIn(x, model) msat.exit()
def get_value(): solver = Solver('z3') parser = SmtLibParser() script = parser.get_script_fname("get_value.smt2") #result = script.evaluate(Solver('z3')) exprs = [] for get_val_cmd in script.filter_by_command_name("get-value"): exprs.extend(get_val_cmd.args) formula = script.get_last_formula() solver.add_assertion(formula) result1 = solver.solve() result2 = solver.get_values(exprs)
def solve(formula, n, max_models=None, solver="msat"): s = Solver(name=solver) st = s.is_sat(formula) if st: vs = [x for xs in variables(n) for x in xs] k = 0 s.add_assertion(formula) while s.solve() and ((not max_models) or k < max_models): k = k + 1 model = s.get_model() s.add_assertion(Not(And([EqualsOrIff(v, model[v]) for v in vs]))) yield to_bn(model, n)
def get_prepared_solver(self, logic, solver_name=None): """Returns a solver initialized with the sudoku constraints and a matrix of SMT variables, each representing a cell of the game. """ sq_size = self.size**2 ty = self.get_type() var_table = [[FreshSymbol(ty) for _ in xrange(sq_size)] for _ in xrange(sq_size)] solver = Solver(logic=logic, name=solver_name) # Sudoku constraints # all variables are positive and lower or equal to than sq_size for row in var_table: for var in row: solver.add_assertion(Or([Equals(var, self.const(i)) for i in xrange(1, sq_size + 1)])) # each row and each column contains all different numbers for i in xrange(sq_size): solver.add_assertion(AllDifferent(var_table[i])) solver.add_assertion(AllDifferent([x[i] for x in var_table])) # each square contains all different numbers for sx in xrange(self.size): for sy in xrange(self.size): square = [var_table[i + sx * self.size][j + sy * self.size] for i in xrange(self.size) for j in xrange(self.size)] solver.add_assertion(AllDifferent(square)) return solver, var_table
def get_prepared_solver(self, logic, solver_name=None): """Returns a solver initialized with the sudoku constraints and a matrix of SMT variables, each representing a cell of the game. """ sq_size = self.size**2 ty = self.get_type() var_table = [[FreshSymbol(ty) for _ in range(sq_size)] for _ in range(sq_size)] solver = Solver(logic=logic, name=solver_name) # Sudoku constraints # all variables are positive and lower or equal to than sq_size for row in var_table: for var in row: solver.add_assertion(Or([Equals(var, self.const(i)) for i in range(1, sq_size + 1)])) # each row and each column contains all different numbers for i in range(sq_size): solver.add_assertion(AllDifferent(var_table[i])) solver.add_assertion(AllDifferent([x[i] for x in var_table])) # each square contains all different numbers for sx in range(self.size): for sy in range(self.size): square = [var_table[i + sx * self.size][j + sy * self.size] for i in range(self.size) for j in range(self.size)] solver.add_assertion(AllDifferent(square)) return solver, var_table
def configure(inpcode): global solver verbprint = utils.verbprint verbprint(2, "Parsing the input...", False) parser = SmtLibParser() script = parser.get_script(inpcode) formula = script.get_last_formula() # print "Got a formula: " + str(f) verbprint(2, "done.", True) verbprint(2, "Initializing the solver...", False) solver = Solver(name="msat") verbprint(2, "done.", True) verbprint(2, "Loading the input file in the solver...", False) solver.add_assertion(formula) verbprint(2, "done.", True)
def solve_soduku_model(model): solver = Solver() solver.add_assertion(model.extractConstraints()) if solver.solve(): # solution found solution = tabletools.generateEmptyTable(model.n) for y in range(0, model.n): # retrive the values from the sloved model for x in range(0, model.n): solution[y][x] = solver.get_value(model.getSymbol(x, y)) return solution else: return None # couldn't solve model :(
def test_create_and_solve(self): solver = Solver(logic=QF_BOOL) varA = Symbol("A", BOOL) varB = Symbol("B", BOOL) f = And(varA, Not(varB)) g = f.substitute({varB:varA}) solver.add_assertion(g) res = solver.solve() self.assertFalse(res, "Formula was expected to be UNSAT") h = And(g, Bool(False)) simp_h = h.simplify() self.assertEqual(simp_h, Bool(False))
def test_create_and_solve(self): solver = Solver(logic=QF_BOOL) varA = Symbol("A", BOOL) varB = Symbol("B", BOOL) f = And(varA, Not(varB)) g = f.substitute({varB: varA}) solver.add_assertion(g) res = solver.solve() self.assertFalse(res, "Formula was expected to be UNSAT") h = And(g, Bool(False)) simp_h = h.simplify() self.assertEqual(simp_h, Bool(False))
def k_sequence_WH(m, K, K_seq_len=100, count=100): k_seq = [Symbol('x_%i' % i, INT) for i in range(K_seq_len)] domain = And([Or(Equals(x, Int(0)), Equals(x, Int(1))) for x in k_seq]) K_window = And([ LE(Plus(k_seq[t:min(K_seq_len, t + K)]), Int(m)) for t in range(max(1, K_seq_len - K + 1)) ]) formula = And(domain, K_window) solver = Solver(name='yices', incremental=True, random_seed=randint(2 << 30)) solver.add_assertion(formula) for _ in range(count): result = solver.solve() if not result: solver = Solver(name='z3', incremental=True, random_seed=randint(2 << 30)) solver.add_assertion(formula) solver.solve() model = solver.get_model() model = array(list(map(lambda x: model.get_py_value(x), k_seq)), dtype=bool) yield model solver.add_assertion( Or([NotEquals(k_seq[i], Int(model[i])) for i in range(K_seq_len)]))
def test_msat_preferred_variable(self): a, b, c = [Symbol(x) for x in "abc"] na, nb, nc = [Not(Symbol(x)) for x in "abc"] f = And(Implies(a, And(b, c)), Implies(na, And(nb, nc))) s1 = Solver("msat") s1.add_assertion(f) s1.set_preferred_var(a, True) self.assertTrue(s1.solve()) self.assertTrue(s1.get_value(a).is_true()) s2 = Solver("msat") s2.add_assertion(f) s2.set_preferred_var(a, False) self.assertTrue(s2.solve()) self.assertTrue(s2.get_value(a).is_false()) # Show that calling without polarity still works # This case is harder to test, because we only say # that the split will occur on that variable first. s1.set_preferred_var(a)
def _model_iterator_base(formula): """Finds all the total truth assignments that satisfy the given formula. Args: formula (FNode): The pysmt formula to examine. Yields: model: The model representing the next total truth assignment that satisfies the formula. """ solver = Solver(name="msat") solver.add_assertion(formula) while solver.solve(): model = solver.get_model() yield model atom_assignments = {a : model.get_value(a) for a in formula.get_atoms()} # Constrain the solver to find a different assignment solver.add_assertion( Not(And([Iff(var,val) for var,val in atom_assignments.items()])))
def test_msat_preferred_variable(self): a, b, c = [Symbol(x) for x in "abc"] na, nb, nc = [Not(Symbol(x)) for x in "abc"] f = And(Implies(a, And(b,c)), Implies(na, And(nb,nc))) s1 = Solver("msat") s1.add_assertion(f) s1.set_preferred_var(a, True) self.assertTrue(s1.solve()) self.assertTrue(s1.get_value(a).is_true()) s2 = Solver("msat") s2.add_assertion(f) s2.set_preferred_var(a, False) self.assertTrue(s2.solve()) self.assertTrue(s2.get_value(a).is_false()) # Show that calling without polarity still works # This case is harder to test, because we only say # that the split will occur on that variable first. s1.set_preferred_var(a)
def k_sequence_WH_worst_case(m, K, K_seq_len=100, count=100): k_seq = [Symbol('x_%i' % i, INT) for i in range(K_seq_len)] domain = And([Or(Equals(x, Int(0)), Equals(x, Int(1))) for x in k_seq]) K_window = And([ LE(Plus(k_seq[t:min(K_seq_len, t + K)]), Int(m)) for t in range(max(1, K_seq_len - K + 1)) ]) violate_up = And([ GT(Plus(k_seq[t:min(K_seq_len, t + K)]), Int(m - 1)) for t in range(max(1, K_seq_len - K + 1)) ]) def violate_right_generator(n): return And([ GT(Plus(k_seq[t:min(K_seq_len, t + K + n)]), Int(m)) for t in range(max(1, K_seq_len - (K + n) + 1)) ]) right_shift = 1 formula = And(domain, K_window, violate_up, violate_right_generator(right_shift)) solver = Solver(name='z3', incremental=True, random_seed=randint(2 << 30)) solver.add_assertion(formula) solver.z3.set('timeout', 5 * 60 * 1000) solutions = And() for _ in range(count): while right_shift + K < K_seq_len: try: result = solver.solve() except BaseException: result = None if not result: solver = Solver(name='z3', incremental=True, random_seed=randint(2 << 30)) right_shift += 1 solver.z3.set('timeout', 5 * 60 * 1000) solver.add_assertion( And(solutions, domain, K_window, violate_up, violate_right_generator(right_shift))) else: break try: model = solver.get_model() except BaseException: break model = array(list(map(lambda x: model.get_py_value(x), k_seq)), dtype=bool) yield model solution = Or( [NotEquals(k_seq[i], Int(model[i])) for i in range(K_seq_len)]) solutions = And(solutions, solution) solver.add_assertion(solution)
class SATSolver: def __init__(self, solver_name): self.solver = Solver(solver_name, logic=QF_BOOL) def get_model(self): m = None res = self.solver.solve() if res: m = self.solver.get_model() return m def enum_model(self, blocking_cls): m = None self.solver.add_assertion(Not(And(blocking_cls))) res = self.solver.solve() if res: m = self.solver.get_model() return m def assert_cls(self, _cls): self.solver.add_assertion(_cls) def add_cls(self, _cls, level): self.solver.push(level) self.solver.add_assertion(_cls) def remove_cls(self, level): self.solver.pop(level) def unsat_core(self): res = self.solver.solve() if not res: print('Assertions:', self.fml) conj = conjunctive_partition(self.fml) ucore = get_unsat_core(conj) print("UNSAT-Core size '%d'" % len(ucore)) for f in ucore: print(f.serialize())
class Reluzy: def __init__(self, filename, violationfile, logger): self.logger = logger self.nnet2smt = Nnet2Smt(filename, violationfile) self.nnet2smt.convert(True) self.input_vars = self.nnet2smt.input_vars self.output_vars = self.nnet2smt.output_vars self.formulae = self.nnet2smt.formulae self.relus = self.nnet2smt.relus self.relus_level = self.nnet2smt.relus_level self.solver = Solver(name='yices') self.sat_checker = Solver(name='yices') self.init() def init(self): self.solver.add_assertion(And(self.formulae)) self.sat_checker.add_assertion(And(self.formulae)) for r1, r2 in self.relus: self.sat_checker.add_assertion(Equals(r1, Max(r2, Real(0)))) #lemmas = self.refine_zero_lb(False) #self.solver.add_assertion(And(lemmas)) #lemmas = self.refine_slope_lb(False) #self.solver.add_assertion(And(lemmas)) def solve(self): while True: self.logger.info('Solving') res = self.solver.solve() if not res: print('unsat') break else: lemmas = self.refine() if not lemmas: print('sat') break else: self.solver.add_assertion(And(lemmas)) def check_sat(self): self.logger.info('Checking for Sat') self.sat_checker.push() for v in self.input_vars: self.sat_checker.add_assertion(Equals(v, self.solver.get_value(v))) res = self.sat_checker.solve() if res: for x in self.input_vars: print(v, self.sat_checker.get_value(x)) return True else: self.sat_checker.pop() return False def refine(self): self.logger.info('Refining') lemmas = self.refine_zero_lb() if not lemmas: lemmas = self.refine_slope_lb() if not lemmas: lemmas = self.refine_zero_ub() if not lemmas: lemmas = self.refine_slope_ub() if not lemmas: for v in self.input_vars: print(v, self.solver.get_value(v)) #elif self.check_sat(): # return [] return lemmas def refine_zero_lb(self, check=True): lemmas = [] zero = Real(0) for r1, _ in self.relus: l = GE(r1, zero) if check: tval = self.solver.get_value(l) if tval.is_false(): lemmas.append(l) self.logger.debug('Adding %s' % l ) else: lemmas.append(l) return lemmas def refine_zero_ub(self): lemmas = [] zero = Real(0) for s in self.relus_level: for r1, r2 in s: l = Implies(LE(r2, zero), LE(r1, zero)) tval = self.solver.get_value(l) if tval.is_false(): lemmas.append(l) self.logger.debug('Adding %s' % l ) if lemmas: break return lemmas def refine_slope_lb(self, check=True): lemmas = [] for r1, r2 in self.relus: l = GE(r1, r2) if check: tval = self.solver.get_value(l) if tval.is_false(): lemmas.append(l) self.logger.debug('Adding %s' % l ) else: lemmas.append(l) return lemmas def refine_slope_ub(self): lemmas = [] zero = Real(0) for s in self.relus_level: for r1, r2 in s: l = Implies(GE(r2, zero), LE(r1, r2)) tval = self.solver.get_value(l) if tval.is_false(): lemmas.append(l) self.logger.debug('Adding %s' % l ) if lemmas: break return lemmas
def solve_with_dreal(formula): env = get_env() env.factory.add_generic_solver(DREAL_NAME, [DREAL_PATH, DREAL_ARGS], DREAL_LOGICS) solver = Solver('dreal') solver.add_assertion(formula) solver.solve()
def comb_attack(self): # dis generator solver_name = 'btor' solver_obf = Solver(name=solver_name) solver_key = Solver(name=solver_name) solver_oracle = Solver(name=solver_name) attack_formulas = FormulaGenerator(self.oracle_cir, self.obf_cir) f = attack_formulas.dip_gen_ckt # f = simplify(f) solver_obf.add_assertion(f) f = attack_formulas.key_inequality_ckt # f = simplify(f) solver_obf.add_assertion(f) iteration = 0 while 1: # query dip generator if solver_obf.solve(): dip_formula = [] dip_boolean = [] for l in self.obf_cir.input_wires: t = Symbol(l) if solver_obf.get_py_value(t): dip_formula.append(t) dip_boolean.append(TRUE()) else: dip_formula.append(Not(t)) dip_boolean.append(FALSE()) logging.info(dip_formula) # query oracle dip_out = [] for l in self.oracle_cir.output_wires: t = self.oracle_cir.wire_objs[l].formula solver_oracle.reset_assertions() solver_oracle.add_assertion(t) if solver_oracle.solve(dip_formula): dip_out.append(TRUE()) else: dip_out.append(FALSE()) logging.info(dip_out) # add dip checker f = [] for i in range(len(attack_formulas.dip_chk1)): f.append(And(Iff(dip_out[i], attack_formulas.dip_chk1[i]), Iff(dip_out[i], attack_formulas.dip_chk2[i]))) f = And(f) subs = {} for i in range(len(self.obf_cir.input_wires)): subs[Symbol(self.obf_cir.input_wires[i])] = dip_boolean[i] # f = simplify(f) f = substitute(f, subs) solver_obf.add_assertion(f) solver_key.add_assertion(f) iteration += 1 logging.warning('iteration: {}'.format(iteration)) else: logging.warning('print keys') if solver_key.solve(): key = '' for i in range(len(self.obf_cir.key_wires)): k = 'keyinput{}_0'.format(i) if solver_key.get_py_value(Symbol(k)): key += '1' else: key += '0' print("key=%s" % key) else: logging.critical('key solver returned UNSAT') return
def test_btor_does_not_support_const_arryas(self): with self.assertRaises(ConvertExpressionError): btor = Solver(name="btor") btor.add_assertion( Equals(Array(BV8, BV(0, 8)), FreshSymbol(ArrayType(BV8, BV8))))
class SMTValidator(object): """ Validating Anchor's explanations using SMT solving. """ def __init__(self, formula, feats, nof_classes, xgb): """ Constructor. """ self.ftids = {f: i for i, f in enumerate(feats)} self.nofcl = nof_classes self.idmgr = IDPool() self.optns = xgb.options # xgbooster will also be needed self.xgb = xgb self.verbose = self.optns.verb self.oracle = Solver(name=self.xgb.options.solver) self.inps = [] # input (feature value) variables for f in self.xgb.extended_feature_names_as_array_strings: if '_' not in f: self.inps.append(Symbol(f, typename=REAL)) else: self.inps.append(Symbol(f, typename=BOOL)) self.outs = [] # output (class score) variables for c in range(self.nofcl): self.outs.append(Symbol('class{0}_score'.format(c), typename=REAL)) # theory self.oracle.add_assertion(formula) # current selector self.selv = None def prepare(self, sample, expl): """ Prepare the oracle for validating an explanation given a sample. """ if self.selv: # disable the previous assumption if any self.oracle.add_assertion(Not(self.selv)) # creating a fresh selector for a new sample sname = ','.join([str(v).strip() for v in sample]) # the samples should not repeat; otherwise, they will be # inconsistent with the previously introduced selectors assert sname not in self.idmgr.obj2id, 'this sample has been considered before (sample {0})'.format( self.idmgr.id(sname)) self.selv = Symbol('sample{0}_selv'.format(self.idmgr.id(sname)), typename=BOOL) self.rhypos = [] # relaxed hypotheses # transformed sample self.sample = list(self.xgb.transform(sample)[0]) # preparing the selectors for i, (inp, val) in enumerate(zip(self.inps, self.sample), 1): feat = inp.symbol_name().split('_')[0] selv = Symbol('selv_{0}'.format(feat)) val = float(val) self.rhypos.append(selv) # adding relaxed hypotheses to the oracle for inp, val, sel in zip(self.inps, self.sample, self.rhypos): if '_' not in inp.symbol_name(): hypo = Implies(self.selv, Implies(sel, Equals(inp, Real(float(val))))) else: hypo = Implies(self.selv, Implies(sel, inp if val else Not(inp))) self.oracle.add_assertion(hypo) # propagating the true observation if self.oracle.solve([self.selv] + self.rhypos): model = self.oracle.get_model() else: assert 0, 'Formula is unsatisfiable under given assumptions' # choosing the maximum outvals = [float(model.get_py_value(o)) for o in self.outs] maxoval = max(zip(outvals, range(len(outvals)))) # correct class id (corresponds to the maximum computed) true_output = maxoval[1] # forcing a misclassification, i.e. a wrong observation disj = [] for i in range(len(self.outs)): if i != true_output: disj.append(GT(self.outs[i], self.outs[true_output])) self.oracle.add_assertion(Implies(self.selv, Or(disj))) # removing all hypotheses except for those in the explanation hypos = [] for i, hypo in enumerate(self.rhypos): j = self.ftids[self.xgb.transform_inverse_by_index(i)[0]] if j in expl: hypos.append(hypo) self.rhypos = hypos if self.verbose: inpvals = self.xgb.readable_sample(sample) preamble = [] for f, v in zip(self.xgb.feature_names, inpvals): if f not in v: preamble.append('{0} = {1}'.format(f, v)) else: preamble.append(v) print(' explanation for: "IF {0} THEN {1}"'.format( ' AND '.join(preamble), self.xgb.target_name[true_output])) def validate(self, sample, expl): """ Make an effort to show that the explanation is too optimistic. """ self.time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \ resource.getrusage(resource.RUSAGE_SELF).ru_utime # adapt the solver to deal with the current sample self.prepare(sample, expl) # if satisfiable, then there is a counterexample if self.oracle.solve([self.selv] + self.rhypos): model = self.oracle.get_model() inpvals = [float(model.get_py_value(i)) for i in self.inps] outvals = [float(model.get_py_value(o)) for o in self.outs] maxoval = max(zip(outvals, range(len(outvals)))) inpvals = self.xgb.transform_inverse(np.array(inpvals))[0] self.coex = tuple([inpvals, maxoval[1]]) inpvals = self.xgb.readable_sample(inpvals) if self.verbose: preamble = [] for f, v in zip(self.xgb.feature_names, inpvals): if f not in v: preamble.append('{0} = {1}'.format(f, v)) else: preamble.append(v) print(' explanation is incorrect') print(' counterexample: "IF {0} THEN {1}"'.format( ' AND '.join(preamble), self.xgb.target_name[maxoval[1]])) else: self.coex = None if self.verbose: print(' explanation is correct') self.time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \ resource.getrusage(resource.RUSAGE_SELF).ru_utime - self.time if self.verbose: print(' time: {0:.2f}'.format(self.time)) return self.coex
def test_btor_does_not_support_const_arryas(self): with self.assertRaises(ConvertExpressionError): btor = Solver(name="btor") btor.add_assertion(Equals(Array(BV8, BV(0, 8)), FreshSymbol(ArrayType(BV8, BV8))))
class SMTExplainer(object): """ An SMT-inspired minimal explanation extractor for XGBoost models. """ def __init__(self, formula, intvs, imaps, ivars, feats, nof_classes, options, xgb): """ Constructor. """ self.feats = feats self.intvs = intvs self.imaps = imaps self.ivars = ivars self.nofcl = nof_classes self.optns = options self.idmgr = IDPool() # saving XGBooster self.xgb = xgb self.verbose = self.optns.verb self.oracle = Solver(name=options.solver) self.inps = [] # input (feature value) variables for f in self.xgb.extended_feature_names_as_array_strings: if '_' not in f: self.inps.append(Symbol(f, typename=REAL)) else: self.inps.append(Symbol(f, typename=BOOL)) self.outs = [] # output (class score) variables for c in range(self.nofcl): self.outs.append(Symbol('class{0}_score'.format(c), typename=REAL)) # theory self.oracle.add_assertion(formula) # current selector self.selv = None # save and use dual explanations whenever needed self.dualx = [] # number of oracle calls involved self.calls = 0 def prepare(self, sample): """ Prepare the oracle for computing an explanation. """ if self.selv: # disable the previous assumption if any self.oracle.add_assertion(Not(self.selv)) # creating a fresh selector for a new sample sname = ','.join([str(v).strip() for v in sample]) # the samples should not repeat; otherwise, they will be # inconsistent with the previously introduced selectors assert sname not in self.idmgr.obj2id, 'this sample has been considered before (sample {0})'.format( self.idmgr.id(sname)) self.selv = Symbol('sample{0}_selv'.format(self.idmgr.id(sname)), typename=BOOL) self.rhypos = [] # relaxed hypotheses # transformed sample self.sample = list(self.xgb.transform(sample)[0]) self.sel2fid = {} # selectors to original feature ids self.sel2vid = {} # selectors to categorical feature ids # preparing the selectors for i, (inp, val) in enumerate(zip(self.inps, self.sample), 1): feat = inp.symbol_name().split('_')[0] selv = Symbol('selv_{0}'.format(feat)) val = float(val) self.rhypos.append(selv) if selv not in self.sel2fid: self.sel2fid[selv] = int(feat[1:]) self.sel2vid[selv] = [i - 1] else: self.sel2vid[selv].append(i - 1) # adding relaxed hypotheses to the oracle if not self.intvs: for inp, val, sel in zip(self.inps, self.sample, self.rhypos): if '_' not in inp.symbol_name(): hypo = Implies(self.selv, Implies(sel, Equals(inp, Real(float(val))))) else: hypo = Implies(self.selv, Implies(sel, inp if val else Not(inp))) self.oracle.add_assertion(hypo) else: for inp, val, sel in zip(self.inps, self.sample, self.rhypos): inp = inp.symbol_name() # determining the right interval and the corresponding variable for ub, fvar in zip(self.intvs[inp], self.ivars[inp]): if ub == '+' or val < ub: hypo = Implies(self.selv, Implies(sel, fvar)) break self.oracle.add_assertion(hypo) # in case of categorical data, there are selector duplicates # and we need to remove them self.rhypos = sorted(set(self.rhypos), key=lambda x: int(x.symbol_name()[6:])) # propagating the true observation if self.oracle.solve([self.selv] + self.rhypos): model = self.oracle.get_model() else: assert 0, 'Formula is unsatisfiable under given assumptions' # choosing the maximum outvals = [float(model.get_py_value(o)) for o in self.outs] maxoval = max(zip(outvals, range(len(outvals)))) # correct class id (corresponds to the maximum computed) self.out_id = maxoval[1] self.output = self.xgb.target_name[self.out_id] # forcing a misclassification, i.e. a wrong observation disj = [] for i in range(len(self.outs)): if i != self.out_id: disj.append(GT(self.outs[i], self.outs[self.out_id])) self.oracle.add_assertion(Implies(self.selv, Or(disj))) if self.verbose: inpvals = self.xgb.readable_sample(sample) self.preamble = [] for f, v in zip(self.xgb.feature_names, inpvals): if f not in v: self.preamble.append('{0} = {1}'.format(f, v)) else: self.preamble.append(v) print(' explaining: "IF {0} THEN {1}"'.format( ' AND '.join(self.preamble), self.output)) def explain(self, sample, smallest): """ Hypotheses minimization. """ # reinitializing the number of used oracle calls # 1 because of the initial call checking the entailment self.calls = 1 # adapt the solver to deal with the current sample self.prepare(sample) # saving external explanation to be minimized further self.to_consider = [True for h in self.rhypos] self.time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \ resource.getrusage(resource.RUSAGE_SELF).ru_utime # if satisfiable, then the observation is not implied by the hypotheses if self.oracle.solve( [self.selv] + [h for h, c in zip(self.rhypos, self.to_consider) if c]): print(' no implication!') print(self.oracle.get_model()) sys.exit(1) if self.optns.xtype == 'abductive': # abductive explanations => MUS computation and enumeration if not smallest and self.optns.xnum == 1: expls = [self.compute_minimal_abductive()] else: expls = self.enumerate_abductive(smallest=smallest) else: # contrastive explanations => MCS enumeration if self.optns.usemhs: expls = self.enumerate_contrastive() else: if not smallest: expls = self.enumerate_minimal_contrastive() else: # expls = self.enumerate_smallest_contrastive() expls = self.enumerate_contrastive() self.time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \ resource.getrusage(resource.RUSAGE_SELF).ru_utime - self.time expls = list( map(lambda expl: sorted([self.sel2fid[h] for h in expl]), expls)) if self.dualx: self.dualx = list( map(lambda expl: sorted([self.sel2fid[h] for h in expl]), self.dualx)) if self.verbose: if expls[0] != None: for expl in expls: preamble = [self.preamble[i] for i in expl] if self.optns.xtype == 'abductive': print(' explanation: "IF {0} THEN {1}"'.format( ' AND '.join(preamble), self.xgb.target_name[self.out_id])) else: print( ' explanation: "IF NOT {0} THEN NOT {1}"'.format( ' AND NOT '.join(preamble), self.xgb.target_name[self.out_id])) print(' # hypos left:', len(expl)) print(' time: {0:.2f}'.format(self.time)) # here we return the last computed explanation return expls def compute_minimal_abductive(self): """ Compute any subset-minimal explanation. """ i = 0 # filtering out unnecessary features if external explanation is given rhypos = [h for h, c in zip(self.rhypos, self.to_consider) if c] # simple deletion-based linear search while i < len(rhypos): to_test = rhypos[:i] + rhypos[(i + 1):] self.calls += 1 if self.oracle.solve([self.selv] + to_test): i += 1 else: rhypos = to_test return rhypos def enumerate_minimal_contrastive(self): """ Compute a subset-minimal contrastive explanation. """ def _overapprox(): model = self.oracle.get_model() for sel in self.rhypos: if int(model.get_py_value(sel)) > 0: # soft clauses contain positive literals # so if var is true then the clause is satisfied self.ss_assumps.append(sel) else: self.setd.append(sel) def _compute(): i = 0 while i < len(self.setd): if self.optns.usecld: _do_cld_check(self.setd[i:]) i = 0 if self.setd: # it may be empty after the clause D check self.calls += 1 self.ss_assumps.append(self.setd[i]) if not self.oracle.solve([self.selv] + self.ss_assumps + self.bb_assumps): self.ss_assumps.pop() self.bb_assumps.append(Not(self.setd[i])) i += 1 def _do_cld_check(cld): self.cldid += 1 sel = Symbol('{0}_{1}'.format(self.selv.symbol_name(), self.cldid)) cld.append(Not(sel)) # adding clause D self.oracle.add_assertion(Or(cld)) self.ss_assumps.append(sel) self.setd = [] st = self.oracle.solve([self.selv] + self.ss_assumps + self.bb_assumps) self.ss_assumps.pop() # removing clause D assumption if st == True: model = self.oracle.get_model() for l in cld[:-1]: # filtering all satisfied literals if int(model.get_py_value(l)) > 0: self.ss_assumps.append(l) else: self.setd.append(l) else: # clause D is unsatisfiable => all literals are backbones self.bb_assumps.extend([Not(l) for l in cld[:-1]]) # deactivating clause D self.oracle.add_assertion(Not(sel)) # sets of selectors to work with self.cldid = 0 expls = [] # detect and block unit-size MCSes immediately if self.optns.unitmcs: for i, hypo in enumerate(self.rhypos): self.calls += 1 if self.oracle.solve([self.selv] + self.rhypos[:i] + self.rhypos[(i + 1):]): expls.append([hypo]) if len(expls) != self.optns.xnum: self.oracle.add_assertion(Or([Not(self.selv), hypo])) else: break self.calls += 1 while self.oracle.solve([self.selv]): self.ss_assumps, self.bb_assumps, self.setd = [], [], [] _overapprox() _compute() expl = [list(f.get_free_variables())[0] for f in self.bb_assumps] expls.append(expl) if len(expls) == self.optns.xnum: break self.oracle.add_assertion(Or([Not(self.selv)] + expl)) self.calls += 1 self.calls += self.cldid return expls if expls else [None] def enumerate_abductive(self, smallest=True): """ Compute a cardinality-minimal explanation. """ # result expls = [] # just in case, let's save dual (contrastive) explanations self.dualx = [] with Hitman(bootstrap_with=[[ i for i in range(len(self.rhypos)) if self.to_consider[i] ]], htype='sorted' if smallest else 'lbx') as hitman: # computing unit-size MCSes for i, hypo in enumerate(self.rhypos): if self.to_consider[i] == False: continue self.calls += 1 if self.oracle.solve([self.selv] + self.rhypos[:i] + self.rhypos[(i + 1):]): hitman.hit([i]) self.dualx.append([self.rhypos[i]]) # main loop iters = 0 while True: hset = hitman.get() iters += 1 if self.verbose > 1: print('iter:', iters) print('cand:', hset) if hset == None: break self.calls += 1 if self.oracle.solve([self.selv] + [self.rhypos[i] for i in hset]): to_hit = [] satisfied, unsatisfied = [], [] removed = list( set(range(len(self.rhypos))).difference(set(hset))) model = self.oracle.get_model() for h in removed: i = self.sel2fid[self.rhypos[h]] if '_' not in self.inps[i].symbol_name(): # feature variable and its expected value var, exp = self.inps[i], self.sample[i] # true value true_val = float(model.get_py_value(var)) if not exp - 0.001 <= true_val <= exp + 0.001: unsatisfied.append(h) else: hset.append(h) else: for vid in self.sel2vid[self.rhypos[h]]: var, exp = self.inps[vid], int( self.sample[vid]) # true value true_val = int(model.get_py_value(var)) if exp != true_val: unsatisfied.append(h) break else: hset.append(h) # computing an MCS (expensive) for h in unsatisfied: self.calls += 1 if self.oracle.solve([self.selv] + [self.rhypos[i] for i in hset] + [self.rhypos[h]]): hset.append(h) else: to_hit.append(h) if self.verbose > 1: print('coex:', to_hit) hitman.hit(to_hit) self.dualx.append([self.rhypos[i] for i in to_hit]) else: if self.verbose > 1: print('expl:', hset) expl = [self.rhypos[i] for i in hset] expls.append(expl) if len(expls) != self.optns.xnum: hitman.block(hset) else: break return expls def enumerate_smallest_contrastive(self): """ Compute a cardinality-minimal contrastive explanation. """ # result expls = [] # computing unit-size MUSes muses = set([]) for hypo in self.rhypos: self.calls += 1 if not self.oracle.solve([self.selv, hypo]): muses.add(hypo) # we are going to discard unit-size MUSes from consideration rhypos = set(self.rhypos).difference(muses) # introducing interer cost literals for rhypos costlits = [] for i, hypo in enumerate(rhypos): costlit = Symbol(name='costlit_{0}_{1}'.format( self.selv.symbol_name(), i), typename=INT) costlits.append(costlit) self.oracle.add_assertion( Ite(hypo, Equals(costlit, Int(0)), Equals(costlit, Int(1)))) # main loop (linear search unsat-sat) i = 0 while i < len(rhypos) and len(expls) != self.optns.xnum: # fresh selector for the current iteration sit = Symbol('iter_{0}_{1}'.format(self.selv.symbol_name(), i)) # adding cardinality constraint self.oracle.add_assertion(Implies(sit, LE(Plus(costlits), Int(i)))) # extracting explanations from MaxSAT models while self.oracle.solve([self.selv, sit]): self.calls += 1 model = self.oracle.get_model() expl = [] for hypo in rhypos: if int(model.get_py_value(hypo)) == 0: expl.append(hypo) # each MCS contains all unit-size MUSes expls.append(list(muses) + expl) # either stop or add a blocking clause if len(expls) != self.optns.xnum: self.oracle.add_assertion(Implies(self.selv, Or(expl))) else: break i += 1 self.calls += 1 return expls def enumerate_contrastive(self, smallest=True): """ Compute a cardinality-minimal contrastive explanation. """ # core extraction is done via calling Z3's internal API assert self.optns.solver == 'z3', 'This procedure requires Z3' # result expls = [] # just in case, let's save dual (abductive) explanations self.dualx = [] # mapping from hypothesis variables to their indices hmap = {h: i for i, h in enumerate(self.rhypos)} # mapping from internal Z3 variable into variables of PySMT vmap = {self.oracle.converter.convert(v): v for v in self.rhypos} vmap[self.oracle.converter.convert(self.selv)] = None def _get_core(): core = self.oracle.z3.unsat_core() return sorted(filter(lambda x: x != None, map(lambda x: vmap[x], core)), key=lambda x: int(x.symbol_name()[6:])) def _do_trimming(core): for i in range(self.optns.trim): self.calls += 1 self.oracle.solve([self.selv] + core) new_core = _get_core() if len(core) == len(new_core): break return new_core def _reduce_lin(core): def _assump_needed(a): if len(to_test) > 1: to_test.remove(a) self.calls += 1 if not self.oracle.solve([self.selv] + list(to_test)): return False to_test.add(a) return True else: return True to_test = set(core) return list(filter(lambda a: _assump_needed(a), core)) def _reduce_qxp(core): coex = core[:] filt_sz = len(coex) / 2.0 while filt_sz >= 1: i = 0 while i < len(coex): to_test = coex[:i] + coex[(i + int(filt_sz)):] self.calls += 1 if to_test and not self.oracle.solve([self.selv] + to_test): # assumps are not needed coex = to_test else: # assumps are needed => check the next chunk i += int(filt_sz) # decreasing size of the set to filter filt_sz /= 2.0 if filt_sz > len(coex) / 2.0: # next size is too large => make it smaller filt_sz = len(coex) / 2.0 return coex def _reduce_coex(core): if self.optns.reduce == 'lin': return _reduce_lin(core) else: # qxp return _reduce_qxp(core) with Hitman(bootstrap_with=[[ i for i in range(len(self.rhypos)) if self.to_consider[i] ]], htype='sorted' if smallest else 'lbx') as hitman: # computing unit-size MUSes for i, hypo in enumerate(self.rhypos): if self.to_consider[i] == False: continue self.calls += 1 if not self.oracle.solve([self.selv, self.rhypos[i]]): hitman.hit([i]) self.dualx.append([self.rhypos[i]]) elif self.optns.unitmcs: self.calls += 1 if self.oracle.solve([self.selv] + self.rhypos[:i] + self.rhypos[(i + 1):]): # this is a unit-size MCS => block immediately hitman.block([i]) expls.append([self.rhypos[i]]) # main loop iters = 0 while True: hset = hitman.get() iters += 1 if self.verbose > 1: print('iter:', iters) print('cand:', hset) if hset == None: break self.calls += 1 if not self.oracle.solve([self.selv] + [ self.rhypos[h] for h in list( set(range(len(self.rhypos))).difference(set(hset))) ]): to_hit = _get_core() if len(to_hit) > 1 and self.optns.trim: to_hit = _do_trimming(to_hit) if len(to_hit) > 1 and self.optns.reduce != 'none': to_hit = _reduce_coex(to_hit) self.dualx.append(to_hit) to_hit = [hmap[h] for h in to_hit] if self.verbose > 1: print('coex:', to_hit) hitman.hit(to_hit) else: if self.verbose > 1: print('expl:', hset) expl = [self.rhypos[i] for i in hset] expls.append(expl) if len(expls) != self.optns.xnum: hitman.block(hset) else: break return expls
def callback(model, converter, result): """Callback for msat_all_sat. This function is called by the MathSAT API everytime a new model is found. If the function returns 1, the search continues, otherwise it stops. """ # Elements in model are msat_term . # Converter.back() provides the pySMT representation of a solver term. py_model = [converter.back(v) for v in model] result.append(And(py_model)) return 1 # go on x, y = Symbol("x"), Symbol("y") f = Or(x, y) msat = Solver(name="msat") converter = msat.converter # .converter is a property implemented by all solvers msat.add_assertion(f) # This is still at pySMT level result = [] # Directly invoke the mathsat API !!! # The second term is a list of "important variables" mathsat.msat_all_sat(msat.msat_env(), [converter.convert(x)], # Convert the pySMT term into a MathSAT term lambda model : callback(model, converter, result)) print("'exists y . %s' is equivalent to '%s'" %(f, Or(result))) #exists y . (x | y) is equivalent to ((! x) | x)
def get_makespan_optimal_weakly_hard_schedule(g, network, feasibility_timeout=None, optimization_timeout=10 * 60 * 1000): vprint('*computing optimal weakly-hard real-time schedule via SMT*') # SMT formulation tc = transitive_closure(g) logical_edges = get_logical_edges(g) JUMPTABLE_MAX = 6 K_MAX = 5001 # LAMBDA(i)[1] < K_MAX for all i < JUMPTABLE_MAX A, B, C, D, GAMMA, LAMBDA = (network[key] for key in ('A', 'B', 'C', 'D', 'GAMMA', 'LAMBDA')) assert (all(map(lambda x: LAMBDA(x)[1] < K_MAX, range(JUMPTABLE_MAX)))) vprint('\tinstantiating symvars...') label = [Symbol('label_%i' % i, INT) for i in range(len(logical_edges))] # first half for slot, second half for beacons chi = [Symbol('chi_%i' % i, INT) for i in range(2 * len(logical_edges))] duration = [ Symbol('duration_%i' % i, INT) for i in range(len(logical_edges)) ] zeta = [ Symbol('zeta_%i' % i, INT) for i in range(g.num_vertices() + len(logical_edges)) ] delta_e_in_r = [[ Symbol('delta_e_in_r-%i_%i' % (i, j), INT) for j in range(len(logical_edges)) ] for i in range(len(logical_edges))] delta_chi_eq_i = [[ Symbol('delta_chi_eq_i-%i_%i' % (i, j), INT) for j in range(JUMPTABLE_MAX) ] for i in range(2 * len(logical_edges))] delta_tau_before_r = [[ Symbol('delta_tau_before_r-%i_%i' % (i, j), INT) for j in range(len(logical_edges)) ] for i in range(g.num_vertices() + len(logical_edges))] vprint('\tgenerating constraint clauses...') domain = And([ And([ And(LE(Int(1), sym), LE(sym, Int(len(logical_edges)))) for sym in label ]), And([And(LE(Int(1), sym), LT(sym, Int(JUMPTABLE_MAX))) for sym in chi]), And([LE(Int(0), sym) for sym in zeta]), And([ And(LE(Int(0), sym), LE(sym, Int(1))) for sym in chain.from_iterable(delta_e_in_r + delta_chi_eq_i + delta_tau_before_r) ]) ]) one_hot = And([ And([ Equals( Plus([delta_e_in_r[e][r] for r in range(len(logical_edges))]), Int(1)) for e in range(len(logical_edges)) ]), And([ Equals( Plus([delta_chi_eq_i[chir][i] for i in range(JUMPTABLE_MAX)]), Int(1)) for chir in range(2 * len(logical_edges)) ]) ]) CFOP = And([ LT(label[logical_edges.index(r)], label[logical_edges.index(s)]) for r, s in product(logical_edges, repeat=2) if r.source() in tc.get_in_neighbors(s.source()) ]) task_partitioning_by_round = And( And([ LE(delta_tau_before_r[int(tau)][r], delta_tau_before_r[int(mu)][r]) for tau, mu, r in product(tc.vertices(), tc.vertices(), range(len(logical_edges))) if tau in tc.get_in_neighbors(mu) ]), And([ Equals(delta_tau_before_r[r + g.num_vertices()][s], Int(0)) if r < s else Equals(delta_tau_before_r[r + g.num_vertices()][s], Int(1)) for r, s in product(range(len(logical_edges)), repeat=2) ])) round_empty = And([ Implies( Equals( Plus([delta_e_in_r[e][r] for e in range(len(logical_edges))]), Int(0)), Equals(chi[len(logical_edges) + r], Int(1))) for r in range(len(logical_edges)) ]) durations = And([ Equals( duration[r], Plus( Int(A), Times(Plus(Times(Int(2), chi[r + len(logical_edges)]), Int(B)), Int(C + D * GAMMA)), Times( Ite( GE( Plus([ delta_e_in_r[e][r] for e in range(len(logical_edges)) ]), Int(1)), Int(0), Int(-1)), Int(A + (2 + B) * (C + D * GAMMA))), Plus([ Ite( Equals(delta_e_in_r[e][r], Int(1)), Plus( Int(A), Times( Plus(Times(Int(2), chi[e]), Int(B)), Int(C + D * g.edge_properties['widths'][ logical_edges[e]]))), Int(0)) for e in range(len(logical_edges)) ]))) for r in range(len(logical_edges)) ]) label_to_delta = And([ Equals( label[e], Plus([ Times(delta_e_in_r[e][i - 1], Int(i)) for i in range(1, 1 + len(logical_edges)) ])) for e in range(len(logical_edges)) ]) chi_to_delta = And([ Equals( chi[chir], Plus([ Times(delta_chi_eq_i[chir][i], Int(i)) for i in range(JUMPTABLE_MAX) ])) for chir in range(2 * len(logical_edges)) ]) order = And( And([ LT(zeta[int(tau)], Minus(zeta[int(mu)], Int(g.vertex_properties['durations'][mu]))) for tau, mu in product(g.vertices(), repeat=2) if tau in tc.get_in_neighbors(mu) ]), And([ LT(zeta[r + g.num_vertices()], Minus(zeta[r + 1 + g.num_vertices()], duration[r + 1])) for r in range(len(logical_edges) - 1) ]), And([ Implies( Equals(delta_e_in_r[e][r], Int(1)), GT( Minus(zeta[int(tau)], Int(g.vertex_properties['durations'][tau])), zeta[r + g.num_vertices()])) for tau in g.vertices() for r in range(len(logical_edges)) for e in range(len(logical_edges)) if tau in tc.get_out_neighbors(logical_edges[e].source()) ]), And([ Implies( Equals(delta_e_in_r[e][r], Int(1)), GT(Minus(zeta[r + g.num_vertices()], duration[r]), zeta[int(tau)])) for tau in g.vertices() for r in range(len(logical_edges)) for e in range(len(logical_edges)) if tau in tc.get_in_neighbors(logical_edges[e].source()) or tau == logical_edges[e].source() ])) exclusion = And([ And( Implies( Equals(delta_tau_before_r[int(tau)][r], Int(0)), GT(Minus(zeta[r + g.num_vertices()], duration[r]), zeta[int(tau)])), Implies( Equals(delta_tau_before_r[int(tau)][r], Int(1)), GT( Minus(zeta[int(tau)], Int(g.vertex_properties['durations'][tau])), zeta[g.num_vertices() + r]))) for tau in g.vertices() for r in range(len(logical_edges)) ]) deadline = And([ LE(zeta[int(tau)], Int(g.vertex_properties['deadlines'][tau])) for tau in g.vertices() if g.vertex_properties['deadlines'][tau] >= 0 ]) def sum_m(tau): return Plus([Int(0)] + [ Plus( Ite(Equals(delta_chi_eq_i[e][i], Int(1)), Int(LAMBDA(i)[0]), Int(0)), Plus([ Ite( Equals(delta_chi_eq_i[len(logical_edges) + r][i], Int(1)), Ite(Equals(delta_e_in_r[e][r], Int(1)), Int(LAMBDA(i)[0]), Int(0)), Int(0)) for r in range(len(logical_edges)) ])) for i in range(JUMPTABLE_MAX) for e in range(len(logical_edges)) if logical_edges[e].source() in tc.get_in_neighbors(tau) ]) def min_K(tau): return Min([Int(K_MAX)] + [ Min( Ite(Equals(delta_chi_eq_i[e][i], Int(1)), Int(LAMBDA(i)[1]), Int(K_MAX)), Min([ Ite( Equals(delta_chi_eq_i[len(logical_edges) + r][i], Int(1)), Ite(Equals(delta_e_in_r[e][r], Int(1)), Int(LAMBDA(i)[1]), Int(K_MAX)), Int(K_MAX)) for r in range(len(logical_edges)) ])) for i in range(JUMPTABLE_MAX) for r in range(len(logical_edges)) for e in range(len(logical_edges)) if logical_edges[e].source() in tc.get_in_neighbors(tau) ]) WH = And([ And( GE(Int(g.vertex_properties['weakly-hard'][tau][0]), Min(sum_m(tau), min_K(tau))), LE(Int(g.vertex_properties['weakly-hard'][tau][1]), min_K(tau))) for tau in g.vertices() if g.vertex_properties['weakly-hard'][tau][0] >= 0 ]) formula = And([ domain, one_hot, CFOP, task_partitioning_by_round, round_empty, durations, label_to_delta, chi_to_delta, order, exclusion, deadline, WH ]) vprint('\tchecking feasibility...') solver = Solver(name='z3', incremental=True, logic='LIA') if feasibility_timeout: solver.z3.set('timeout', feasibility_timeout) solver.add_assertion(formula) try: result = solver.solve() except SolverReturnedUnknownResultError: result = None if not result: vprint('\tsolver returned infeasible!') return [None] * 4 else: models = [solver.get_model()] vprint('\tsolver found a feasible solution, optimizing...') solver.z3.set('timeout', optimization_timeout) LB = 0 UB = max(map(lambda x: models[-1].get_py_value(x), zeta)) curr_B = UB // 2 while range(LB + 1, UB): try: result = solver.solve( [And([LT(zeta_tau, Int(curr_B)) for zeta_tau in zeta])]) except SolverReturnedUnknownResultError: vprint('\t(timeout, not necessarily unsat)') result = None if result: vprint('\tfound feasible solution of length %i, optimizing...' % curr_B) models.append(solver.get_model()) UB = curr_B else: vprint('\tnew lower bound %i, optimizing...' % curr_B) LB = curr_B curr_B = LB + int(ceil((UB - LB) / 2)) vprint('\tsolver returned optimal (under composition+P.O. abstractions)!') best_model = models[-1] zeta = list(map(lambda x: best_model.get_py_value(x), zeta)) chi = list(map(lambda x: best_model.get_py_value(x), chi)) duration = list(map(lambda x: best_model.get_py_value(x), duration)) label = list(map(lambda x: best_model.get_py_value(x), label)) return zeta, chi, duration, label
o = BV(2, VECT_WIDTH) # o - cpu player x_turns = 0 o_turns = 0 x_val = Cell.x.value.constant_value() o_val = Cell.o.value.constant_value() board = [[FreshSymbol(BVType(VECT_WIDTH)) for _ in range(3)] for _ in range(3)] solver = Solver() # initialise board cells, each one has to be blank, x or o for row in board: for cell in row: solver.add_assertion(Or([Equals(cell, i.value) for i in Cell])) # load board test = 'tests/blank.txt' with open(test) as fh: for row, line in enumerate(fh.readlines()): for col, cell in enumerate(line.strip().split(' ')): if cell == Cell.x.name: solver.add_assertion(Equals(board[row][col], Cell.x.value)) elif cell == Cell.o.name: solver.add_assertion(Equals(board[row][col], Cell.o.value)) def already_played(row, col): if solver.get_value(board[row][col]) == Cell.s.value: return False
def comb_attack(self): # dis generator solver_name = 'yices' solver_obf = Solver(name=solver_name) solver_key = Solver(name=solver_name) self.solver_oracle = Solver(name=solver_name) attack_formulas = FormulaGenerator(self.oracle_cir, self.obf_cir) f = attack_formulas.dip_gen_ckt # f = simplify(f) solver_obf.add_assertion(f) f = attack_formulas.key_inequality_ckt # f = simplify(f) solver_obf.add_assertion(f) for l in self.oracle_cir.wire_objs: self.solver_oracle.add_assertion(l.formula) dip_list = [] stateful_keys = [] iteration = 0 while 1: # query dip generator if solver_obf.solve(): dip_formula = [] dip_boolean = [] for l in self.oracle_cir.input_wires: s = Symbol(l.name) if solver_obf.get_py_value(s): dip_formula.append(s) dip_boolean.append(TRUE()) else: dip_formula.append(Not(s)) dip_boolean.append(FALSE()) logging.info(dip_formula) # query oracle dip_out = self.query_oracle(dip_formula) logging.info(dip_out) # check for stateful condition if dip_formula in dip_list: # ban stateful key logging.info("found a repeated dip!") # check outputs for both keys key = None for l in self.obf_cir.output_wires: s1 = Symbol(l.name + '@dc1') s2 = Symbol(l.name + '@dc2') if solver_obf.get_py_value( s1) != solver_obf.get_py_value(s2): if solver_obf.get_py_value( s1) != self.solver_oracle.get_py_value( Symbol(l.name)): key = '0' else: key = '1' break if key == None: logging.critical( 'something is wrong when banning keys') # find assigned keys key_list = [] for l in self.obf_cir.key_wires: k = Symbol(l.name + '_' + key) if solver_obf.get_py_value(k): key_list.append(k) else: key_list.append(Not(k)) stateful_keys.append(key_list) # ban the stateful key f = Not(And(key_list)) solver_obf.add_assertion(f) solver_key.add_assertion(f) if len(stateful_keys) % 5000 == 0: logging.warning('current stateful keys: {}'.format( len(stateful_keys))) continue else: dip_list.append(dip_formula) # add dip checker f = [] f.append( attack_formulas.gen_dip_chk(iteration * 2, '_0', dip_boolean)) f.append( attack_formulas.gen_dip_chk(iteration * 2 + 1, '_1', dip_boolean)) for i in range(len(self.obf_cir.output_wires)): l = self.obf_cir.output_wires[i].name f.append( And( Iff(dip_out[i], Symbol(l + '@{}'.format(iteration * 2))), Iff(dip_out[i], Symbol(l + '@{}'.format(iteration * 2 + 1))))) f = And(f) solver_obf.add_assertion(f) solver_key.add_assertion(f) iteration += 1 logging.warning('iteration: {}'.format(iteration)) else: logging.warning('print keys') logging.warning('stateful keys: {}'.format(len(stateful_keys))) if solver_key.solve(): key = '' for l in self.obf_cir.key_wires: if solver_key.get_py_value(Symbol(l.name + '_0')): key += '1' else: key += '0' print("key=%s" % key) else: logging.critical('key solver returned UNSAT') return