class TestSolve(unittest.TestCase):
    def setUp(self) :
        self.solver = Solver(threads = 2)

    def test_wrong_args(self):
        self.assertRaises(TypeError, self.solver.add_clause, 'A')
        self.assertRaises(TypeError, self.solver.add_clause, 1)
        self.assertRaises(TypeError, self.solver.add_clause, 1.0)
        self.assertRaises(TypeError, self.solver.add_clause, object())
        self.assertRaises(TypeError, self.solver.add_clause, ['a'])
        self.assertRaises(TypeError, self.solver.add_clause, [[1, 2], [3, None]])
        self.assertRaises(ValueError, self.solver.add_clause, [1, 0])

    def test_no_clauses(self):
        for n in range(7):
            self.assertEqual(self.solver.solve([]), (True, (None,)))

    def test_cnf1(self):
        for cl in clauses1:
            self.solver.add_clause(cl)
        res, solution = self.solver.solve()
        self.assertEqual(res, True)
        self.assertTrue(check_solution(clauses1, solution))

    def test_bad_iter(self):
        class Liar:
            def __iter__(self): return None
        self.assertRaises(TypeError, self.solver.add_clause, Liar())

    def test_cnf2(self):
        for cl in clauses2:
            self.solver.add_clause(cl)
        self.assertEqual(self.solver.solve(), (False, None))

    def test_cnf3(self):
        for cl in clauses3:
            self.solver.add_clause(cl)
        res, solution = self.solver.solve()
        self.assertEqual(res, True)
        self.assertTrue(check_solution(clauses3, solution))

    def test_cnf1_confl_limit(self):
        for lim in range(1, 20):
            self.setUp()
            for cl in clauses1:
                self.solver.add_clause(cl)

            res, solution = self.solver.solve()
            self.assertTrue(res == None or check_solution(clauses1, solution))
class TestDump(unittest.TestCase):

    def setUp(self):
        self.solver = Solver()

    def test_max_glue_missing(self):
        self.assertRaises(TypeError,
                          self.solver.start_getting_small_clauses, 4)

    def test_one_dump(self):
        with open("tests/test.cnf", "r") as x:
            for line in x:
                line = line.strip()
                if "p" in line or "c" in line:
                    continue

                out = [int(x) for x in line.split()[:-1]]
                self.solver.add_clause(out)

        res, _ = self.solver.solve()
        self.assertEqual(res, True)

        self.solver.start_getting_small_clauses(4, max_glue=10)
        x = self.solver.get_next_small_clause()
        self.assertNotEquals(x, None)
        self.solver.end_getting_small_clauses()
Beispiel #3
0
 def test_cnf3(self):
     solver = Solver()
     for cl in clauses3:
         solver.add_clause(cl)
     res, solution = solver.solve()
     self.assertEqual(res, True)
     self.assertTrue(check_solution(clauses3, solution))
Beispiel #4
0
    def test_cnf1_confl_limit(self):
        for lim in range(1, 20):
            solver = Solver(confl_limit=lim)
            for cl in clauses1:
                solver.add_clause(cl)

            res, solution = solver.solve()
            self.assertTrue(res == None or check_solution(clauses1, solution))
Beispiel #5
0
 def cryptominisat_solve(self):
     s = Solver()
     for clause in self.clauses:
         s.add_clause(clause)
     sat, solution = s.solve()
     if sat:
         return solution
     else:
         return 'UNSAT'
class TestXor(unittest.TestCase):

    def setUp(self):
        self.solver = Solver(threads=2)

    def test_wrong_args(self):
        self.assertRaises(TypeError, self.solver.add_xor_clause, [1, 2])
        self.assertRaises(ValueError, self.solver.add_xor_clause, [1, 0], True)
        self.assertRaises(
            ValueError, self.solver.add_xor_clause, [-1, 2], True)

    def test_binary(self):
        self.solver.add_xor_clause([1, 2], False)
        res, solution = self.solver.solve([1])
        self.assertEqual(res, True)
        self.assertEqual(solution, (None, True, True))

    def test_unit(self):
        self.solver.add_xor_clause([1], False)
        res, solution = self.solver.solve()
        self.assertEqual(res, True)
        self.assertEqual(solution, (None, False))

    def test_unit2(self):
        self.solver.add_xor_clause([1], True)
        res, solution = self.solver.solve()
        self.assertEqual(res, True)
        self.assertEqual(solution, (None, True))

    def test_3_long(self):
        self.solver.add_xor_clause([1, 2, 3], False)
        res, solution = self.solver.solve([1, 2])
        self.assertEqual(res, True)
        # self.assertEqual(solution, (None, True, True, False))

    def test_3_long2(self):
        self.solver.add_xor_clause([1, 2, 3], True)
        res, solution = self.solver.solve([1, -2])
        self.assertEqual(res, True)
        self.assertEqual(solution, (None, True, False, False))

    def test_long(self):
        for l in range(10, 30):
            self.setUp()
            toadd = []
            toassume = []
            solution_expected = [None]
            for i in range(1, l):
                toadd.append(i)
                solution_expected.append(False)
                if i != l - 1:
                    toassume.append(i * -1)

            self.solver.add_xor_clause(toadd, False)
            res, solution = self.solver.solve(toassume)
            self.assertEqual(res, True)
            self.assertEqual(solution, tuple(solution_expected))
Beispiel #7
0
class TestXor(unittest.TestCase):
    def setUp(self):
        self.solver = Solver(threads=2)

    def test_wrong_args(self):
        self.assertRaises(TypeError, self.solver.add_xor_clause, [1, 2])
        self.assertRaises(ValueError, self.solver.add_xor_clause, [1, 0], True)
        self.assertRaises(ValueError, self.solver.add_xor_clause, [-1, 2],
                          True)

    def test_binary(self):
        self.solver.add_xor_clause([1, 2], False)
        res, solution = self.solver.solve([1])
        self.assertEqual(res, True)
        self.assertEqual(solution, (None, True, True))

    def test_unit(self):
        self.solver.add_xor_clause([1], False)
        res, solution = self.solver.solve()
        self.assertEqual(res, True)
        self.assertEqual(solution, (None, False))

    def test_unit2(self):
        self.solver.add_xor_clause([1], True)
        res, solution = self.solver.solve()
        self.assertEqual(res, True)
        self.assertEqual(solution, (None, True))

    def test_3_long(self):
        self.solver.add_xor_clause([1, 2, 3], False)
        res, solution = self.solver.solve([1, 2])
        self.assertEqual(res, True)
        # self.assertEqual(solution, (None, True, True, False))

    def test_3_long2(self):
        self.solver.add_xor_clause([1, 2, 3], True)
        res, solution = self.solver.solve([1, -2])
        self.assertEqual(res, True)
        self.assertEqual(solution, (None, True, False, False))

    def test_long(self):
        for l in range(10, 30):
            self.setUp()
            toadd = []
            toassume = []
            solution_expected = [None]
            for i in range(1, l):
                toadd.append(i)
                solution_expected.append(False)
                if i != l - 1:
                    toassume.append(i * -1)

            self.solver.add_xor_clause(toadd, False)
            res, solution = self.solver.solve(toassume)
            self.assertEqual(res, True)
            self.assertEqual(solution, tuple(solution_expected))
Beispiel #8
0
    def test_long(self) :
        for l in range(10,30) :
            solver = Solver()
            toadd = []
            toassume = []
            solution_expected = [None]
            for i in range(1,l) :
                toadd.append(i)
                solution_expected.append(False)
                if i != l-1 :
                    toassume.append(i*-1)

            solver.add_xor_clause(toadd, False)
            res, solution = solver.solve(toassume)
            self.assertEqual(res, True)
            self.assertEqual(solution, tuple(solution_expected))
Beispiel #9
0
class TestDump(unittest.TestCase):
    def setUp(self):
        self.solver = Solver()

    def test_one_dump(self):
        with open("tests/test.cnf", "r") as x:
            for line in x:
                line = line.strip()
                if "p" in line or "c" in line:
                    continue

                out = [int(x) for x in line.split()[:-1]]
                self.solver.add_clause(out)

        res, _ = self.solver.solve()
        self.assertEqual(res, True)

        self.solver.start_getting_small_clauses(4)
        x = self.solver.get_next_small_clause()
        self.assertNotEquals(x, None)
        self.solver.end_getting_small_clauses()
Beispiel #10
0
    def is_system_uniquely_satisfiable(self, system, n):
        """
        Tests unique satisfiable by banning all zero solution
        :param system: 
        :param n: 
        :return: 
        """
        if not system:
            return False

        # Prep solver
        solver = Solver()
        for clause in system:
            solver.add_xor_clause(clause, False)

        # Ban all zero
        solver.add_clause(range(1, n + 1))

        sat, sol = solver.solve()

        # print "Found system is {0}".format(sat)

        return not sat
