Exemple #1
0
    def _accept_word(self, ts_enc, ts, word, final_states):
        """ Check if a particular word with a given final state
        is accepted by ts
        """

        # error is encoded in the final state
        bmc = BMC(ts_enc.helper, ts, TRUE())

        solver = Solver(name='z3', logic=QF_BOOL)
        all_vars = set(ts.state_vars)
        all_vars.update(ts.input_vars)

        bmc.encode_up_to_k(solver, all_vars, len(word))

        error = bmc.helper.get_formula_at_i(all_vars, final_states, len(word))
        solver.add_assertion(error)

        # encode the word
        for i in range(len(word)):
            w_formula = ts_enc.r2a.get_msg_eq(word[i])
            w_at_i = bmc.helper.get_formula_at_i(ts.input_vars, w_formula, i)
            solver.add_assertion(w_at_i)

        res = solver.solve()
        # if res:
        #     model = solver.get_model()
        #     print model
        return res
Exemple #2
0
 def test_btor_does_not_support_int_arrays(self):
     a = Symbol("a", ARRAY_INT_INT)
     formula = Equals(Select(Store(a, Int(10), Int(100)), Int(10)),
                      Int(100))
     btor = Solver(name="btor")
     with self.assertRaises(ConvertExpressionError):
         btor.add_assertion(formula)
Exemple #3
0
 def _compute_WMI_PA_no_boolean(self, lab_formula, pa_vars, labels, other_assignments={}):
     """Finds all the assignments that satisfy the given formula using AllSAT.
         
     Args:
         lab_formula (FNode): The labelled pysmt formula to examine.
         pa_vars (): 
         labels (dict): The dictionary containing the correlation between each label and their
             true value.
         
     Yields:
         dict: One of the assignments that satisfies the formula.
     
     """
     solver = Solver(name="msat",
                 solver_options={"dpll.allsat_minimize_model" : "true"})
     converter = solver.converter
     solver.add_assertion(lab_formula)
     lra_assignments = []
     mathsat.msat_all_sat(
         solver.msat_env(),
         [converter.convert(v) for v in pa_vars],
         lambda model : WMI._callback(model, converter, lra_assignments))
     for mu_lra in lra_assignments:                    
         assignments = {}
         for atom, value in WMI._get_assignments(mu_lra).items():
             if atom in labels:
                 atom = labels[atom]
             assignments[atom] = value
         assignments.update(other_assignments)
         yield assignments
Exemple #4
0
 def _compute_TTAs(self, formula):
     """Computes the total truth assignments of the given formula.
     
     This method first labels the formula and then uses the funtionality of mathsat called AllSAT to
         retrieve all the total truth assignments.
     
     Args:
         formula (FNode): The pysmt formula to examine.
     
     Returns:
         list: The list of all the total truth assignments.
         dict: The dictionary containing all the correspondence between the labels and their true value.
         
     """
     labels = {}
     expressions = []
     allsat_variables = set()
     
     # Label LRA atoms with fresh boolean variables
     labelled_formula, pa_vars, labels = self.label_formula(formula, formula.get_atoms())
     
     # Perform AllSMT on the labelled formula
     solver = Solver(name="msat")
     converter = solver.converter
     solver.add_assertion(labelled_formula)
     models = []
     mathsat.msat_all_sat(solver.msat_env(),
                     [converter.convert(v) for v in pa_vars],
                     lambda model : WMI._callback(model, converter, models))
     return models, labels
Exemple #5
0
 def test_btor_does_not_support_int_arrays(self):
     a = Symbol("a", ARRAY_INT_INT)
     formula = Equals(Select(Store(a, Int(10), Int(100)), Int(10)),
                      Int(100))
     btor = Solver(name="btor")
     with self.assertRaises(ConvertExpressionError):
         btor.add_assertion(formula)
Exemple #6
0
def configure(inpcode, options):
  global solver
  global solver_name
  if ("pysmt" in options):
    solver_name = options["pysmt"]
  elif ("smtpipe" in options):
    solver_name = "custom: " + options["smtpipe"].split("/")[-1]
    solver_cmd = options["smtpipe"]
    # solver_logics = [BOOL, LIA, LRA, NIA, NRA, QF_LRA, QF_NIA, QF_NRA, QF_UFLIA, QF_UFLRA, QF_UFNIA, QF_UFNRA, UFLIRA, UFLRA, UFNIA, AUFNIRA]
    # solver_logics = [AUFNIRA]
    solver_logic = eval(options["pipe_logic"])
    solver_logics = [solver_logic]
    env = get_env()
    env.factory.add_generic_solver(solver_name, solver_cmd, solver_logics)
  verbprint = utils.verbprint
  verbprint(2, "Parsing the input...", False)
  parser = SmtLibParser()
  script = parser.get_script(inpcode)
  formula = script.get_last_formula()
  # print "Got a formula: " + str(f)
  verbprint(2, "done.", True)
  verbprint(2, "Initializing the solver...", False)
  if ("smtpipe" in options):
    solver = Solver(name=solver_name, logic=solver_logic)
    #solver = SmtLibSolver(options["smtpipe"], env, LRA)
  else:
    solver = Solver(name=solver_name)
  verbprint(2, "done.", True)
  verbprint(2, "Loading the input file in the solver...", False)
  solver.add_assertion(formula)
  verbprint(2, "done.", True)
Exemple #7
0
def main(path, solver_name):
    solver = Solver(solver_name)
    parser = SmtLibParser()
    script = parser.get_script_fname(path)
    formula = script.get_last_formula()
    solver.add_assertion(formula)
    result = solver.solve()
    print(result)
Exemple #8
0
 def test_btor_options(self):
     for (f, _, sat, logic) in get_example_formulae():
         if logic == QF_BV:
             solver = Solver(name="btor",
                             solver_options={"rewrite-level":0,
                                             "fun:dual-prop":1,
                                             "eliminate-slices":1})
             solver.add_assertion(f)
             res = solver.solve()
             self.assertTrue(res == sat)
Exemple #9
0
 def test_examples_solving(self):
     for example in EXAMPLE_FORMULAS:
         if example.logic != pysmt.logics.BOOL:
             continue
         solver = Solver(logic=pysmt.logics.BOOL, name='bdd')
         solver.add_assertion(example.expr)
         if example.is_sat:
             self.assertTrue(solver.solve())
         else:
             self.assertFalse(solver.solve())
Exemple #10
0
 def test_examples_solving(self):
     for example in get_example_formulae():
         if example.logic != BOOL:
             continue
         solver = Solver(logic=BOOL, name='bdd')
         solver.add_assertion(example.expr)
         if example.is_sat:
             self.assertTrue(solver.solve())
         else:
             self.assertFalse(solver.solve())
Exemple #11
0
 def test_btor_options(self):
     for (f, _, sat, logic) in get_example_formulae():
         if logic == QF_BV:
             solver = Solver(name="btor",
                             solver_options={"rewrite-level":0,
                                             "fun:dual-prop":1,
                                             "eliminate-slices":1})
             solver.add_assertion(f)
             res = solver.solve()
             self.assertTrue(res == sat)
Exemple #12
0
 def test_examples_solving(self):
     for example in get_example_formulae():
         if example.logic != BOOL:
             continue
         solver = Solver(logic=BOOL,
                         name='bdd')
         solver.add_assertion(example.expr)
         if example.is_sat:
             self.assertTrue(solver.solve())
         else:
             self.assertFalse(solver.solve())
Exemple #13
0
    def test_msat_partial_model(self):
        msat = Solver(name="msat")
        x, y = Symbol("x"), Symbol("y")
        msat.add_assertion(x)
        c = msat.solve()
        self.assertTrue(c)

        model = msat.get_model()
        self.assertNotIn(y, model)
        self.assertIn(x, model)
        msat.exit()
Exemple #14
0
    def test_msat_partial_model(self):
        msat = Solver(name="msat")
        x, y = Symbol("x"), Symbol("y")
        msat.add_assertion(x)
        c = msat.solve()
        self.assertTrue(c)

        model = msat.get_model()
        self.assertNotIn(y, model)
        self.assertIn(x, model)
        msat.exit()
Exemple #15
0
 def test_examples_solving(self):
     for example in EXAMPLE_FORMULAS:
         if example.logic != pysmt.logics.BOOL:
             continue
         solver = Solver(logic=pysmt.logics.BOOL,
                         name='bdd')
         solver.add_assertion(example.expr)
         if example.is_sat:
             self.assertTrue(solver.solve())
         else:
             self.assertFalse(solver.solve())
Exemple #16
0
def get_value():
    solver = Solver('z3')
    parser = SmtLibParser()
    script = parser.get_script_fname("get_value.smt2")
    #result = script.evaluate(Solver('z3'))
    exprs = []
    for get_val_cmd in script.filter_by_command_name("get-value"):
        exprs.extend(get_val_cmd.args)
    formula = script.get_last_formula()
    solver.add_assertion(formula)
    result1 = solver.solve()
    result2 = solver.get_values(exprs)
Exemple #17
0
def solve(formula, n, max_models=None, solver="msat"):
    s = Solver(name=solver)
    st = s.is_sat(formula)
    if st:
        vs = [x for xs in variables(n) for x in xs]
        k = 0
        s.add_assertion(formula)
        while s.solve() and ((not max_models) or k < max_models):
            k = k + 1
            model = s.get_model()
            s.add_assertion(Not(And([EqualsOrIff(v, model[v]) for v in vs])))
            yield to_bn(model, n)
Exemple #18
0
    def get_prepared_solver(self, logic, solver_name=None):
        """Returns a solver initialized with the sudoku constraints and a
        matrix of SMT variables, each representing a cell of the game.
        """
        sq_size = self.size**2
        ty = self.get_type()
        var_table = [[FreshSymbol(ty) for _ in xrange(sq_size)]
                          for _ in xrange(sq_size)]

        solver = Solver(logic=logic, name=solver_name)
        # Sudoku constraints
        # all variables are positive and lower or equal to than sq_size
        for row in var_table:
            for var in row:
                solver.add_assertion(Or([Equals(var, self.const(i))
                                         for i in xrange(1, sq_size + 1)]))

        # each row and each column contains all different numbers
        for i in xrange(sq_size):
            solver.add_assertion(AllDifferent(var_table[i]))
            solver.add_assertion(AllDifferent([x[i] for x in var_table]))

        # each square contains all different numbers
        for sx in xrange(self.size):
            for sy in xrange(self.size):
                square = [var_table[i + sx * self.size][j + sy * self.size]
                          for i in xrange(self.size) for j in xrange(self.size)]
                solver.add_assertion(AllDifferent(square))

        return solver, var_table
Exemple #19
0
    def get_prepared_solver(self, logic, solver_name=None):
        """Returns a solver initialized with the sudoku constraints and a
        matrix of SMT variables, each representing a cell of the game.
        """
        sq_size = self.size**2
        ty = self.get_type()
        var_table = [[FreshSymbol(ty) for _ in range(sq_size)]
                          for _ in range(sq_size)]

        solver = Solver(logic=logic, name=solver_name)
        # Sudoku constraints
        # all variables are positive and lower or equal to than sq_size
        for row in var_table:
            for var in row:
                solver.add_assertion(Or([Equals(var, self.const(i))
                                         for i in range(1, sq_size + 1)]))

        # each row and each column contains all different numbers
        for i in range(sq_size):
            solver.add_assertion(AllDifferent(var_table[i]))
            solver.add_assertion(AllDifferent([x[i] for x in var_table]))

        # each square contains all different numbers
        for sx in range(self.size):
            for sy in range(self.size):
                square = [var_table[i + sx * self.size][j + sy * self.size]
                          for i in range(self.size) for j in range(self.size)]
                solver.add_assertion(AllDifferent(square))

        return solver, var_table
Exemple #20
0
def configure(inpcode):
  global solver
  verbprint = utils.verbprint
  verbprint(2, "Parsing the input...", False)
  parser = SmtLibParser()
  script = parser.get_script(inpcode)
  formula = script.get_last_formula()
  # print "Got a formula: " + str(f)
  verbprint(2, "done.", True)
  verbprint(2, "Initializing the solver...", False)
  solver = Solver(name="msat")
  verbprint(2, "done.", True)
  verbprint(2, "Loading the input file in the solver...", False)
  solver.add_assertion(formula)
  verbprint(2, "done.", True)
Exemple #21
0
def solve_soduku_model(model):
    solver = Solver()
    solver.add_assertion(model.extractConstraints())

    if solver.solve():  # solution found
        solution = tabletools.generateEmptyTable(model.n)

        for y in range(0, model.n):  # retrive the values from the sloved model
            for x in range(0, model.n):
                solution[y][x] = solver.get_value(model.getSymbol(x, y))

        return solution

    else:
        return None  # couldn't solve model :(
Exemple #22
0
    def test_create_and_solve(self):
        solver = Solver(logic=QF_BOOL)

        varA = Symbol("A", BOOL)
        varB = Symbol("B", BOOL)

        f = And(varA, Not(varB))

        g = f.substitute({varB:varA})
        solver.add_assertion(g)
        res = solver.solve()
        self.assertFalse(res, "Formula was expected to be UNSAT")

        h = And(g, Bool(False))
        simp_h = h.simplify()
        self.assertEqual(simp_h, Bool(False))
Exemple #23
0
    def test_create_and_solve(self):
        solver = Solver(logic=QF_BOOL)

        varA = Symbol("A", BOOL)
        varB = Symbol("B", BOOL)

        f = And(varA, Not(varB))

        g = f.substitute({varB: varA})
        solver.add_assertion(g)
        res = solver.solve()
        self.assertFalse(res, "Formula was expected to be UNSAT")

        h = And(g, Bool(False))
        simp_h = h.simplify()
        self.assertEqual(simp_h, Bool(False))
Exemple #24
0
def k_sequence_WH(m, K, K_seq_len=100, count=100):
    k_seq = [Symbol('x_%i' % i, INT) for i in range(K_seq_len)]
    domain = And([Or(Equals(x, Int(0)), Equals(x, Int(1))) for x in k_seq])
    K_window = And([
        LE(Plus(k_seq[t:min(K_seq_len, t + K)]), Int(m))
        for t in range(max(1, K_seq_len - K + 1))
    ])
    formula = And(domain, K_window)
    solver = Solver(name='yices',
                    incremental=True,
                    random_seed=randint(2 << 30))
    solver.add_assertion(formula)
    for _ in range(count):
        result = solver.solve()
        if not result:
            solver = Solver(name='z3',
                            incremental=True,
                            random_seed=randint(2 << 30))
            solver.add_assertion(formula)
            solver.solve()
        model = solver.get_model()
        model = array(list(map(lambda x: model.get_py_value(x), k_seq)),
                      dtype=bool)
        yield model
        solver.add_assertion(
            Or([NotEquals(k_seq[i], Int(model[i])) for i in range(K_seq_len)]))
Exemple #25
0
    def test_msat_preferred_variable(self):
        a, b, c = [Symbol(x) for x in "abc"]
        na, nb, nc = [Not(Symbol(x)) for x in "abc"]

        f = And(Implies(a, And(b, c)), Implies(na, And(nb, nc)))

        s1 = Solver("msat")
        s1.add_assertion(f)
        s1.set_preferred_var(a, True)
        self.assertTrue(s1.solve())
        self.assertTrue(s1.get_value(a).is_true())

        s2 = Solver("msat")
        s2.add_assertion(f)
        s2.set_preferred_var(a, False)
        self.assertTrue(s2.solve())
        self.assertTrue(s2.get_value(a).is_false())

        # Show that calling without polarity still works
        # This case is harder to test, because we only say
        # that the split will occur on that variable first.
        s1.set_preferred_var(a)
Exemple #26
0
 def _model_iterator_base(formula):
     """Finds all the total truth assignments that satisfy the given formula.
     
     Args:
         formula (FNode): The pysmt formula to examine.
         
     Yields:
         model: The model representing the next total truth assignment that satisfies the formula.
     
     """
     solver = Solver(name="msat")
     solver.add_assertion(formula)
     while solver.solve():
         model = solver.get_model()
         yield model
         atom_assignments = {a : model.get_value(a)
                                for a in formula.get_atoms()}
                                
         # Constrain the solver to find a different assignment
         solver.add_assertion(
             Not(And([Iff(var,val)
                      for var,val in atom_assignments.items()])))
Exemple #27
0
    def test_msat_preferred_variable(self):
        a, b, c = [Symbol(x) for x in "abc"]
        na, nb, nc = [Not(Symbol(x)) for x in "abc"]

        f = And(Implies(a, And(b,c)),
                Implies(na, And(nb,nc)))

        s1 = Solver("msat")
        s1.add_assertion(f)
        s1.set_preferred_var(a, True)
        self.assertTrue(s1.solve())
        self.assertTrue(s1.get_value(a).is_true())

        s2 = Solver("msat")
        s2.add_assertion(f)
        s2.set_preferred_var(a, False)
        self.assertTrue(s2.solve())
        self.assertTrue(s2.get_value(a).is_false())

        # Show that calling without polarity still works
        # This case is harder to test, because we only say
        # that the split will occur on that variable first.
        s1.set_preferred_var(a)