Beispiel #11
0
class ProgramSolver():
    def __init__(self,filename):
        self.s = Solver(threads = 3)
        self.tt = 0

        h2v = {} # hole 2 variable
                
        self.maximum_variable = -1
        with open(filename,'r') as f:
            for l in f:
                if len(l) > len('c hole ') and l[:len('c hole ')] == 'c hole ':
                    ms = re.findall(r'(\d+) \- (\d+)', l)[0]
                    assert int(ms[0]) == int(ms[1])
                    ms = int(ms[0])
                    n = int(re.findall(r'H__\S+_(\S+)\s',l)[0])
                    h2v[n] = ms
                elif len(l) > 0 and not 'c' in l and not 'p' in l:
                    vs = re.findall(r'(\-?\d+)',l)
                    assert vs[-1] == '0'
                    clause = [int(v) for v in vs[:-1] ]
                    self.maximum_variable = max([self.maximum_variable] +
                                                [abs(v) for v in clause ])
                    self.s.add_clause(clause)
        print "Loaded",filename," with",len(h2v),"holes"
        # convert the tape index into a sat variable
        self.tape2variable = [ v for h,v in sorted(h2v.items()) ]
        
        # converts a sat variable to a tape index
        self.variable2tape = dict([ (v,h) for h,v in h2v.items() ])



    def generate_variable(self):
        self.maximum_variable += 1
        return self.maximum_variable
    
    def random_projection(self):
        self.s.add_xor_clause([v for v in self.variable2tape if random.random() > 0.5 ],random.random() > 0.5)
        
    def try_solving(self,assumptions = None):
        print "About to run solver ==  ==  ==  > "
        start_time = time.time()
        if assumptions != None:
            result = self.s.solve(assumptions)
        else:
            result = self.s.solve()
        dt = (time.time() - start_time)
        self.tt += dt
        print "Ran solver in time",dt
        if result[0]:
            bindings = {}
            for v in range(len(result[1])):
                if v in self.variable2tape:
                    bindings[v] = result[1][v]
            print "Satisfiable."
            return bindings
        else:
            print "Unsatisfiable."
            return False

    def uniqueness_clause(self,tape):
        p,bit_mask = parse_tape(tape)
        clause = []
        for j in range(len(tape)):
            if bit_mask[j] == 1:
                # jth tape position
                v = self.tape2variable[j]
                if tape[j] == 1: v = -v
                clause += [v]
        return clause
        
    def is_solution_unique(self,tape):
        d = self.generate_variable()
        clause = [d] + self.uniqueness_clause(tape)
        print "uniqueness clause",clause
        self.s.add_clause(clause)
        result = self.try_solving([-d])
        self.s.add_clause([d]) # make the clause documents they satisfied
        if result:
            tp = self.holes2tape(result)
            print "alternative:",parse_tape(tp)
            print "alternative tape:",tp
            return False
        else:
            return True
                

    def holes2tape(self,result):
        return [ (1 if result[v] else 0) for v in self.tape2variable ]

    def try_sampling(self,subspace_dimension):
        for j in range(subspace_dimension):
            self.random_projection()
        result = self.try_solving()
        if result:
            print "Random projection satisfied"
            tp = self.holes2tape(result)
            print parse_tape(tp)[0]
            if self.is_solution_unique(tp):
                print "Unique. Accepted."
            else:
                print "Sample rejected"

    def adaptive_sample(self):
        subspace_dimension = 1
        result = self.try_solving()
        if result:
            print "Formula satisfied"
            for j in range(subspace_dimension):
                self.random_projection()
            while True:
                print "\n\niterating:"
                result = self.try_solving()
                if result:
                    print "Satisfied %d constraints" % subspace_dimension
                    tp = self.holes2tape(result)
                    print parse_tape(tp)
                    print "tape = ",tp
                    if self.is_solution_unique(tp):
                        print "UNIQUE"
                        print "<<< ==  ==  == >>>"
                    self.random_projection()
                    subspace_dimension += 1
                else:
                    print "Rejected %d projections" % subspace_dimension
                    print "total time = ",self.tt
                    break


    def enumerate_solutions(self):
        solutions = []
        result = self.try_solving()
        d = self.generate_variable()
        logZ = float('-inf')
        while result:
            tp = self.holes2tape(result)
            program,mask = parse_tape(tp)
            solutions = solutions + [program]
            specified = sum(mask)
            logZ = lse(logZ, -specified * 0.693)
            print "Enumerated program", program, "with", specified, "specified bits."
            self.s.add_clause([d] + self.uniqueness_clause(tp))
            result = self.try_solving([-d])
        print "log(z) = ",logZ, "\t1/p = ", math.exp(-logZ)
        return solutions