Exemple #28
0
def k_sequence_WH_worst_case(m, K, K_seq_len=100, count=100):
    k_seq = [Symbol('x_%i' % i, INT) for i in range(K_seq_len)]
    domain = And([Or(Equals(x, Int(0)), Equals(x, Int(1))) for x in k_seq])
    K_window = And([
        LE(Plus(k_seq[t:min(K_seq_len, t + K)]), Int(m))
        for t in range(max(1, K_seq_len - K + 1))
    ])
    violate_up = And([
        GT(Plus(k_seq[t:min(K_seq_len, t + K)]), Int(m - 1))
        for t in range(max(1, K_seq_len - K + 1))
    ])

    def violate_right_generator(n):
        return And([
            GT(Plus(k_seq[t:min(K_seq_len, t + K + n)]), Int(m))
            for t in range(max(1, K_seq_len - (K + n) + 1))
        ])

    right_shift = 1
    formula = And(domain, K_window, violate_up,
                  violate_right_generator(right_shift))
    solver = Solver(name='z3', incremental=True, random_seed=randint(2 << 30))
    solver.add_assertion(formula)
    solver.z3.set('timeout', 5 * 60 * 1000)
    solutions = And()
    for _ in range(count):
        while right_shift + K < K_seq_len:
            try:
                result = solver.solve()
            except BaseException:
                result = None
            if not result:
                solver = Solver(name='z3',
                                incremental=True,
                                random_seed=randint(2 << 30))
                right_shift += 1
                solver.z3.set('timeout', 5 * 60 * 1000)
                solver.add_assertion(
                    And(solutions, domain, K_window, violate_up,
                        violate_right_generator(right_shift)))
            else:
                break
        try:
            model = solver.get_model()
        except BaseException:
            break
        model = array(list(map(lambda x: model.get_py_value(x), k_seq)),
                      dtype=bool)
        yield model
        solution = Or(
            [NotEquals(k_seq[i], Int(model[i])) for i in range(K_seq_len)])
        solutions = And(solutions, solution)
        solver.add_assertion(solution)
Exemple #29
0
class SATSolver:
    def __init__(self, solver_name):
        self.solver = Solver(solver_name, logic=QF_BOOL)

    def get_model(self):
        m = None
        res = self.solver.solve()
        if res:
            m = self.solver.get_model()
        return m

    def enum_model(self, blocking_cls):
        m = None
        self.solver.add_assertion(Not(And(blocking_cls)))
        res = self.solver.solve()
        if res:
            m = self.solver.get_model()
        return m

    def assert_cls(self, _cls):
        self.solver.add_assertion(_cls)

    def add_cls(self, _cls, level):
        self.solver.push(level)
        self.solver.add_assertion(_cls)

    def remove_cls(self, level):
        self.solver.pop(level)

    def unsat_core(self):

        res = self.solver.solve()

        if not res:
            print('Assertions:', self.fml)
            conj = conjunctive_partition(self.fml)
            ucore = get_unsat_core(conj)
            print("UNSAT-Core size '%d'" % len(ucore))
        for f in ucore:
            print(f.serialize())
Exemple #30
0
class Reluzy:

    def __init__(self, filename, violationfile, logger):
        self.logger = logger
        self.nnet2smt = Nnet2Smt(filename, violationfile)
        self.nnet2smt.convert(True)
        self.input_vars = self.nnet2smt.input_vars
        self.output_vars = self.nnet2smt.output_vars
        self.formulae = self.nnet2smt.formulae
        self.relus = self.nnet2smt.relus
        self.relus_level = self.nnet2smt.relus_level
        self.solver = Solver(name='yices')
        self.sat_checker = Solver(name='yices')
        self.init()
        
    def init(self):
        self.solver.add_assertion(And(self.formulae))
        self.sat_checker.add_assertion(And(self.formulae))
        for r1, r2 in self.relus:
            self.sat_checker.add_assertion(Equals(r1, Max(r2, Real(0))))
        #lemmas = self.refine_zero_lb(False)
        #self.solver.add_assertion(And(lemmas))
        #lemmas = self.refine_slope_lb(False)
        #self.solver.add_assertion(And(lemmas))

    def solve(self):
        while True:
            self.logger.info('Solving')
            res = self.solver.solve()
            if not res:
                print('unsat')
                break
            else:
                lemmas = self.refine()
                if not lemmas:
                    print('sat')
                    break
                else:
                    self.solver.add_assertion(And(lemmas))


    def check_sat(self):
        self.logger.info('Checking for Sat')
        self.sat_checker.push()
        for v in self.input_vars:
            self.sat_checker.add_assertion(Equals(v, self.solver.get_value(v)))
        res = self.sat_checker.solve()
        if res:
            for x in self.input_vars:
                print(v, self.sat_checker.get_value(x))
            return True
        else:
            self.sat_checker.pop()
            return False
            

    def refine(self):
        self.logger.info('Refining')
        lemmas = self.refine_zero_lb()
        if not lemmas:
            lemmas = self.refine_slope_lb()

        if not lemmas:
            lemmas = self.refine_zero_ub()

        if not lemmas:
            lemmas = self.refine_slope_ub()

        if not lemmas:
            for v in self.input_vars:
                print(v, self.solver.get_value(v))
        #elif self.check_sat():
        #    return []

        return lemmas

    def refine_zero_lb(self, check=True):
        lemmas = []
        zero = Real(0)
        for r1, _ in self.relus:
            l = GE(r1, zero)
            if check:
                tval = self.solver.get_value(l)
                if tval.is_false():
                    lemmas.append(l)
                    self.logger.debug('Adding %s' % l )
            else:
                lemmas.append(l)
        return lemmas

    def refine_zero_ub(self):
        lemmas = []
        zero = Real(0)
        for s in self.relus_level:
            for r1, r2 in s:
                l = Implies(LE(r2, zero), LE(r1, zero))
                tval = self.solver.get_value(l)
                if tval.is_false():
                    lemmas.append(l)
                    self.logger.debug('Adding %s' % l )
            if lemmas:
                break
        return lemmas
        
    def refine_slope_lb(self, check=True):
        lemmas = []
        for r1, r2 in self.relus:
            l = GE(r1, r2)
            if check:
                tval = self.solver.get_value(l)
                if tval.is_false():
                    lemmas.append(l)
                    self.logger.debug('Adding %s' % l )
            else:
                lemmas.append(l)
        return lemmas
       
    def refine_slope_ub(self):
        lemmas = []
        zero = Real(0)
        for s in self.relus_level:
            for r1, r2 in s:
                l = Implies(GE(r2, zero), LE(r1, r2))
                tval = self.solver.get_value(l)
                if tval.is_false():
                    lemmas.append(l)
                    self.logger.debug('Adding %s' % l )
            if lemmas:
                break
        return lemmas
Exemple #31
0
def solve_with_dreal(formula):
        env = get_env()
        env.factory.add_generic_solver(DREAL_NAME, [DREAL_PATH, DREAL_ARGS], DREAL_LOGICS)
        solver = Solver('dreal')
        solver.add_assertion(formula)
        solver.solve()
Exemple #32
0
    def comb_attack(self):
        # dis generator
        solver_name = 'btor'
        solver_obf = Solver(name=solver_name)
        solver_key = Solver(name=solver_name)
        solver_oracle = Solver(name=solver_name)
        attack_formulas = FormulaGenerator(self.oracle_cir, self.obf_cir)

        f = attack_formulas.dip_gen_ckt
        # f = simplify(f)
        solver_obf.add_assertion(f)

        f = attack_formulas.key_inequality_ckt
        # f = simplify(f)
        solver_obf.add_assertion(f)

        iteration = 0
        while 1:
            # query dip generator
            if solver_obf.solve():
                dip_formula = []
                dip_boolean = []
                for l in self.obf_cir.input_wires:
                    t = Symbol(l)
                    if solver_obf.get_py_value(t):
                        dip_formula.append(t)
                        dip_boolean.append(TRUE())
                    else:
                        dip_formula.append(Not(t))
                        dip_boolean.append(FALSE())
                logging.info(dip_formula)

                # query oracle
                dip_out = []
                for l in self.oracle_cir.output_wires:
                    t = self.oracle_cir.wire_objs[l].formula
                    solver_oracle.reset_assertions()
                    solver_oracle.add_assertion(t)
                    if solver_oracle.solve(dip_formula):
                        dip_out.append(TRUE())
                    else:
                        dip_out.append(FALSE())
                logging.info(dip_out)

                # add dip checker
                f = []
                for i in range(len(attack_formulas.dip_chk1)):
                    f.append(And(Iff(dip_out[i], attack_formulas.dip_chk1[i]),
                                 Iff(dip_out[i], attack_formulas.dip_chk2[i])))
                f = And(f)

                subs = {}
                for i in range(len(self.obf_cir.input_wires)):
                    subs[Symbol(self.obf_cir.input_wires[i])] = dip_boolean[i]

                # f = simplify(f)
                f = substitute(f, subs)
                solver_obf.add_assertion(f)
                solver_key.add_assertion(f)

                iteration += 1
                logging.warning('iteration: {}'.format(iteration))
            else:
                logging.warning('print keys')
                if solver_key.solve():
                    key = ''
                    for i in range(len(self.obf_cir.key_wires)):
                        k = 'keyinput{}_0'.format(i)
                        if solver_key.get_py_value(Symbol(k)):
                            key += '1'
                        else:
                            key += '0'
                    print("key=%s" % key)
                else:
                    logging.critical('key solver returned UNSAT')
                return