Beispiel #12
0
class CryptoMiniSat(SatSolver):
    r"""
    CryptoMiniSat Solver.

    INPUT:

    - ``verbosity`` -- an integer between 0 and 15 (default: 0). Verbosity.

    - ``confl_limit`` -- an integer (default: ``None``). Abort after this many
      conflicts. If set to ``None``, never aborts.

    - ``threads`` -- an integer (default: None). The number of thread to
      use. If set to ``None``, the number of threads used corresponds to the
      number of cpus.

    EXAMPLES::

        sage: from sage.sat.solvers.cryptominisat import CryptoMiniSat
        sage: solver = CryptoMiniSat()                                  # optional - cryptominisat
    """
    def __init__(self, verbosity=0, confl_limit=None, threads=None):
        r"""
        Constuct a new CryptoMiniSat instance.

        See the documentation class for the description of inputs.

        EXAMPLES::

            sage: from sage.sat.solvers.cryptominisat import CryptoMiniSat
            sage: solver = CryptoMiniSat(threads=1)                     # optional - cryptominisat
        """
        if threads is None:
            from sage.parallel.ncpus import ncpus
            threads = ncpus()
        if confl_limit is None:
            from sys import maxint
            confl_limit = maxint
        try:
            from pycryptosat import Solver
        except ImportError:
            from sage.misc.package import PackageNotFoundError
            raise PackageNotFoundError("cryptominisat")
        self._solver = Solver(verbose=int(verbosity), confl_limit=int(confl_limit), threads=int(threads))
        self._nvars = 0
        self._clauses = []

    def var(self, decision=None):
        r"""
        Return a *new* variable.

        INPUT:

        - ``decision`` -- accepted for compatibility with other solvers, ignored.

        EXAMPLES::

            sage: from sage.sat.solvers.cryptominisat import CryptoMiniSat
            sage: solver = CryptoMiniSat()                                  # optional - cryptominisat
            sage: solver.var()                                              # optional - cryptominisat
            1

            sage: solver.add_clause((-1,2,-4))                              # optional - cryptominisat
            sage: solver.var()                                              # optional - cryptominisat
            5
        """
        return self._nvars + 1

    def nvars(self):
        r"""
        Return the number of variables. Note that for compatibility with DIMACS
        convention, the number of variables corresponds to the maximal index of
        the variables used. 
        
        EXAMPLES::

            sage: from sage.sat.solvers.cryptominisat import CryptoMiniSat
            sage: solver = CryptoMiniSat()                                  # optional - cryptominisat
            sage: solver.nvars()                                            # optional - cryptominisat
            0

        If a variable with intermediate index is not used, it is still
        considered as a variable::

            sage: solver.add_clause((1,-2,4))                               # optional - cryptominisat
            sage: solver.nvars()                                            # optional - cryptominisat
            4
        """
        return self._nvars

    def add_clause(self, lits):
        r"""
        Add a new clause to set of clauses.

        INPUT:

        - ``lits`` -- a tuple of nonzero integers.

        .. note::

            If any element ``e`` in ``lits`` has ``abs(e)`` greater
            than the number of variables generated so far, then new
            variables are created automatically.

        EXAMPLES::

            sage: from sage.sat.solvers.cryptominisat import CryptoMiniSat
            sage: solver = CryptoMiniSat()                                  # optional - cryptominisat
            sage: solver.add_clause((1, -2 , 3))                            # optional - cryptominisat
        """
        if 0 in lits:
            raise ValueError("0 should not appear in the clause: {}".format(lits))
        # cryptominisat does not handle Sage integers
        lits = tuple(int(i) for i in lits)
        self._nvars = max(self._nvars, max(abs(i) for i in lits))
        self._solver.add_clause(lits)
        self._clauses.append((lits, False, None))

    def add_xor_clause(self, lits, rhs=True):
        r"""
        Add a new XOR clause to set of clauses.

        INPUT:

        - ``lits`` -- a tuple of positive integers.

        - ``rhs`` -- boolean (default: ``True``). Whether this XOR clause should
          be evaluated to ``True`` or ``False``.

        EXAMPLES::

            sage: from sage.sat.solvers.cryptominisat import CryptoMiniSat
            sage: solver = CryptoMiniSat()                                  # optional - cryptominisat
            sage: solver.add_xor_clause((1, 2 , 3), False)                  # optional - cryptominisat
        """
        if 0 in lits:
            raise ValueError("0 should not appear in the clause: {}".format(lits))
        # cryptominisat does not handle Sage integers
        lits = tuple(int(i) for i in lits)
        self._nvars = max(self._nvars, max(abs(i) for i in lits))
        self._solver.add_xor_clause(lits, rhs)
        self._clauses.append((lits, True, rhs))

    def __call__(self, assumptions=None):
        r"""
        Solve this instance.

        OUTPUT:

        - If this instance is SAT: A tuple of length ``nvars()+1``
          where the ``i``-th entry holds an assignment for the
          ``i``-th variables (the ``0``-th entry is always ``None``).

        - If this instance is UNSAT: ``False``.

        EXAMPLES::

            sage: from sage.sat.solvers.cryptominisat import CryptoMiniSat
            sage: solver = CryptoMiniSat()                                  # optional - cryptominisat
            sage: solver.add_clause((1,2))                                  # optional - cryptominisat
            sage: solver.add_clause((-1,2))                                 # optional - cryptominisat
            sage: solver.add_clause((-1,-2))                                # optional - cryptominisat
            sage: solver()                                                  # optional - cryptominisat
            (None, False, True)

            sage: solver.add_clause((1,-2))                                 # optional - cryptominisat
            sage: solver()                                                  # optional - cryptominisat
            False
        """
        satisfiable, assignments = self._solver.solve()
        if satisfiable:
            return assignments
        else:
            return False

    def __repr__(self):
        r"""
        TESTS::

            sage: from sage.sat.solvers.cryptominisat import CryptoMiniSat
            sage: solver = CryptoMiniSat()                                  # optional - cryptominisat
            sage: solver                                                    # optional - cryptominisat
            CryptoMiniSat solver: 0 variables, 0 clauses.
        """
        return "CryptoMiniSat solver: {} variables, {} clauses.".format(self.nvars(), len(self.clauses()))

    def clauses(self, filename=None):
        r"""
        Return original clauses.

        INPUT:

        - ``filename`` -- if not ``None`` clauses are written to ``filename`` in
          DIMACS format (default: ``None``)

        OUTPUT:

            If ``filename`` is ``None`` then a list of ``lits, is_xor, rhs``
            tuples is returned, where ``lits`` is a tuple of literals,
            ``is_xor`` is always ``False`` and ``rhs`` is always ``None``.

            If ``filename`` points to a writable file, then the list of original
            clauses is written to that file in DIMACS format.

        EXAMPLES::

            sage: from sage.sat.solvers import CryptoMiniSat
            sage: solver = CryptoMiniSat()                              # optional - cryptominisat
            sage: solver.add_clause((1,2,3,4,5,6,7,8,-9))               # optional - cryptominisat
            sage: solver.add_xor_clause((1,2,3,4,5,6,7,8,9), rhs=True)  # optional - cryptominisat
            sage: solver.clauses()                                      # optional - cryptominisat
            [((1, 2, 3, 4, 5, 6, 7, 8, -9), False, None),
            ((1, 2, 3, 4, 5, 6, 7, 8, 9), True, True)]

        DIMACS format output::

            sage: from sage.sat.solvers import CryptoMiniSat
            sage: solver = CryptoMiniSat()                      # optional - cryptominisat
            sage: solver.add_clause((1, 2, 4))                  # optional - cryptominisat
            sage: solver.add_clause((1, 2, -4))                 # optional - cryptominisat
            sage: fn = tmp_filename()                           # optional - cryptominisat
            sage: solver.clauses(fn)                            # optional - cryptominisat
            sage: print(open(fn).read())                        # optional - cryptominisat
            p cnf 4 2
            1 2 4 0
            1 2 -4 0
            <BLANKLINE>

        Note that in cryptominisat, the DIMACS standard format is augmented with
        the following extension: having an ``x`` in front of a line makes that
        line an XOR clause::

            sage: solver.add_xor_clause((1,2,3), rhs=True)      # optional - cryptominisat
            sage: solver.clauses(fn)                            # optional - cryptominisat
            sage: print(open(fn).read())                        # optional - cryptominisat
            p cnf 4 3
            1 2 4 0
            1 2 -4 0
            x1 2 3 0
            <BLANKLINE>

        Note that inverting an xor-clause is equivalent to inverting one of the
        variables::

            sage: solver.add_xor_clause((1,2,5),rhs=False)      # optional - cryptominisat
            sage: solver.clauses(fn)                            # optional - cryptominisat
            sage: print(open(fn).read())                        # optional - cryptominisat
            p cnf 5 4
            1 2 4 0
            1 2 -4 0
            x1 2 3 0
            x1 2 -5 0
            <BLANKLINE> 
        """
        if filename is None:
            return self._clauses
        else:
            from sage.sat.solvers.dimacs import DIMACS
            DIMACS.render_dimacs(self._clauses, filename, self.nvars())
Beispiel #13
0
 def test_unit2(self) :
     solver = Solver()
     solver.add_xor_clause([1], True)
     res, solution = solver.solve()
     self.assertEqual(res, True)
     self.assertEqual(solution, (None, True))
Beispiel #14
0
 def test_no_clauses(self):
     solver = Solver()
     for n in range(7):
         self.assertEqual(solver.solve([]), (True, (None,)))
Beispiel #15
0
class TestSolve(unittest.TestCase):
    def setUp(self):
        self.solver = Solver(threads=2)

    def test_wrong_args(self):
        self.assertRaises(TypeError, self.solver.add_clause, 'A')
        self.assertRaises(TypeError, self.solver.add_clause, 1)
        self.assertRaises(TypeError, self.solver.add_clause, 1.0)
        self.assertRaises(TypeError, self.solver.add_clause, object())
        self.assertRaises(TypeError, self.solver.add_clause, ['a'])
        self.assertRaises(TypeError, self.solver.add_clause,
                          [[1, 2], [3, None]])
        self.assertRaises(ValueError, self.solver.add_clause, [1, 0])

    def test_no_clauses(self):
        for _ in range(7):
            self.assertEqual(self.solver.solve([]), (True, (None, )))

    def test_cnf1(self):
        for cl in clauses1:
            self.solver.add_clause(cl)
        res, solution = self.solver.solve()
        self.assertEqual(res, True)
        self.assertTrue(check_solution(clauses1, solution))

    def test_bad_iter(self):
        class Liar:
            def __iter__(self):
                return None

        self.assertRaises(TypeError, self.solver.add_clause, Liar())

    def test_cnf2(self):
        for cl in clauses2:
            self.solver.add_clause(cl)
        self.assertEqual(self.solver.solve(), (False, None))

    def test_cnf3(self):
        for cl in clauses3:
            self.solver.add_clause(cl)
        res, solution = self.solver.solve()
        self.assertEqual(res, True)
        self.assertTrue(check_solution(clauses3, solution))

    def test_cnf1_confl_limit(self):
        for _ in range(1, 20):
            self.setUp()
            for cl in clauses1:
                self.solver.add_clause(cl)

            res, solution = self.solver.solve()
            self.assertTrue(res is None or check_solution(clauses1, solution))

    def test_by_re_curse(self):
        self.solver.add_clause([-1, -2, 3])
        res, _ = self.solver.solve()
        self.assertEqual(res, True)

        self.solver.add_clause([-5, 1])
        self.solver.add_clause([4, -3])
        self.solver.add_clause([2, 3, 5])
        res, _ = self.solver.solve()
        self.assertEqual(res, True)