Exemple #33
0
 def test_btor_does_not_support_const_arryas(self):
     with self.assertRaises(ConvertExpressionError):
         btor = Solver(name="btor")
         btor.add_assertion(
             Equals(Array(BV8, BV(0, 8)), FreshSymbol(ArrayType(BV8, BV8))))
Exemple #34
0
class SMTValidator(object):
    """
        Validating Anchor's explanations using SMT solving.
    """
    def __init__(self, formula, feats, nof_classes, xgb):
        """
            Constructor.
        """

        self.ftids = {f: i for i, f in enumerate(feats)}
        self.nofcl = nof_classes
        self.idmgr = IDPool()
        self.optns = xgb.options

        # xgbooster will also be needed
        self.xgb = xgb

        self.verbose = self.optns.verb
        self.oracle = Solver(name=self.xgb.options.solver)

        self.inps = []  # input (feature value) variables
        for f in self.xgb.extended_feature_names_as_array_strings:
            if '_' not in f:
                self.inps.append(Symbol(f, typename=REAL))
            else:
                self.inps.append(Symbol(f, typename=BOOL))

        self.outs = []  # output (class  score) variables
        for c in range(self.nofcl):
            self.outs.append(Symbol('class{0}_score'.format(c), typename=REAL))

        # theory
        self.oracle.add_assertion(formula)

        # current selector
        self.selv = None

    def prepare(self, sample, expl):
        """
            Prepare the oracle for validating an explanation given a sample.
        """

        if self.selv:
            # disable the previous assumption if any
            self.oracle.add_assertion(Not(self.selv))

        # creating a fresh selector for a new sample
        sname = ','.join([str(v).strip() for v in sample])

        # the samples should not repeat; otherwise, they will be
        # inconsistent with the previously introduced selectors
        assert sname not in self.idmgr.obj2id, 'this sample has been considered before (sample {0})'.format(
            self.idmgr.id(sname))
        self.selv = Symbol('sample{0}_selv'.format(self.idmgr.id(sname)),
                           typename=BOOL)

        self.rhypos = []  # relaxed hypotheses

        # transformed sample
        self.sample = list(self.xgb.transform(sample)[0])

        # preparing the selectors
        for i, (inp, val) in enumerate(zip(self.inps, self.sample), 1):
            feat = inp.symbol_name().split('_')[0]
            selv = Symbol('selv_{0}'.format(feat))
            val = float(val)

            self.rhypos.append(selv)

        # adding relaxed hypotheses to the oracle
        for inp, val, sel in zip(self.inps, self.sample, self.rhypos):
            if '_' not in inp.symbol_name():
                hypo = Implies(self.selv,
                               Implies(sel, Equals(inp, Real(float(val)))))
            else:
                hypo = Implies(self.selv, Implies(sel,
                                                  inp if val else Not(inp)))

            self.oracle.add_assertion(hypo)

        # propagating the true observation
        if self.oracle.solve([self.selv] + self.rhypos):
            model = self.oracle.get_model()
        else:
            assert 0, 'Formula is unsatisfiable under given assumptions'

        # choosing the maximum
        outvals = [float(model.get_py_value(o)) for o in self.outs]
        maxoval = max(zip(outvals, range(len(outvals))))

        # correct class id (corresponds to the maximum computed)
        true_output = maxoval[1]

        # forcing a misclassification, i.e. a wrong observation
        disj = []
        for i in range(len(self.outs)):
            if i != true_output:
                disj.append(GT(self.outs[i], self.outs[true_output]))
        self.oracle.add_assertion(Implies(self.selv, Or(disj)))

        # removing all hypotheses except for those in the explanation
        hypos = []
        for i, hypo in enumerate(self.rhypos):
            j = self.ftids[self.xgb.transform_inverse_by_index(i)[0]]
            if j in expl:
                hypos.append(hypo)
        self.rhypos = hypos

        if self.verbose:
            inpvals = self.xgb.readable_sample(sample)

            preamble = []
            for f, v in zip(self.xgb.feature_names, inpvals):
                if f not in v:
                    preamble.append('{0} = {1}'.format(f, v))
                else:
                    preamble.append(v)

            print('  explanation for:  "IF {0} THEN {1}"'.format(
                ' AND '.join(preamble), self.xgb.target_name[true_output]))

    def validate(self, sample, expl):
        """
            Make an effort to show that the explanation is too optimistic.
        """

        self.time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \
                resource.getrusage(resource.RUSAGE_SELF).ru_utime

        # adapt the solver to deal with the current sample
        self.prepare(sample, expl)

        # if satisfiable, then there is a counterexample
        if self.oracle.solve([self.selv] + self.rhypos):
            model = self.oracle.get_model()
            inpvals = [float(model.get_py_value(i)) for i in self.inps]
            outvals = [float(model.get_py_value(o)) for o in self.outs]
            maxoval = max(zip(outvals, range(len(outvals))))

            inpvals = self.xgb.transform_inverse(np.array(inpvals))[0]
            self.coex = tuple([inpvals, maxoval[1]])
            inpvals = self.xgb.readable_sample(inpvals)

            if self.verbose:
                preamble = []
                for f, v in zip(self.xgb.feature_names, inpvals):
                    if f not in v:
                        preamble.append('{0} = {1}'.format(f, v))
                    else:
                        preamble.append(v)

                print('  explanation is incorrect')
                print('  counterexample: "IF {0} THEN {1}"'.format(
                    ' AND '.join(preamble), self.xgb.target_name[maxoval[1]]))
        else:
            self.coex = None

            if self.verbose:
                print('  explanation is correct')

        self.time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \
                resource.getrusage(resource.RUSAGE_SELF).ru_utime - self.time

        if self.verbose:
            print('  time: {0:.2f}'.format(self.time))

        return self.coex
Exemple #35
0
 def test_btor_does_not_support_const_arryas(self):
     with self.assertRaises(ConvertExpressionError):
         btor = Solver(name="btor")
         btor.add_assertion(Equals(Array(BV8, BV(0, 8)),
                                   FreshSymbol(ArrayType(BV8, BV8))))