Beispiel #16
0
class ProgramSolver():
    def __init__(self, filename):
        self.s = Solver(threads=3)
        self.tt = 0

        h2v = {}  # hole 2 variable

        self.maximum_variable = -1
        with open(filename, 'r') as f:
            for l in f:
                if len(l) > len('c hole ') and l[:len('c hole ')] == 'c hole ':
                    ms = re.findall(r'(\d+) \- (\d+)', l)[0]
                    assert int(ms[0]) == int(ms[1])
                    ms = int(ms[0])
                    n = int(re.findall(r'H__\S+_(\S+)\s', l)[0])
                    h2v[n] = ms
                elif len(l) > 0 and not 'c' in l and not 'p' in l:
                    vs = re.findall(r'(\-?\d+)', l)
                    assert vs[-1] == '0'
                    clause = [int(v) for v in vs[:-1]]
                    self.maximum_variable = max([self.maximum_variable] +
                                                [abs(v) for v in clause])
                    self.s.add_clause(clause)
        print "Loaded", filename, " with", len(h2v), "holes"
        # convert the tape index into a sat variable
        self.tape2variable = [v for h, v in sorted(h2v.items())]

        # converts a sat variable to a tape index
        self.variable2tape = dict([(v, h) for h, v in h2v.items()])

    def generate_variable(self):
        self.maximum_variable += 1
        return self.maximum_variable

    def random_projection(self):
        self.s.add_xor_clause(
            [v for v in self.variable2tape if random.random() > 0.5],
            random.random() > 0.5)

    def try_solving(self, assumptions=None):
        print "About to run solver ==  ==  ==  > "
        start_time = time.time()
        if assumptions != None:
            result = self.s.solve(assumptions)
        else:
            result = self.s.solve()
        dt = (time.time() - start_time)
        self.tt += dt
        print "Ran solver in time", dt
        if result[0]:
            bindings = {}
            for v in range(len(result[1])):
                if v in self.variable2tape:
                    bindings[v] = result[1][v]
            print "Satisfiable."
            return bindings
        else:
            print "Unsatisfiable."
            return False

    def uniqueness_clause(self, tape):
        p, bit_mask = parse_tape(tape)
        clause = []
        for j in range(len(tape)):
            if bit_mask[j] == 1:
                # jth tape position
                v = self.tape2variable[j]
                if tape[j] == 1: v = -v
                clause += [v]
        return clause

    def is_solution_unique(self, tape):
        d = self.generate_variable()
        clause = [d] + self.uniqueness_clause(tape)
        print "uniqueness clause", clause
        self.s.add_clause(clause)
        result = self.try_solving([-d])
        self.s.add_clause([d])  # make the clause documents they satisfied
        if result:
            tp = self.holes2tape(result)
            print "alternative:", parse_tape(tp)
            print "alternative tape:", tp
            return False
        else:
            return True

    def holes2tape(self, result):
        return [(1 if result[v] else 0) for v in self.tape2variable]

    def try_sampling(self, subspace_dimension):
        for j in range(subspace_dimension):
            self.random_projection()
        result = self.try_solving()
        if result:
            print "Random projection satisfied"
            tp = self.holes2tape(result)
            print parse_tape(tp)[0]
            if self.is_solution_unique(tp):
                print "Unique. Accepted."
            else:
                print "Sample rejected"

    def adaptive_sample(self):
        subspace_dimension = 1
        result = self.try_solving()
        if result:
            print "Formula satisfied"
            for j in range(subspace_dimension):
                self.random_projection()
            while True:
                print "\n\niterating:"
                result = self.try_solving()
                if result:
                    print "Satisfied %d constraints" % subspace_dimension
                    tp = self.holes2tape(result)
                    print parse_tape(tp)
                    print "tape = ", tp
                    if self.is_solution_unique(tp):
                        print "UNIQUE"
                        print "<<< ==  ==  == >>>"
                    self.random_projection()
                    subspace_dimension += 1
                else:
                    print "Rejected %d projections" % subspace_dimension
                    print "total time = ", self.tt
                    break

    def enumerate_solutions(self):
        solutions = []
        result = self.try_solving()
        d = self.generate_variable()
        logZ = float('-inf')
        while result:
            tp = self.holes2tape(result)
            program, mask = parse_tape(tp)
            solutions = solutions + [program]
            specified = sum(mask)
            logZ = lse(logZ, -specified * 0.693)
            print "Enumerated program", program, "with", specified, "specified bits."
            self.s.add_clause([d] + self.uniqueness_clause(tp))
            result = self.try_solving([-d])
        print "log(z) = ", logZ, "\t1/p = ", math.exp(-logZ)
        return solutions
Beispiel #17
0
class Solvatore(object):
    def __init__(self):
        '''
        Create a new Solvatore instance
        '''
        self.cipher = None
        self.round_states =[]
        self.rounds = None
        self.fresh_conditions = False
        self.sbox_cnfs = {}
        self.model_size = 0
        return

    def load_cipher(self, cipher):
        '''
        Load a cipher description to Solvatore
        '''
        self.model_size = 0
        self.cipher = cipher
        self.state_size = cipher.state_size
        self.fresh_conditions = False
        self.set_rounds(cipher.rounds)
        self.sbox_cnfs = {}
        
        return
    

        
    def addclause(self,clause):
        self.solver.add_clause(clause)
        self.model_size +=1
        return
        
    def set_rounds(self, rounds):
        '''
        Specify the number of times the round transition is applied
        '''
        rounds = int(rounds)
        if not rounds >= 0:
            print('The number of rounds needs to be positive')
            sys.exit(1)
        self.rounds = rounds
        self.fresh_conditions = False
        return

    def create_conditions(self):
        '''
        Create conditions for solver from cipher description
        '''
        if self.cipher == None:
            print('You need to load a cipher.')
            sys.exit(1)
        if  self.rounds == None:
            print('You need to specify the number of rounds.')
            sys.exit(1)
        if self.fresh_conditions:
            return
        self.solver = Solver()
        self.next_variable = 1
        self.state_bit = [i + 1 for i in range(self.state_size)]
        state = [self.state_bit[i] for i in range(self.state_size)]
        self.round_states.append(state)
        self.next_variable += self.state_size
        self.temporary = {i: None for i in self.cipher.temporaries}
        self.sbox_tmps = {}

        for rnd in range(self.rounds):
            for step in self.cipher.transition:
                if step[-1] == 'XOR':
                    self.apply_xor(step[0], step[1], step[2])
                elif step[-1] == 'AND':
                    self.apply_and(step[0], step[1], step[2])
                elif step[-1] == 'PERM':
                    self.apply_permutation(step[0])
                elif step[-1] == 'SBOX':
                    self.apply_sbox(step[0], step[1], step[2])
                elif step[-1] == 'MOV':
                    self.apply_mov(step[0], step[1])
            self.set_temporaries_to_zero()
            state = [self.state_bit[i] for i in range(self.state_size)]
            self.round_states.append(state)
        self.fresh_conditions = True
        return

    def apply_mov(self, target, source):
        """ When target != source, we have the following allowed trails
            on source, target:
              0 0 -> 0 0
              1 0 -> 0 1
              1 0 -> 1 0
            In particular, when the target is a state bit, it cannot be equal
            to 1.
        """
        if target == source:
            return
        old_source, new_source = self.get_variables(source)
        old_target, new_target = self.get_variables(target)
        #Ensure old target is 0
        if old_target != None:
            self.addclause([-old_target])
        self.addclause([-old_source, new_source, new_target])
        self.addclause([old_source, -new_source])
        self.addclause([old_source, -new_target])
        self.addclause([-new_source, -new_target])
        return

    def apply_xor(self, target, source_1, source_2):
        """ First move source_2 to target, then apply the following
            rules on source_1, target:
              0 0 -> 0 0
              0 1 -> 0 1
              1 0 -> 1 0
              1 0 -> 0 1
              1 1 -> 1 1
        """
        if source_1 == source_2:
            old_target, new_target = self.get_variables(target)
            self.addclause([-new_target])
            return
        if source_1 != target:
            self.apply_mov(target, source_2)
            source = source_1
        else:
            source = source_2
        old_source, new_source = self.get_variables(source)
        old_target, new_target = self.get_variables(target)
        self.addclause([-new_source, old_source])
        self.addclause([new_target, -old_target])
        self.addclause([-new_source, -new_target, old_target])
        self.addclause([new_source, new_target, -old_source])
        self.addclause([-new_target, old_source, old_target])
        self.addclause([new_source, -old_source, -old_target])
        return

    def apply_and(self, target, source_1, source_2):
        """ First move source_2 to target, then apply the following
            rules on source_1, target:
              0 0 -> 0 0
              0 1 -> 0 1
              1 0 -> 1 0
              1 0 -> 0 1
              1 1 -> 0 1
        """
        if source_1 == source_2:
            self.apply_mov(target, source_1)
            return
        if source_1 != target:
            self.apply_mov(target, source_2)
            source = source_1
        else:
            source = source_2
        old_source, new_source = self.get_variables(source)
        old_target, new_target = self.get_variables(target)
        self.addclause([-new_source, old_source])
        self.addclause([new_target, -old_target])
        self.addclause([-new_source, -new_target])
        self.addclause([new_source, new_target, -old_source])
        self.addclause([-new_target, old_source, old_target])
        return

    def apply_permutation(self, permutation):
        '''
        Relabel the state bits according to the permutation
        '''
        last_value = int(permutation[-1][1:])
        tmp = self.state_bit[last_value]
        for bit in permutation:
            number = int(bit[1:])
            self.state_bit[number], tmp = tmp, self.state_bit[number]
        return

    def apply_sbox(self, sbox_name, input_bits, output_bits):
        cnf = self.get_sbox_cnf(sbox_name, len(input_bits), len(output_bits))
        for bit in input_bits:
            self.apply_mov('copy_' + bit, bit)
        sources = [self.get_variables('copy_' + bit) for bit in input_bits]
        targets = [self.get_variables(bit) for bit in output_bits]
        for in_clause, out_clause in cnf:
            clause = [sources[i][0] if in_clause[i] == 1 else \
                      -sources[i][0] for i in range(len(sources))]
            clause += [targets[i][1] if out_clause[i] == 1 else \
                       -targets[i][1] for i in range(len(targets))]
            self.addclause(clause)
        # Ensure that the target bits did not contain 1s
        for i in range(len(targets)):
            self.addclause([-targets[i][0]])

    def get_sbox_cnf(self, sbox_name, n, m):
        if sbox_name in self.sbox_cnfs:
            return self.sbox_cnfs[sbox_name]
        if sbox_name not in self.cipher.sboxes:
            raise KeyError('sbox_name not in cipher.sboxes')
        sbox = self.cipher.sboxes[sbox_name]
        anf = self.get_anf(sbox)
        # Store the products of the output anfs in products
        products = []
        for output_dp in range(2**m):
            # Set the empty product to the constant 1
            prod = [0 for _ in range(2**n)]
            prod[0] = 1
            # Multiply the anfs of the selected output bits
            for i in range(m):
                if (output_dp >> i) & 1:
                    # multiply prod with output bit
                    result = [0 for _ in range(2**n)]
                    for j in range(2**n):
                        for k in range(2**n):
                            result[j|k] ^= prod[j] * ((anf[k] >> i) & 1)
                    prod = result
            products.append(prod)
        # Store guaranteed zero derivatives
        zero_derivatives = []
        for input_dp in range(2**n):
            for output_dp in range(2**m):
                prod = products[output_dp]
                # If not term in the output product contains all selected
                # input bits, the derivative is zero, so we store it.
                if not any(input_dp & i == input_dp for i in range(2**n) if prod[i]):
                    zero_derivatives.append((input_dp, output_dp))
        # As the zero derivatives are the transitions we want to exclude,
        # we can use them to generate the CNF.
        cnf = []
        for in_dp, out_dp in zero_derivatives:
            cnf.append((tuple(((in_dp >> i) & 1) ^ 1 for i in range(n)),\
                       tuple(((out_dp >> i) & 1) ^ 1 for i in range(m))))
        self.sbox_cnfs[sbox_name] = cnf
        return cnf

    def get_anf(self, sbox):
        n = int(log(len(sbox), 2))
        anf = [x for x in sbox]
        for i in range(n):
            mask = (1 << i)
            for j in range(len(anf)):
                if j & mask:
                    anf[j] ^= anf[j^mask]
        return anf

    def is_sbox_bijective(self, sbox, n, m):
        if n != m:
            return False
        if len(set(sbox[i & (2**n-1)] for i in range(2**n))) != 2**n:
            return False
        return True

    def set_temporaries_to_zero(self):
        '''
        To guarantee that temporaries are zero at end of rounds
        '''
        for tmp in self.temporary.values():
            self.addclause([-tmp])
        return

    def get_variables(self, bit):
        if bit[0] == 's':
            number = int(bit[1:])
            old_bit = self.state_bit[number]
            new_bit = self.next_variable
            self.state_bit[number] = new_bit
            self.next_variable += 1
        elif bit[0] == 't':
            old_bit = self.temporary[bit]
            new_bit = self.next_variable
            self.temporary[bit] = new_bit
            self.next_variable += 1
        else:
            if bit in self.sbox_tmps:
                old_bit = self.sbox_tmps[bit]
            else:
                old_bit = None
            new_bit = self.next_variable
            self.next_variable += 1
            self.sbox_tmps[bit] = new_bit
        return old_bit, new_bit

    def is_bit_balanced(self, bit, rnd, active):
        self.create_conditions()
        if bit >= self.state_size:
            raise ValueError("There are only {} state bits."\
                             .format(self.state_size))
        active_bits = map(int, active)
        for active_bit in active_bits:
            if active_bit >= self.state_size or active_bit < 0:
                print('Bit {} designated as active bit, but there are only '
                      '{} state bits.'.format(active_bit, self.state_size))
                sys.exit(1)
        conditions = []
        for i in range(self.state_size):
            if i in active_bits:
                conditions.append(self.round_states[0][i])
            else:
                conditions.append(-self.round_states[0][i])
        for i in range(self.state_size):
            if i != bit:
                conditions.append(-self.round_states[rnd][i])
            else:
                conditions.append(self.round_states[rnd][i])
        unbalanced, _ = self.solver.solve(conditions)
        return not unbalanced

    def distinguisher_exists(self, rnd):
        #TODO: not working at the moment
        self.create_conditions()

        # No all zero input, no all active input
        input_allzero = [self.round_states[0][i] for i in range(self.state_size)]
        input_allone = [-self.round_states[0][i] for i in range(self.state_size)]
        self.addclause(input_allzero)
        self.addclause(input_allone)

        # At least one unit vector reachable
        out_cond = [self.round_states[rnd][i] for i in range(self.state_size)]
        self.addclause(out_cond)

        # No pair should be sat together
        for pair in combinations(range(self.state_size), 2):
            exclude_pair = [-self.round_states[rnd][pair[0]],
                            -self.round_states[rnd][pair[1]]]
            self.addclause(exclude_pair)
        sat, solution = self.solver.solve()

        # for r in range(rnd + 1):
        #     print("\nRound {}: ".format(r))
        #     for i in self.round_states[r]:
        #         print int(solution[i]),

        return sat, solution


    def is_reachable(self, output_bits, rnd, active_bits):
        self.create_conditions()
        conditions = []
        for i in range(self.state_size):
            if i in active_bits:
                conditions.append(self.round_states[0][i])
            else:
                conditions.append(-self.round_states[0][i])
        for i in range(self.state_size):
            if i not in output_bits:
                conditions.append(-self.round_states[rnd][i])
            else:
                conditions.append(self.round_states[rnd][i])
        reachable, _ = self.solver.solve(conditions)
        return reachable
Beispiel #18
0
def solve_or(model):
    s = Solver()
    for clause in model.clauses:
        s.add_clause(clause)
    return s.solve()
Beispiel #19
0
def solve_xor(model):
    s = Solver()
    for clause in model.clauses:
        s.add_xor_clause(_only_positive(clause), rhs=_get_clause_parity(clause))
    return s.solve()
Beispiel #20
0
 def test_binary(self) :
     solver = Solver()
     solver.add_xor_clause([1,2], False)
     res, solution = solver.solve([1])
     self.assertEqual(res, True)
     self.assertEqual(solution, (None, True, True))