Exemple #36
0
class SMTExplainer(object):
    """
        An SMT-inspired minimal explanation extractor for XGBoost models.
    """
    def __init__(self, formula, intvs, imaps, ivars, feats, nof_classes,
                 options, xgb):
        """
            Constructor.
        """

        self.feats = feats
        self.intvs = intvs
        self.imaps = imaps
        self.ivars = ivars
        self.nofcl = nof_classes
        self.optns = options
        self.idmgr = IDPool()

        # saving XGBooster
        self.xgb = xgb

        self.verbose = self.optns.verb
        self.oracle = Solver(name=options.solver)

        self.inps = []  # input (feature value) variables
        for f in self.xgb.extended_feature_names_as_array_strings:
            if '_' not in f:
                self.inps.append(Symbol(f, typename=REAL))
            else:
                self.inps.append(Symbol(f, typename=BOOL))

        self.outs = []  # output (class  score) variables
        for c in range(self.nofcl):
            self.outs.append(Symbol('class{0}_score'.format(c), typename=REAL))

        # theory
        self.oracle.add_assertion(formula)

        # current selector
        self.selv = None

        # save and use dual explanations whenever needed
        self.dualx = []

        # number of oracle calls involved
        self.calls = 0

    def prepare(self, sample):
        """
            Prepare the oracle for computing an explanation.
        """

        if self.selv:
            # disable the previous assumption if any
            self.oracle.add_assertion(Not(self.selv))

        # creating a fresh selector for a new sample
        sname = ','.join([str(v).strip() for v in sample])

        # the samples should not repeat; otherwise, they will be
        # inconsistent with the previously introduced selectors
        assert sname not in self.idmgr.obj2id, 'this sample has been considered before (sample {0})'.format(
            self.idmgr.id(sname))
        self.selv = Symbol('sample{0}_selv'.format(self.idmgr.id(sname)),
                           typename=BOOL)

        self.rhypos = []  # relaxed hypotheses

        # transformed sample
        self.sample = list(self.xgb.transform(sample)[0])

        self.sel2fid = {}  # selectors to original feature ids
        self.sel2vid = {}  # selectors to categorical feature ids

        # preparing the selectors
        for i, (inp, val) in enumerate(zip(self.inps, self.sample), 1):
            feat = inp.symbol_name().split('_')[0]
            selv = Symbol('selv_{0}'.format(feat))
            val = float(val)

            self.rhypos.append(selv)
            if selv not in self.sel2fid:
                self.sel2fid[selv] = int(feat[1:])
                self.sel2vid[selv] = [i - 1]
            else:
                self.sel2vid[selv].append(i - 1)

        # adding relaxed hypotheses to the oracle
        if not self.intvs:
            for inp, val, sel in zip(self.inps, self.sample, self.rhypos):
                if '_' not in inp.symbol_name():
                    hypo = Implies(self.selv,
                                   Implies(sel, Equals(inp, Real(float(val)))))
                else:
                    hypo = Implies(self.selv,
                                   Implies(sel, inp if val else Not(inp)))

                self.oracle.add_assertion(hypo)
        else:
            for inp, val, sel in zip(self.inps, self.sample, self.rhypos):
                inp = inp.symbol_name()
                # determining the right interval and the corresponding variable
                for ub, fvar in zip(self.intvs[inp], self.ivars[inp]):
                    if ub == '+' or val < ub:
                        hypo = Implies(self.selv, Implies(sel, fvar))
                        break

                self.oracle.add_assertion(hypo)

        # in case of categorical data, there are selector duplicates
        # and we need to remove them
        self.rhypos = sorted(set(self.rhypos),
                             key=lambda x: int(x.symbol_name()[6:]))

        # propagating the true observation
        if self.oracle.solve([self.selv] + self.rhypos):
            model = self.oracle.get_model()
        else:
            assert 0, 'Formula is unsatisfiable under given assumptions'

        # choosing the maximum
        outvals = [float(model.get_py_value(o)) for o in self.outs]
        maxoval = max(zip(outvals, range(len(outvals))))

        # correct class id (corresponds to the maximum computed)
        self.out_id = maxoval[1]
        self.output = self.xgb.target_name[self.out_id]

        # forcing a misclassification, i.e. a wrong observation
        disj = []
        for i in range(len(self.outs)):
            if i != self.out_id:
                disj.append(GT(self.outs[i], self.outs[self.out_id]))
        self.oracle.add_assertion(Implies(self.selv, Or(disj)))

        if self.verbose:
            inpvals = self.xgb.readable_sample(sample)

            self.preamble = []
            for f, v in zip(self.xgb.feature_names, inpvals):
                if f not in v:
                    self.preamble.append('{0} = {1}'.format(f, v))
                else:
                    self.preamble.append(v)

            print('  explaining:  "IF {0} THEN {1}"'.format(
                ' AND '.join(self.preamble), self.output))

    def explain(self, sample, smallest):
        """
            Hypotheses minimization.
        """

        # reinitializing the number of used oracle calls
        # 1 because of the initial call checking the entailment
        self.calls = 1

        # adapt the solver to deal with the current sample
        self.prepare(sample)

        # saving external explanation to be minimized further
        self.to_consider = [True for h in self.rhypos]

        self.time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \
                resource.getrusage(resource.RUSAGE_SELF).ru_utime

        # if satisfiable, then the observation is not implied by the hypotheses
        if self.oracle.solve(
            [self.selv] +
            [h for h, c in zip(self.rhypos, self.to_consider) if c]):
            print('  no implication!')
            print(self.oracle.get_model())
            sys.exit(1)

        if self.optns.xtype == 'abductive':
            # abductive explanations => MUS computation and enumeration
            if not smallest and self.optns.xnum == 1:
                expls = [self.compute_minimal_abductive()]
            else:
                expls = self.enumerate_abductive(smallest=smallest)
        else:  # contrastive explanations => MCS enumeration
            if self.optns.usemhs:
                expls = self.enumerate_contrastive()
            else:
                if not smallest:
                    expls = self.enumerate_minimal_contrastive()
                else:
                    # expls = self.enumerate_smallest_contrastive()
                    expls = self.enumerate_contrastive()

        self.time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \
                resource.getrusage(resource.RUSAGE_SELF).ru_utime - self.time

        expls = list(
            map(lambda expl: sorted([self.sel2fid[h] for h in expl]), expls))

        if self.dualx:
            self.dualx = list(
                map(lambda expl: sorted([self.sel2fid[h] for h in expl]),
                    self.dualx))

        if self.verbose:
            if expls[0] != None:
                for expl in expls:
                    preamble = [self.preamble[i] for i in expl]
                    if self.optns.xtype == 'abductive':
                        print('  explanation: "IF {0} THEN {1}"'.format(
                            ' AND '.join(preamble),
                            self.xgb.target_name[self.out_id]))
                    else:
                        print(
                            '  explanation: "IF NOT {0} THEN NOT {1}"'.format(
                                ' AND NOT '.join(preamble),
                                self.xgb.target_name[self.out_id]))
                    print('  # hypos left:', len(expl))

            print('  time: {0:.2f}'.format(self.time))

        # here we return the last computed explanation
        return expls

    def compute_minimal_abductive(self):
        """
            Compute any subset-minimal explanation.
        """

        i = 0

        # filtering out unnecessary features if external explanation is given
        rhypos = [h for h, c in zip(self.rhypos, self.to_consider) if c]

        # simple deletion-based linear search
        while i < len(rhypos):
            to_test = rhypos[:i] + rhypos[(i + 1):]

            self.calls += 1
            if self.oracle.solve([self.selv] + to_test):
                i += 1
            else:
                rhypos = to_test

        return rhypos

    def enumerate_minimal_contrastive(self):
        """
            Compute a subset-minimal contrastive explanation.
        """
        def _overapprox():
            model = self.oracle.get_model()

            for sel in self.rhypos:
                if int(model.get_py_value(sel)) > 0:
                    # soft clauses contain positive literals
                    # so if var is true then the clause is satisfied
                    self.ss_assumps.append(sel)
                else:
                    self.setd.append(sel)

        def _compute():
            i = 0
            while i < len(self.setd):
                if self.optns.usecld:
                    _do_cld_check(self.setd[i:])
                    i = 0

                if self.setd:
                    # it may be empty after the clause D check

                    self.calls += 1
                    self.ss_assumps.append(self.setd[i])
                    if not self.oracle.solve([self.selv] + self.ss_assumps +
                                             self.bb_assumps):
                        self.ss_assumps.pop()
                        self.bb_assumps.append(Not(self.setd[i]))

                i += 1

        def _do_cld_check(cld):
            self.cldid += 1
            sel = Symbol('{0}_{1}'.format(self.selv.symbol_name(), self.cldid))
            cld.append(Not(sel))

            # adding clause D
            self.oracle.add_assertion(Or(cld))
            self.ss_assumps.append(sel)

            self.setd = []
            st = self.oracle.solve([self.selv] + self.ss_assumps +
                                   self.bb_assumps)

            self.ss_assumps.pop()  # removing clause D assumption
            if st == True:
                model = self.oracle.get_model()

                for l in cld[:-1]:
                    # filtering all satisfied literals
                    if int(model.get_py_value(l)) > 0:
                        self.ss_assumps.append(l)
                    else:
                        self.setd.append(l)
            else:
                # clause D is unsatisfiable => all literals are backbones
                self.bb_assumps.extend([Not(l) for l in cld[:-1]])

            # deactivating clause D
            self.oracle.add_assertion(Not(sel))

        # sets of selectors to work with
        self.cldid = 0
        expls = []

        # detect and block unit-size MCSes immediately
        if self.optns.unitmcs:
            for i, hypo in enumerate(self.rhypos):
                self.calls += 1
                if self.oracle.solve([self.selv] + self.rhypos[:i] +
                                     self.rhypos[(i + 1):]):
                    expls.append([hypo])

                    if len(expls) != self.optns.xnum:
                        self.oracle.add_assertion(Or([Not(self.selv), hypo]))
                    else:
                        break

        self.calls += 1
        while self.oracle.solve([self.selv]):
            self.ss_assumps, self.bb_assumps, self.setd = [], [], []
            _overapprox()
            _compute()

            expl = [list(f.get_free_variables())[0] for f in self.bb_assumps]
            expls.append(expl)

            if len(expls) == self.optns.xnum:
                break

            self.oracle.add_assertion(Or([Not(self.selv)] + expl))
            self.calls += 1

        self.calls += self.cldid
        return expls if expls else [None]

    def enumerate_abductive(self, smallest=True):
        """
            Compute a cardinality-minimal explanation.
        """

        # result
        expls = []

        # just in case, let's save dual (contrastive) explanations
        self.dualx = []

        with Hitman(bootstrap_with=[[
                i for i in range(len(self.rhypos)) if self.to_consider[i]
        ]],
                    htype='sorted' if smallest else 'lbx') as hitman:
            # computing unit-size MCSes
            for i, hypo in enumerate(self.rhypos):
                if self.to_consider[i] == False:
                    continue

                self.calls += 1
                if self.oracle.solve([self.selv] + self.rhypos[:i] +
                                     self.rhypos[(i + 1):]):
                    hitman.hit([i])
                    self.dualx.append([self.rhypos[i]])

            # main loop
            iters = 0
            while True:
                hset = hitman.get()
                iters += 1

                if self.verbose > 1:
                    print('iter:', iters)
                    print('cand:', hset)

                if hset == None:
                    break

                self.calls += 1
                if self.oracle.solve([self.selv] +
                                     [self.rhypos[i] for i in hset]):
                    to_hit = []
                    satisfied, unsatisfied = [], []

                    removed = list(
                        set(range(len(self.rhypos))).difference(set(hset)))

                    model = self.oracle.get_model()
                    for h in removed:
                        i = self.sel2fid[self.rhypos[h]]
                        if '_' not in self.inps[i].symbol_name():
                            # feature variable and its expected value
                            var, exp = self.inps[i], self.sample[i]

                            # true value
                            true_val = float(model.get_py_value(var))

                            if not exp - 0.001 <= true_val <= exp + 0.001:
                                unsatisfied.append(h)
                            else:
                                hset.append(h)
                        else:
                            for vid in self.sel2vid[self.rhypos[h]]:
                                var, exp = self.inps[vid], int(
                                    self.sample[vid])

                                # true value
                                true_val = int(model.get_py_value(var))

                                if exp != true_val:
                                    unsatisfied.append(h)
                                    break
                            else:
                                hset.append(h)

                    # computing an MCS (expensive)
                    for h in unsatisfied:
                        self.calls += 1
                        if self.oracle.solve([self.selv] +
                                             [self.rhypos[i] for i in hset] +
                                             [self.rhypos[h]]):
                            hset.append(h)
                        else:
                            to_hit.append(h)

                    if self.verbose > 1:
                        print('coex:', to_hit)

                    hitman.hit(to_hit)

                    self.dualx.append([self.rhypos[i] for i in to_hit])
                else:
                    if self.verbose > 1:
                        print('expl:', hset)

                    expl = [self.rhypos[i] for i in hset]
                    expls.append(expl)

                    if len(expls) != self.optns.xnum:
                        hitman.block(hset)
                    else:
                        break

        return expls

    def enumerate_smallest_contrastive(self):
        """
            Compute a cardinality-minimal contrastive explanation.
        """

        # result
        expls = []

        # computing unit-size MUSes
        muses = set([])
        for hypo in self.rhypos:
            self.calls += 1
            if not self.oracle.solve([self.selv, hypo]):
                muses.add(hypo)

        # we are going to discard unit-size MUSes from consideration
        rhypos = set(self.rhypos).difference(muses)

        # introducing interer cost literals for rhypos
        costlits = []
        for i, hypo in enumerate(rhypos):
            costlit = Symbol(name='costlit_{0}_{1}'.format(
                self.selv.symbol_name(), i),
                             typename=INT)
            costlits.append(costlit)

            self.oracle.add_assertion(
                Ite(hypo, Equals(costlit, Int(0)), Equals(costlit, Int(1))))

        # main loop (linear search unsat-sat)
        i = 0
        while i < len(rhypos) and len(expls) != self.optns.xnum:
            # fresh selector for the current iteration
            sit = Symbol('iter_{0}_{1}'.format(self.selv.symbol_name(), i))

            # adding cardinality constraint
            self.oracle.add_assertion(Implies(sit, LE(Plus(costlits), Int(i))))

            # extracting explanations from MaxSAT models
            while self.oracle.solve([self.selv, sit]):
                self.calls += 1
                model = self.oracle.get_model()

                expl = []
                for hypo in rhypos:
                    if int(model.get_py_value(hypo)) == 0:
                        expl.append(hypo)

                # each MCS contains all unit-size MUSes
                expls.append(list(muses) + expl)

                # either stop or add a blocking clause
                if len(expls) != self.optns.xnum:
                    self.oracle.add_assertion(Implies(self.selv, Or(expl)))
                else:
                    break

            i += 1
            self.calls += 1

        return expls

    def enumerate_contrastive(self, smallest=True):
        """
            Compute a cardinality-minimal contrastive explanation.
        """

        # core extraction is done via calling Z3's internal API
        assert self.optns.solver == 'z3', 'This procedure requires Z3'

        # result
        expls = []

        # just in case, let's save dual (abductive) explanations
        self.dualx = []

        # mapping from hypothesis variables to their indices
        hmap = {h: i for i, h in enumerate(self.rhypos)}

        # mapping from internal Z3 variable into variables of PySMT
        vmap = {self.oracle.converter.convert(v): v for v in self.rhypos}
        vmap[self.oracle.converter.convert(self.selv)] = None

        def _get_core():
            core = self.oracle.z3.unsat_core()
            return sorted(filter(lambda x: x != None,
                                 map(lambda x: vmap[x], core)),
                          key=lambda x: int(x.symbol_name()[6:]))

        def _do_trimming(core):
            for i in range(self.optns.trim):
                self.calls += 1
                self.oracle.solve([self.selv] + core)
                new_core = _get_core()
                if len(core) == len(new_core):
                    break
            return new_core

        def _reduce_lin(core):
            def _assump_needed(a):
                if len(to_test) > 1:
                    to_test.remove(a)
                    self.calls += 1
                    if not self.oracle.solve([self.selv] + list(to_test)):
                        return False
                    to_test.add(a)
                    return True
                else:
                    return True

            to_test = set(core)
            return list(filter(lambda a: _assump_needed(a), core))

        def _reduce_qxp(core):
            coex = core[:]
            filt_sz = len(coex) / 2.0
            while filt_sz >= 1:
                i = 0
                while i < len(coex):
                    to_test = coex[:i] + coex[(i + int(filt_sz)):]
                    self.calls += 1
                    if to_test and not self.oracle.solve([self.selv] +
                                                         to_test):
                        # assumps are not needed
                        coex = to_test
                    else:
                        # assumps are needed => check the next chunk
                        i += int(filt_sz)
                # decreasing size of the set to filter
                filt_sz /= 2.0
                if filt_sz > len(coex) / 2.0:
                    # next size is too large => make it smaller
                    filt_sz = len(coex) / 2.0
            return coex

        def _reduce_coex(core):
            if self.optns.reduce == 'lin':
                return _reduce_lin(core)
            else:  # qxp
                return _reduce_qxp(core)

        with Hitman(bootstrap_with=[[
                i for i in range(len(self.rhypos)) if self.to_consider[i]
        ]],
                    htype='sorted' if smallest else 'lbx') as hitman:
            # computing unit-size MUSes
            for i, hypo in enumerate(self.rhypos):
                if self.to_consider[i] == False:
                    continue

                self.calls += 1
                if not self.oracle.solve([self.selv, self.rhypos[i]]):
                    hitman.hit([i])
                    self.dualx.append([self.rhypos[i]])
                elif self.optns.unitmcs:
                    self.calls += 1
                    if self.oracle.solve([self.selv] + self.rhypos[:i] +
                                         self.rhypos[(i + 1):]):
                        # this is a unit-size MCS => block immediately
                        hitman.block([i])
                        expls.append([self.rhypos[i]])

            # main loop
            iters = 0
            while True:
                hset = hitman.get()
                iters += 1

                if self.verbose > 1:
                    print('iter:', iters)
                    print('cand:', hset)

                if hset == None:
                    break

                self.calls += 1
                if not self.oracle.solve([self.selv] + [
                        self.rhypos[h] for h in list(
                            set(range(len(self.rhypos))).difference(set(hset)))
                ]):
                    to_hit = _get_core()

                    if len(to_hit) > 1 and self.optns.trim:
                        to_hit = _do_trimming(to_hit)

                    if len(to_hit) > 1 and self.optns.reduce != 'none':
                        to_hit = _reduce_coex(to_hit)

                    self.dualx.append(to_hit)
                    to_hit = [hmap[h] for h in to_hit]

                    if self.verbose > 1:
                        print('coex:', to_hit)

                    hitman.hit(to_hit)
                else:
                    if self.verbose > 1:
                        print('expl:', hset)

                    expl = [self.rhypos[i] for i in hset]
                    expls.append(expl)

                    if len(expls) != self.optns.xnum:
                        hitman.block(hset)
                    else:
                        break

        return expls