Beispiel #21
0
class SatHandler:

    def __init__(self, tracesList, numOfInputs, usePyCrypto=False):

        self.numOfInputs = numOfInputs
        self.tracesList = tracesList
        self.usePyCrypto = usePyCrypto

        self.s = None
        if self.usePyCrypto and PyCryptoSat_Import_Successful:
            self.s = Solver()
            print("Initialized solver")

        self.clauses = ClauseHandler()

        ## Creates the acyclic Fsm of trace nodes.
        self.constructTraceTree() # self.traceTree is created and filled

        self.numOfTraceNodes = len(self.traceTree.nodes)

        ###
        #This variable is used as global, handle with care
        self.numOfNodes = 0
        ###

    def varToNum(self, var):
        x, s = var.split("_")

        return int(x) * self.numOfNodes + int(s) + 1

    def numToVar(self, num):
        num -= 1

        return "{}_{}".format(int(num / self.numOfNodes), num % self.numOfNodes)

    def newVarToNum(self, newVar):
        ## newVar syntax is "y_a_i_j"
        y, a, i, j = newVar.split("_")

        a = int(a)

        # a+1 so that it doesn't clash with normal vars
        offset = (self.numOfTraceNodes * self.numOfNodes) * (a+1) 
        return offset + self.varToNum(i+"_"+j)

    def numToNewVar(self, num):

        offset = self.numOfTraceNodes * self.numOfNodes
        offsetCount = int(num/offset)

        a = offsetCount - 1

        num = num - offset * offsetCount 

        x, s = self.numToVar(num).split("_")
        return "y_{}_{}_{}".format(a, x, s)

    def constructTraceTree(self, nameInBreathFirst = True):

        self.traceTree = FSM(0, self.numOfInputs, 0) # numOfNodes, numOfInputs, numOfOutputs

        rootNode = FSM.Node(self.numOfInputs, 0) # 0 is the index
        self.traceTree.nodes.append(rootNode)

        currentNode = rootNode
        temp = None

        for trace in self.tracesList:

            currentNode = rootNode #reset this
            temp = None

            for ioTuple in trace:

                # This transition exists
                if currentNode.transitions[ioTuple[0]][0] != None: 
                    currentNode = currentNode.transitions[ioTuple[0]][0] # Keep moving

                # This transition is new, create new branch
                else:
                    temp = FSM.Node(self.numOfInputs, len(self.traceTree.nodes))
                    self.traceTree.nodes.append(temp)

                    currentNode.transitions[ioTuple[0]] = (temp, ioTuple[1])

                    currentNode = temp

        ### Name the nodes in a BDF way if the argument is true

        if nameInBreathFirst:
            newNodeList = []

            nodeQ = deque()

            nodeQ.append(rootNode)

            while len(nodeQ) != 0:

                currentNode = nodeQ.popleft()

                for transition in currentNode.transitions:

                    if transition[0] != None:
                        transition[0].parent = currentNode
                        nodeQ.append(transition[0])

                ## Set the index and append to the list
                currentNode.index = len(newNodeList)
                newNodeList.append(currentNode)

            self.traceTree.nodes = newNodeList

    def constructClauses(self, writeToFile=False, filename="SatFile", numOfNodes=-1):
        #This is for debugging
        if numOfNodes != -1:
            self.numOfNodes = numOfNodes

        countClauses = 0

        #
        #Each trace node must correspond to at least one node
        #Does not the 0_0 condition
        #
        for i in range(self.numOfTraceNodes):
            tempList = []
            for k in range(self.numOfNodes):
                tempList.append("x_{}_{}".format(i, k))
            self.clauses.addClause(tempList)
            countClauses += 1

        #
        #Each trace node must correspond to at most one node
        #
        for i in range(self.numOfTraceNodes):
            for k in range(self.numOfNodes - 1):
                for j in range(k + 1, self.numOfNodes):
                    self.clauses.addClause(["-x_{}_{}".format(i,k), "-x_{}_{}".format(i,j)])
                    countClauses += 1

        #
        #Check each transition between tracenodes and act accordingly
        #
        for i in range(self.numOfTraceNodes - 1):
            for j in range(i + 1, self.numOfTraceNodes):
                for k in range(len(self.traceTree.nodes[i].transitions)):

                    # if outputs are not None
                    if (self.traceTree.nodes[i].transitions[k][1] != None and self.traceTree.nodes[j].transitions[k][1] != None):
                        #if outputs are different
                        if (self.traceTree.nodes[i].transitions[k][1] != self.traceTree.nodes[j].transitions[k][1]):

                            for h in range(self.numOfNodes):
                                self.clauses.addClause(["-x_{}_{}".format(i,h), "-x_{}_{}".format(j,h)])
                                countClauses += 1
                            break

        for x in range(self.numOfTraceNodes):
            for a in range(len(self.traceTree.nodes[x].transitions)):
                for i in range(self.numOfNodes):
                    for j in range(self.numOfNodes):
                        if self.traceTree.nodes[x].transitions[a][0] != None:
                            #file.write("y_a_i_j -x_i -x.transitions[a][0]_j")
                            self.clauses.addClause(["y_{}_{}_{}".format(a,i,j), "-x_{}_{}".format(x,i), "-x_{}_{}".format(self.traceTree.nodes[x].transitions[a][0].index, j)])
                            countClauses += 1

        for a in range(len(self.traceTree.nodes[0].transitions)):
            for i in range(self.numOfNodes):
                for h in range(self.numOfNodes - 1):
                    for j in range(h + 1, self.numOfNodes):
                        self.clauses.addClause(["-y_{}_{}_{}".format(a,i,h), "-y_{}_{}_{}".format(a,i,j)])
                        countClauses += 1

        for a in range(len(self.traceTree.nodes[0].transitions)):
            for i in range(self.numOfNodes):
                tempList = []
                for j in range(self.numOfNodes):
                    tempList.append("y_{}_{}_{}".format(a,i,j))

                self.clauses.addClause(tempList)
                countClauses += 1


        for x in range(self.numOfTraceNodes):
            for a in range(len(self.traceTree.nodes[x].transitions)):
                for i in range(self.numOfNodes):
                    for j in range(self.numOfNodes):
                        if self.traceTree.nodes[x].transitions[a][0] != None:
                            self.clauses.addClause(["-y_{}_{}_{}".format(a,i,j), "-x_{}_{}".format(x,i), "x_{}_{}".format(self.traceTree.nodes[x].transitions[a][0].index, j)])
                            countClauses += 1

        if not self.usePyCrypto:
            varCount = self.numOfNodes * self.numOfTraceNodes * (len(self.traceTree.nodes[0].transitions) + 1)
            self.clauses.addFirstLine("p cnf {} {}".format(varCount, countClauses))
            self.clauses.writeTheFile()

    def constructSatFile(self, writeFile=True, filename="SatFile", verbose=True, numOfNodes=-1):

        #This is for debugging
        if numOfNodes != -1:
            self.numOfNodes = numOfNodes

        file = FileHandler(filename)
        clauseHandler = ClauseHandler(filename)

        countClauses = 0

        #
        #Each trace node must correspond to at least one node
        #Contains the 0_0 condition
        #
        
        if verbose:
            file.writeComment("##### Each must correspond to at least one node #####")
        for i in range(self.numOfTraceNodes):
            tempStr = ""
            commentStr = ""
            for k in range(self.numOfNodes): #min(i+1, self.numOfNodes
                tempStr += "{} ".format(self.varToNum("{}_{}".format(i, k)))
                commentStr += "{}_{} ".format(i, k)
            if verbose:
                file.writeComment(commentStr)
            file.write(tempStr)
            countClauses += 1

        #
        #Each trace node must correspond to at most one node
        #
        if verbose:
            file.writeComment("##### Each must correspond to at most one node #####")
        for i in range(self.numOfTraceNodes):
            for k in range(self.numOfNodes - 1):
                for j in range(k + 1, self.numOfNodes):
                    if verbose:
                        file.writeComment("-{}_{} -{}_{}".format(i, k, i, j))
                    file.write("-{} -{}".format(self.varToNum("{}_{}".format(i, k)), self.varToNum("{}_{}".format(i, j))))
                    countClauses += 1
        
        #
        #Check each trasition between tracenodes and act accordingly
        #
        for i in range(self.numOfTraceNodes - 1):
            for j in range(i + 1, self.numOfTraceNodes):
                for k in range(len(self.traceTree.nodes[i].transitions)):

                    # if outputs are not None
                    if (self.traceTree.nodes[i].transitions[k][1] != None and self.traceTree.nodes[j].transitions[k][1] != None):
                        #if outputs are different
                        if (self.traceTree.nodes[i].transitions[k][1] != self.traceTree.nodes[j].transitions[k][1]):

                            for h in range(self.numOfNodes):
                                if verbose:
                                    file.writeComment("-{}_{} -{}_{}".format(i, h, j, h))
                                file.write("-{} -{}".format(self.varToNum("{}_{}".format(i, h)), self.varToNum("{}_{}".format(j, h))))
                                countClauses += 1
                            break
                        #if outputs are same

                        """
                        # The Auxiliary variables will replace these
                        else:
                            iP = self.traceTree.nodes[i].transitions[k][0].index
                            jP = self.traceTree.nodes[j].transitions[k][0].index
                            for h in range(self.numOfNodes):
                                for hP in range(self.numOfNodes):
                                    if verbose:
                                        file.writeComment("-{}_{} -{}_{} -{}_{} {}_{}".format(i, h, j, h, iP, hP, jP, hP))
                                    file.write("-{} -{} -{} {}".format(self.varToNum("{}_{}".format(i, h)),\
                                                                        self.varToNum("{}_{}".format(j, h)),\
                                                                        self.varToNum("{}_{}".format(iP, hP)),\
                                                                        self.varToNum("{}_{}".format(jP, hP))))
                                    countClauses += 1
                        """
                        

        
        for x in range(self.numOfTraceNodes):
            for a in range(len(self.traceTree.nodes[x].transitions)):
                for i in range(self.numOfNodes):
                    for j in range(self.numOfNodes):
                        if self.traceTree.nodes[x].transitions[a][0] != None:
                            #file.write("y_a_i_j -x_i -x.transitions[a][0]_j")
                            file.writeComment("y_{}_{}_{} -{}_{} -{}_{}".format(a, i, j, \
                                                                                x, i, \
                                                                                self.traceTree.nodes[x].transitions[a][0].index, j))
                            file.write("{} -{} -{}".format(self.newVarToNum("y_{}_{}_{}".format(a, i, j)), \
                                                            self.varToNum("{}_{}".format(x, i)), \
                                                            self.varToNum("{}_{}".format(self.traceTree.nodes[x].transitions[a][0].index, j))))
                            countClauses += 1


        for a in range(len(self.traceTree.nodes[0].transitions)):
            for i in range(self.numOfNodes):
                for h in range(self.numOfNodes - 1):
                    for j in range(h + 1, self.numOfNodes):
                        file.writeComment("-y_{}_{}_{} -y_{}_{}_{}".format(a,i,h,\
                                                                            a,i,j))
                        file.write("-{} -{}".format(self.newVarToNum("-y_{}_{}_{}".format(a,i,h)), \
                                                    self.newVarToNum("-y_{}_{}_{}".format(a,i,j))))
                        countClauses += 1

        for a in range(len(self.traceTree.nodes[0].transitions)):
            for i in range(self.numOfNodes):
                tempStr = ""
                commentStr = ""
                for j in range(self.numOfNodes):
                    commentStr += "y_{}_{}_{} ".format(a, i, j)
                    tempStr += "{} ".format(self.newVarToNum("y_{}_{}_{}".format(a, i, j)))

                file.writeComment(commentStr)
                file.write(tempStr)
                countClauses += 1

        for x in range(self.numOfTraceNodes):
            for a in range(len(self.traceTree.nodes[x].transitions)):
                for i in range(self.numOfNodes):
                    for j in range(self.numOfNodes):
                        if self.traceTree.nodes[x].transitions[a][0] != None:
                            #file.write("y_a_i_j -x_i -x.transitions[a][0]_j")
                            file.writeComment("-y_{}_{}_{} -{}_{} {}_{}".format(a, i, j, \
                                                                                x, i, \
                                                                                self.traceTree.nodes[x].transitions[a][0].index, j))
                            file.write("-{} -{} {}".format(self.newVarToNum("y_{}_{}_{}".format(a, i, j)), \
                                                            self.varToNum("{}_{}".format(x, i)), \
                                                            self.varToNum("{}_{}".format(self.traceTree.nodes[x].transitions[a][0].index, j))))
                            countClauses += 1

        varCount = self.numOfNodes * self.numOfTraceNodes * (len(self.traceTree.nodes[0].transitions) + 1)
        file.addFirstLine("p cnf {} {}".format(varCount, countClauses))

        file.writeTheFile()

    def checkOutput(self, filename="satOutput", onlyCheck=False):

        with open(filename, "r") as f:
            
            #if the output has no solution, dont continue
            if f.readline().strip() != "s SATISFIABLE":
                return (False, None)

            #if the output has a solution BUT the onlyCheck parameter is true, dont read the output
            elif onlyCheck:
                return (True, None)

            output = []
            for line in f.readlines():
                for var in line.split()[1:]:
                    if var[0] != "-" and var != "0":
                        output.append(self.numToVar(int(var)))

            return (True, output)

    def addFormulasForSingleVariable(self, traceNo=0, numOfNodes=-1):
        if numOfNodes == -1:
            self.numOfNodes = numOfNodes
        if traceNo == 0:
            print("addFormulasForSingleVariable -> You should gice the traceNo to use this function!")
            return

        #This trace must be at least one of the nodes
        tempList = []
        for k in range(self.numOfNodes):
            tempList.append("x_{}_{}".format(traceNo, k))
        self.clauses.addClause(tempList)

        for k in range(self.numOfNodes - 1):
            for j in range(k + 1, self.numOfNodes):
                self.clauses.addClause(["-x_{}_{}".format(traceNo,k), "-x_{}_{}".format(traceNo,j)])
                countClauses += 1

        for j in range(traceNo):
            for k in range(len(self.traceTree.nodes[traceNo].transitions)):

                # if outputs are not None
                if (self.traceTree.nodes[traceNo].transitions[k][1] != None and self.traceTree.nodes[j].transitions[k][1] != None):
                    #if outputs are different
                    if (self.traceTree.nodes[traceNo].transitions[k][1] != self.traceTree.nodes[j].transitions[k][1]):

                        for h in range(self.numOfNodes):
                            self.clauses.addClause(["-x_{}_{}".format(traceNo,h), "-x_{}_{}".format(j,h)])
                            countClauses += 1
                        break

        for a in range(len(self.traceTree.nodes[traceNo].transitions)):
            for i in range(self.numOfNodes):
                for j in range(self.numOfNodes):
                    if self.traceTree.nodes[traceNo].transitions[a][0] != None:
                        #file.write("y_a_i_j -x_i -x.transitions[a][0]_j")
                        self.clauses.addClause(["y_{}_{}_{}".format(a,i,j), "-x_{}_{}".format(traceNo,i), "-x_{}_{}".format(self.traceTree.nodes[traceNo].transitions[a][0].index, j)])
                        countClauses += 1

        #Auxiliary for this
            for a in range(len(self.traceTree.nodes[traceNo].transitions)):
                for i in range(self.numOfNodes):
                    for j in range(self.numOfNodes):
                        if self.traceTree.nodes[traceNo].transitions[a][0] != None:
                            self.clauses.addClause(["-y_{}_{}_{}".format(a,i,j), "-x_{}_{}".format(traceNo,i), "x_{}_{}".format(self.traceTree.nodes[traceNo].transitions[a][0].index, j)])
                            countClauses += 1

    def oldfindFsmConsecutive(self, filename="SatFile", outputFile="satOutput"):

        self.numOfNodes = 1
        FOUND = False

        startTime = time.time()
        totalSolveTime = 0
        totalFileTime = 0

        while not FOUND:
            print("Trying with", self.numOfNodes, "nodes...")

            if self.usePyCrypto:
                self.s = Solver() #Reinitialize
                sTime = time.time()

                self.constructClauses()
                self.clauses.addToSolver(self.s)
                
                endTime = time.time()-sTime
                totalFileTime += endTime
                print("Sat preperations took", endTime, "seconds")

                sTime = time.time()
                isSatisfiable, output = self.s.solve()
                solveEnd = time.time()-sTime
                totalSolveTime += solveEnd
                print("Solving took", solveEnd, "seconds")

            else:
                sFileTime = time.time()
                
                self.constructClauses()
                
                fileEnd = time.time()-sFileTime
                totalFileTime += fileEnd
                print("Sat file construction took", fileEnd, "seconds")

                solveTime = time.time()
                subprocess.run("cryptominisat5 --verb 0 {} > {}".format(filename, outputFile), shell=True)
                solveEnd = time.time()-solveTime
                totalSolveTime += solveEnd
                print("Solving took", solveEnd, "seconds")

                isSatisfiable, output = self.checkOutput()

            if isSatisfiable:

                print("Satisfiable with", self.numOfNodes, "nodes!")
                print("Total construction time:", totalFileTime, "seconds")
                print("Total solving time:", totalSolveTime, "seconds")
                print("\nConsecutive approach took", time.time()-startTime, "seconds.")
                print()
                FOUND = True



            self.numOfNodes += 1

        return output

    def getFsmFromSolution(self, output):

        switchValueKey = {}
        for key, value in self.clauses.clausedict.items():
            if key[0] == "x":
                switchValueKey[value] = key[2:]

        #init the fsm
        traceFsm = FSM(self.numOfNodes, self.numOfInputs, 0)

        traceToNode = {}
        for i in range(1, len(output)):
            if output[i]: # only check true ones
                ## continu here



    def findFsmConsecutive(self, filename="SatFile", outputFile="satOutput"):

        self.numOfNodes = 7

        FOUND = False

        startTime = time.time()
        totalSolveTime = 0
        totalFileTime = 0

        while not FOUND:


    def findFsmBinary(self, filename="SatFile", outputFile="satOutput"):

        minNumOfNodes = 0
        maxNumOfNodes = -1

        self.numOfNodes = 1

        lastOutput = None
        FOUND = False

        startTime = time.time()

        while not FOUND:

            if maxNumOfNodes == -1:
                self.numOfNodes *= 2
            else:
                self.numOfNodes = int((minNumOfNodes + maxNumOfNodes)/2)

            print("Trying with", self.numOfNodes, "nodes...")

            self.constructSatFile()

            subprocess.run("cryptominisat5 --verb 0 {} > {}".format(filename, outputFile), shell=True)

            isSatisfiable, output = self.checkOutput()

            if isSatisfiable:
                print("Satisfiable with", self.numOfNodes, "nodes!")
                lastOutput = output
                maxNumOfNodes = self.numOfNodes

            else:
                print("Not satisfiable with", self.numOfNodes, "nodes.")
                minNumOfNodes = self.numOfNodes

            if minNumOfNodes == maxNumOfNodes - 1:
                print("Merged at", self.numOfNodes, "nodes!")

                print("\nBinary approach took", time.time()-startTime, "seconds.")

                FOUND = True

        return lastOutput

    def findFsm(self, tryBinarySearch=False, filename="SatFile", outputFile="satOutput"):

        if tryBinarySearch:
            return self.findFsmBinary(filename, outputFile)
        else:
            return self.findFsmConsecutive(filename, outputFile)