Exemple #37
0
def callback(model, converter, result):
    """Callback for msat_all_sat.

    This function is called by the MathSAT API everytime a new model
    is found. If the function returns 1, the search continues,
    otherwise it stops.
    """
    # Elements in model are msat_term .
    # Converter.back() provides the pySMT representation of a solver term.
    py_model = [converter.back(v) for v in model]
    result.append(And(py_model))
    return 1 # go on

x, y = Symbol("x"), Symbol("y")
f = Or(x, y)

msat = Solver(name="msat")
converter = msat.converter # .converter is a property implemented by all solvers
msat.add_assertion(f) # This is still at pySMT level

result = []
# Directly invoke the mathsat API !!!
# The second term is a list of "important variables"
mathsat.msat_all_sat(msat.msat_env(),
      [converter.convert(x)],      # Convert the pySMT term into a MathSAT term
      lambda model : callback(model, converter, result))

print("'exists y . %s' is equivalent to '%s'" %(f, Or(result)))
#exists y . (x | y) is equivalent to ((! x) | x)
Exemple #38
0
def get_makespan_optimal_weakly_hard_schedule(g,
                                              network,
                                              feasibility_timeout=None,
                                              optimization_timeout=10 * 60 *
                                              1000):
    vprint('*computing optimal weakly-hard real-time schedule via SMT*')
    # SMT formulation
    tc = transitive_closure(g)
    logical_edges = get_logical_edges(g)
    JUMPTABLE_MAX = 6
    K_MAX = 5001  # LAMBDA(i)[1] < K_MAX for all i < JUMPTABLE_MAX

    A, B, C, D, GAMMA, LAMBDA = (network[key] for key in ('A', 'B', 'C', 'D',
                                                          'GAMMA', 'LAMBDA'))

    assert (all(map(lambda x: LAMBDA(x)[1] < K_MAX, range(JUMPTABLE_MAX))))

    vprint('\tinstantiating symvars...')
    label = [Symbol('label_%i' % i, INT) for i in range(len(logical_edges))]
    # first half for slot, second half for beacons
    chi = [Symbol('chi_%i' % i, INT) for i in range(2 * len(logical_edges))]
    duration = [
        Symbol('duration_%i' % i, INT) for i in range(len(logical_edges))
    ]
    zeta = [
        Symbol('zeta_%i' % i, INT)
        for i in range(g.num_vertices() + len(logical_edges))
    ]
    delta_e_in_r = [[
        Symbol('delta_e_in_r-%i_%i' % (i, j), INT)
        for j in range(len(logical_edges))
    ] for i in range(len(logical_edges))]
    delta_chi_eq_i = [[
        Symbol('delta_chi_eq_i-%i_%i' % (i, j), INT)
        for j in range(JUMPTABLE_MAX)
    ] for i in range(2 * len(logical_edges))]
    delta_tau_before_r = [[
        Symbol('delta_tau_before_r-%i_%i' % (i, j), INT)
        for j in range(len(logical_edges))
    ] for i in range(g.num_vertices() + len(logical_edges))]

    vprint('\tgenerating constraint clauses...')
    domain = And([
        And([
            And(LE(Int(1), sym), LE(sym, Int(len(logical_edges))))
            for sym in label
        ]),
        And([And(LE(Int(1), sym), LT(sym, Int(JUMPTABLE_MAX)))
             for sym in chi]),
        And([LE(Int(0), sym) for sym in zeta]),
        And([
            And(LE(Int(0), sym), LE(sym, Int(1)))
            for sym in chain.from_iterable(delta_e_in_r + delta_chi_eq_i +
                                           delta_tau_before_r)
        ])
    ])
    one_hot = And([
        And([
            Equals(
                Plus([delta_e_in_r[e][r] for r in range(len(logical_edges))]),
                Int(1)) for e in range(len(logical_edges))
        ]),
        And([
            Equals(
                Plus([delta_chi_eq_i[chir][i] for i in range(JUMPTABLE_MAX)]),
                Int(1)) for chir in range(2 * len(logical_edges))
        ])
    ])
    CFOP = And([
        LT(label[logical_edges.index(r)], label[logical_edges.index(s)])
        for r, s in product(logical_edges, repeat=2)
        if r.source() in tc.get_in_neighbors(s.source())
    ])
    task_partitioning_by_round = And(
        And([
            LE(delta_tau_before_r[int(tau)][r], delta_tau_before_r[int(mu)][r])
            for tau, mu, r in product(tc.vertices(), tc.vertices(),
                                      range(len(logical_edges)))
            if tau in tc.get_in_neighbors(mu)
        ]),
        And([
            Equals(delta_tau_before_r[r + g.num_vertices()][s], Int(0)) if
            r < s else Equals(delta_tau_before_r[r +
                                                 g.num_vertices()][s], Int(1))
            for r, s in product(range(len(logical_edges)), repeat=2)
        ]))
    round_empty = And([
        Implies(
            Equals(
                Plus([delta_e_in_r[e][r] for e in range(len(logical_edges))]),
                Int(0)), Equals(chi[len(logical_edges) + r], Int(1)))
        for r in range(len(logical_edges))
    ])
    durations = And([
        Equals(
            duration[r],
            Plus(
                Int(A),
                Times(Plus(Times(Int(2), chi[r + len(logical_edges)]), Int(B)),
                      Int(C + D * GAMMA)),
                Times(
                    Ite(
                        GE(
                            Plus([
                                delta_e_in_r[e][r]
                                for e in range(len(logical_edges))
                            ]), Int(1)), Int(0), Int(-1)),
                    Int(A + (2 + B) * (C + D * GAMMA))),
                Plus([
                    Ite(
                        Equals(delta_e_in_r[e][r], Int(1)),
                        Plus(
                            Int(A),
                            Times(
                                Plus(Times(Int(2), chi[e]), Int(B)),
                                Int(C + D * g.edge_properties['widths'][
                                    logical_edges[e]]))), Int(0))
                    for e in range(len(logical_edges))
                ]))) for r in range(len(logical_edges))
    ])
    label_to_delta = And([
        Equals(
            label[e],
            Plus([
                Times(delta_e_in_r[e][i - 1], Int(i))
                for i in range(1, 1 + len(logical_edges))
            ])) for e in range(len(logical_edges))
    ])
    chi_to_delta = And([
        Equals(
            chi[chir],
            Plus([
                Times(delta_chi_eq_i[chir][i], Int(i))
                for i in range(JUMPTABLE_MAX)
            ])) for chir in range(2 * len(logical_edges))
    ])
    order = And(
        And([
            LT(zeta[int(tau)],
               Minus(zeta[int(mu)], Int(g.vertex_properties['durations'][mu])))
            for tau, mu in product(g.vertices(), repeat=2)
            if tau in tc.get_in_neighbors(mu)
        ]),
        And([
            LT(zeta[r + g.num_vertices()],
               Minus(zeta[r + 1 + g.num_vertices()], duration[r + 1]))
            for r in range(len(logical_edges) - 1)
        ]),
        And([
            Implies(
                Equals(delta_e_in_r[e][r], Int(1)),
                GT(
                    Minus(zeta[int(tau)],
                          Int(g.vertex_properties['durations'][tau])),
                    zeta[r + g.num_vertices()])) for tau in g.vertices()
            for r in range(len(logical_edges))
            for e in range(len(logical_edges))
            if tau in tc.get_out_neighbors(logical_edges[e].source())
        ]),
        And([
            Implies(
                Equals(delta_e_in_r[e][r], Int(1)),
                GT(Minus(zeta[r + g.num_vertices()], duration[r]),
                   zeta[int(tau)])) for tau in g.vertices()
            for r in range(len(logical_edges))
            for e in range(len(logical_edges))
            if tau in tc.get_in_neighbors(logical_edges[e].source())
            or tau == logical_edges[e].source()
        ]))
    exclusion = And([
        And(
            Implies(
                Equals(delta_tau_before_r[int(tau)][r], Int(0)),
                GT(Minus(zeta[r + g.num_vertices()], duration[r]),
                   zeta[int(tau)])),
            Implies(
                Equals(delta_tau_before_r[int(tau)][r], Int(1)),
                GT(
                    Minus(zeta[int(tau)],
                          Int(g.vertex_properties['durations'][tau])),
                    zeta[g.num_vertices() + r]))) for tau in g.vertices()
        for r in range(len(logical_edges))
    ])
    deadline = And([
        LE(zeta[int(tau)], Int(g.vertex_properties['deadlines'][tau]))
        for tau in g.vertices() if g.vertex_properties['deadlines'][tau] >= 0
    ])

    def sum_m(tau):
        return Plus([Int(0)] + [
            Plus(
                Ite(Equals(delta_chi_eq_i[e][i], Int(1)), Int(LAMBDA(i)[0]),
                    Int(0)),
                Plus([
                    Ite(
                        Equals(delta_chi_eq_i[len(logical_edges) +
                                              r][i], Int(1)),
                        Ite(Equals(delta_e_in_r[e][r], Int(1)),
                            Int(LAMBDA(i)[0]), Int(0)), Int(0))
                    for r in range(len(logical_edges))
                ])) for i in range(JUMPTABLE_MAX)
            for e in range(len(logical_edges))
            if logical_edges[e].source() in tc.get_in_neighbors(tau)
        ])

    def min_K(tau):
        return Min([Int(K_MAX)] + [
            Min(
                Ite(Equals(delta_chi_eq_i[e][i], Int(1)), Int(LAMBDA(i)[1]),
                    Int(K_MAX)),
                Min([
                    Ite(
                        Equals(delta_chi_eq_i[len(logical_edges) +
                                              r][i], Int(1)),
                        Ite(Equals(delta_e_in_r[e][r], Int(1)),
                            Int(LAMBDA(i)[1]), Int(K_MAX)), Int(K_MAX))
                    for r in range(len(logical_edges))
                ])) for i in range(JUMPTABLE_MAX)
            for r in range(len(logical_edges))
            for e in range(len(logical_edges))
            if logical_edges[e].source() in tc.get_in_neighbors(tau)
        ])

    WH = And([
        And(
            GE(Int(g.vertex_properties['weakly-hard'][tau][0]),
               Min(sum_m(tau), min_K(tau))),
            LE(Int(g.vertex_properties['weakly-hard'][tau][1]), min_K(tau)))
        for tau in g.vertices()
        if g.vertex_properties['weakly-hard'][tau][0] >= 0
    ])

    formula = And([
        domain, one_hot, CFOP, task_partitioning_by_round, round_empty,
        durations, label_to_delta, chi_to_delta, order, exclusion, deadline, WH
    ])
    vprint('\tchecking feasibility...')
    solver = Solver(name='z3', incremental=True, logic='LIA')
    if feasibility_timeout:
        solver.z3.set('timeout', feasibility_timeout)
    solver.add_assertion(formula)
    try:
        result = solver.solve()
    except SolverReturnedUnknownResultError:
        result = None
    if not result:
        vprint('\tsolver returned infeasible!')
        return [None] * 4
    else:
        models = [solver.get_model()]
        vprint('\tsolver found a feasible solution, optimizing...')

    solver.z3.set('timeout', optimization_timeout)
    LB = 0
    UB = max(map(lambda x: models[-1].get_py_value(x), zeta))
    curr_B = UB // 2
    while range(LB + 1, UB):
        try:
            result = solver.solve(
                [And([LT(zeta_tau, Int(curr_B)) for zeta_tau in zeta])])
        except SolverReturnedUnknownResultError:
            vprint('\t(timeout, not necessarily unsat)')
            result = None
        if result:
            vprint('\tfound feasible solution of length %i, optimizing...' %
                   curr_B)
            models.append(solver.get_model())
            UB = curr_B
        else:
            vprint('\tnew lower bound %i, optimizing...' % curr_B)
            LB = curr_B
        curr_B = LB + int(ceil((UB - LB) / 2))
    vprint('\tsolver returned optimal (under composition+P.O. abstractions)!')
    best_model = models[-1]
    zeta = list(map(lambda x: best_model.get_py_value(x), zeta))
    chi = list(map(lambda x: best_model.get_py_value(x), chi))
    duration = list(map(lambda x: best_model.get_py_value(x), duration))
    label = list(map(lambda x: best_model.get_py_value(x), label))
    return zeta, chi, duration, label