Beispiel #22
0
 def test_3_long2(self) :
     solver = Solver()
     solver.add_xor_clause([1, 2, 3], True)
     res, solution = solver.solve([1, -2])
     self.assertEqual(res, True)
     self.assertEqual(solution, (None, True, False, False))
Beispiel #23
0
class TestSolve(unittest.TestCase):
    def setUp(self):
        self.solver = Solver(threads=2)

    def test_wrong_args(self):
        self.assertRaises(TypeError, self.solver.add_clause, 'A')
        self.assertRaises(TypeError, self.solver.add_clause, 1)
        self.assertRaises(TypeError, self.solver.add_clause, 1.0)
        self.assertRaises(TypeError, self.solver.add_clause, object())
        self.assertRaises(TypeError, self.solver.add_clause, ['a'])
        self.assertRaises(TypeError, self.solver.add_clause,
                          [[1, 2], [3, None]])
        self.assertRaises(ValueError, self.solver.add_clause, [1, 0])

    def test_no_clauses(self):
        for _ in range(7):
            self.assertEqual(self.solver.solve([]), (True, (None, )))

    def test_cnf1(self):
        for cl in clauses1:
            self.solver.add_clause(cl)
        res, solution = self.solver.solve()
        self.assertEqual(res, True)
        self.assertTrue(check_solution(clauses1, solution))

    def test_add_clauses(self):
        self.solver.add_clauses([[1], [-1]])
        res, solution = self.solver.solve()
        self.assertEqual(res, False)

    def test_add_clauses_wrong_zero(self):
        self.assertRaises(TypeError, self.solver.add_clause, [[1, 0], [-1]])

    def test_add_clauses_array_SAT(self):
        cls = array('i', [1, 2, 0, 1, 2, 0])
        self.solver.add_clauses(cls)
        res, solution = self.solver.solve()
        self.assertEqual(res, True)

    def test_add_clauses_array_UNSAT(self):
        cls = array('i', [-1, 0, 1, 0])
        self.solver.add_clauses(cls)
        res, solution = self.solver.solve()
        self.assertEqual(res, False)

    def test_add_clauses_array_unterminated(self):
        cls = array('i', [1, 2, 0, 1, 2])
        self.assertRaises(ValueError, self.solver.add_clause, cls)

    def test_bad_iter(self):
        class Liar:
            def __iter__(self):
                return None

        self.assertRaises(TypeError, self.solver.add_clause, Liar())

    def test_get_conflict(self):
        self.solver.add_clauses([[-1], [2], [3], [-4]])
        assume = [-2, 3, 4]

        res, model = self.solver.solve(assumptions=assume)
        self.assertEqual(res, False)

        confl = self.solver.get_conflict()
        self.assertEqual(isinstance(confl, list), True)
        self.assertNotIn(3, confl)

        if 2 in confl:
            self.assertIn(2, confl)
        elif -4 in confl:
            self.assertIn(-4, confl)
        else:
            self.assertEqual(False,
                             True,
                             msg="Either -2 or 4 should be conflicting!")

        assume = [2, 4]
        res, model = self.solver.solve(assumptions=assume)
        self.assertEqual(res, False)

        confl = self.solver.get_conflict()
        self.assertEqual(isinstance(confl, list), True)
        self.assertNotIn(2, confl)
        self.assertIn(-4, confl)

    def test_cnf2(self):
        for cl in clauses2:
            self.solver.add_clause(cl)
        self.assertEqual(self.solver.solve(), (False, None))

    def test_cnf3(self):
        for cl in clauses3:
            self.solver.add_clause(cl)
        res, solution = self.solver.solve()
        self.assertEqual(res, True)
        self.assertTrue(check_solution(clauses3, solution))

    def test_cnf1_confl_limit(self):
        for _ in range(1, 20):
            self.setUp()
            for cl in clauses1:
                self.solver.add_clause(cl)

            res, solution = self.solver.solve()
            self.assertTrue(res is None or check_solution(clauses1, solution))

    def test_by_re_curse(self):
        self.solver.add_clause([-1, -2, 3])
        res, _ = self.solver.solve()
        self.assertEqual(res, True)

        self.solver.add_clause([-5, 1])
        self.solver.add_clause([4, -3])
        self.solver.add_clause([2, 3, 5])
        res, _ = self.solver.solve()
        self.assertEqual(res, True)
Beispiel #24
0
class TestSolve(unittest.TestCase):

    def setUp(self):
        self.solver = Solver(threads=2)

    def test_wrong_args(self):
        self.assertRaises(TypeError, self.solver.add_clause, 'A')
        self.assertRaises(TypeError, self.solver.add_clause, 1)
        self.assertRaises(TypeError, self.solver.add_clause, 1.0)
        self.assertRaises(TypeError, self.solver.add_clause, object())
        self.assertRaises(TypeError, self.solver.add_clause, ['a'])
        self.assertRaises(
            TypeError, self.solver.add_clause, [[1, 2], [3, None]])
        self.assertRaises(ValueError, self.solver.add_clause, [1, 0])

    def test_no_clauses(self):
        for _ in range(7):
            self.assertEqual(self.solver.solve([]), (True, (None,)))

    def test_cnf1(self):
        for cl in clauses1:
            self.solver.add_clause(cl)
        res, solution = self.solver.solve()
        self.assertEqual(res, True)
        self.assertTrue(check_solution(clauses1, solution))

    def test_add_clauses(self):
        self.solver.add_clauses([[1], [-1]])
        res, solution = self.solver.solve()
        self.assertEqual(res, False)

    def test_add_clauses_wrong_zero(self):
        self.assertRaises(TypeError, self.solver.add_clause, [[1, 0], [-1]])

    def test_add_clauses_array_SAT(self):
        cls = array('i', [1, 2, 0, 1, 2, 0])
        self.solver.add_clauses(cls)
        res, solution = self.solver.solve()
        self.assertEqual(res, True)

    def test_add_clauses_array_UNSAT(self):
        cls = array('i', [-1, 0, 1, 0])
        self.solver.add_clauses(cls)
        res, solution = self.solver.solve()
        self.assertEqual(res, False)

    def test_add_clauses_array_unterminated(self):
        cls = array('i', [1, 2, 0, 1, 2])
        self.assertRaises(ValueError, self.solver.add_clause, cls)

    def test_bad_iter(self):
        class Liar:

            def __iter__(self):
                return None
        self.assertRaises(TypeError, self.solver.add_clause, Liar())

    def test_get_conflict(self):
        self.solver.add_clauses([[-1], [2], [3], [-4]])
        assume = [-2, 3, 4]

        res, model = self.solver.solve(assumptions=assume)
        self.assertEqual(res, False)

        confl = self.solver.get_conflict()
        self.assertEqual(isinstance(confl, list), True)
        self.assertNotIn(3, confl)

        if 2 in confl:
            self.assertIn(2, confl)
        elif -4 in confl:
            self.assertIn(-4, confl)
        else:
            self.assertEqual(False, True, msg="Either -2 or 4 should be conflicting!")

        assume = [2, 4]
        res, model = self.solver.solve(assumptions=assume)
        self.assertEqual(res, False)

        confl = self.solver.get_conflict()
        self.assertEqual(isinstance(confl, list), True)
        self.assertNotIn(2, confl)
        self.assertIn(-4, confl)

    def test_cnf2(self):
        for cl in clauses2:
            self.solver.add_clause(cl)
        self.assertEqual(self.solver.solve(), (False, None))

    def test_cnf3(self):
        for cl in clauses3:
            self.solver.add_clause(cl)
        res, solution = self.solver.solve()
        self.assertEqual(res, True)
        self.assertTrue(check_solution(clauses3, solution))

    def test_cnf1_confl_limit(self):
        for _ in range(1, 20):
            self.setUp()
            for cl in clauses1:
                self.solver.add_clause(cl)

            res, solution = self.solver.solve()
            self.assertTrue(res is None or check_solution(clauses1, solution))

    def test_by_re_curse(self):
        self.solver.add_clause([-1, -2, 3])
        res, _ = self.solver.solve()
        self.assertEqual(res, True)

        self.solver.add_clause([-5, 1])
        self.solver.add_clause([4, -3])
        self.solver.add_clause([2, 3, 5])
        res, _ = self.solver.solve()
        self.assertEqual(res, True)
 def solve_sat(self):
     solver = Solver()
     for clause in self.formula:
         solver.add_clause(clause)
     sat, assignments = solver.solve()
     return sat
Beispiel #26
0
 def test_cnf2(self):
     solver = Solver()
     for cl in clauses2:
         solver.add_clause(cl)
     self.assertEqual(solver.solve(), (False, None))