Exemple #39
0
    o = BV(2, VECT_WIDTH)  # o - cpu player


x_turns = 0
o_turns = 0
x_val = Cell.x.value.constant_value()
o_val = Cell.o.value.constant_value()

board = [[FreshSymbol(BVType(VECT_WIDTH)) for _ in range(3)] for _ in range(3)]

solver = Solver()

# initialise board cells, each one has to be blank, x or o
for row in board:
    for cell in row:
        solver.add_assertion(Or([Equals(cell, i.value) for i in Cell]))

# load board
test = 'tests/blank.txt'
with open(test) as fh:
    for row, line in enumerate(fh.readlines()):
        for col, cell in enumerate(line.strip().split(' ')):
            if cell == Cell.x.name:
                solver.add_assertion(Equals(board[row][col], Cell.x.value))
            elif cell == Cell.o.name:
                solver.add_assertion(Equals(board[row][col], Cell.o.value))


def already_played(row, col):
    if solver.get_value(board[row][col]) == Cell.s.value:
        return False
Exemple #40
0
    def comb_attack(self):
        # dis generator
        solver_name = 'yices'
        solver_obf = Solver(name=solver_name)
        solver_key = Solver(name=solver_name)
        self.solver_oracle = Solver(name=solver_name)
        attack_formulas = FormulaGenerator(self.oracle_cir, self.obf_cir)

        f = attack_formulas.dip_gen_ckt
        # f = simplify(f)
        solver_obf.add_assertion(f)

        f = attack_formulas.key_inequality_ckt
        # f = simplify(f)
        solver_obf.add_assertion(f)

        for l in self.oracle_cir.wire_objs:
            self.solver_oracle.add_assertion(l.formula)

        dip_list = []
        stateful_keys = []

        iteration = 0
        while 1:
            # query dip generator
            if solver_obf.solve():
                dip_formula = []
                dip_boolean = []
                for l in self.oracle_cir.input_wires:
                    s = Symbol(l.name)
                    if solver_obf.get_py_value(s):
                        dip_formula.append(s)
                        dip_boolean.append(TRUE())
                    else:
                        dip_formula.append(Not(s))
                        dip_boolean.append(FALSE())
                logging.info(dip_formula)

                # query oracle
                dip_out = self.query_oracle(dip_formula)
                logging.info(dip_out)

                # check for stateful condition
                if dip_formula in dip_list:
                    # ban stateful key
                    logging.info("found a repeated dip!")

                    # check outputs for both keys
                    key = None
                    for l in self.obf_cir.output_wires:
                        s1 = Symbol(l.name + '@dc1')
                        s2 = Symbol(l.name + '@dc2')
                        if solver_obf.get_py_value(
                                s1) != solver_obf.get_py_value(s2):
                            if solver_obf.get_py_value(
                                    s1) != self.solver_oracle.get_py_value(
                                        Symbol(l.name)):
                                key = '0'
                            else:
                                key = '1'
                            break
                    if key == None:
                        logging.critical(
                            'something is wrong when banning keys')

                    # find assigned keys
                    key_list = []
                    for l in self.obf_cir.key_wires:
                        k = Symbol(l.name + '_' + key)
                        if solver_obf.get_py_value(k):
                            key_list.append(k)
                        else:
                            key_list.append(Not(k))

                    stateful_keys.append(key_list)

                    # ban the stateful key
                    f = Not(And(key_list))
                    solver_obf.add_assertion(f)
                    solver_key.add_assertion(f)
                    if len(stateful_keys) % 5000 == 0:
                        logging.warning('current stateful keys: {}'.format(
                            len(stateful_keys)))
                    continue
                else:
                    dip_list.append(dip_formula)

                # add dip checker
                f = []
                f.append(
                    attack_formulas.gen_dip_chk(iteration * 2, '_0',
                                                dip_boolean))
                f.append(
                    attack_formulas.gen_dip_chk(iteration * 2 + 1, '_1',
                                                dip_boolean))
                for i in range(len(self.obf_cir.output_wires)):
                    l = self.obf_cir.output_wires[i].name
                    f.append(
                        And(
                            Iff(dip_out[i],
                                Symbol(l + '@{}'.format(iteration * 2))),
                            Iff(dip_out[i],
                                Symbol(l + '@{}'.format(iteration * 2 + 1)))))
                f = And(f)

                solver_obf.add_assertion(f)
                solver_key.add_assertion(f)
                iteration += 1
                logging.warning('iteration: {}'.format(iteration))
            else:
                logging.warning('print keys')
                logging.warning('stateful keys: {}'.format(len(stateful_keys)))
                if solver_key.solve():
                    key = ''
                    for l in self.obf_cir.key_wires:
                        if solver_key.get_py_value(Symbol(l.name + '_0')):
                            key += '1'
                        else:
                            key += '0'
                    print("key=%s" % key)
                else:
                    logging.critical('key solver returned UNSAT')
                return