示例#1
0
    def dominating_subset(self, k=1):
        """
        Check if there exists a vertex cover of, at most, k-vertices.
        Accepts as params:
        - n_color: number of color to check
        - verbose: whether or not print the process
        """
        if not self.edges():
            return []

        logging.info('\nCodifying SAT Solver...')
        solver = Solver(name='cd')
        vpool = IDPool()
        vertices_ids = [vpool.id(vertex) for vertex in self.vertices()]

        logging.info(' -> Codifying: Every vertex must be accessible')
        for vertex in self.vertices():
            solver.add_clause([vpool.id(vertex)] + [
                vpool.id(adjacent_vertex) for adjacent_vertex in self[vertex]
            ])

        logging.info(' -> Codifying: At most', k,
                     'vertices should be selected')

        cnf = CardEnc.atmost(lits=vertices_ids, bound=k, vpool=vpool)
        solver.append_formula(cnf)

        logging.info('Running SAT Solver...')
        return solver.solve()
def get_unsat_core_pysat(fmla, alpha=None):
    n_vars = fmla.nv
    vpool = IDPool(start_from=n_vars + 1)
    r = lambda i: vpool.id(i)
    new_fmla = fmla.copy()
    num_clauses = len(new_fmla.clauses)
    for count in list(range(0, num_clauses)):
        new_fmla.clauses[count].append(r(count))  # add r_i to the ith clause
    s = Solver(name="cdl")
    s.append_formula(new_fmla)
    asms = [-r(i) for i in list(range(0, num_clauses))]
    if alpha is not None:
        asms = asms + alpha
    if not s.solve(assumptions=asms):
        core_aux = s.get_core()
    else:  # TODO(jesse): better error handling
        raise Exception("formula is sat")
    # return list(filter(lambda x: x is not None, [vpool.obj(abs(r)) for r in core_aux]))
    result = []
    bad_asms = []
    for lit in core_aux:
        if abs(lit) > n_vars:
            result.append(vpool.obj(abs(lit)))
        else:
            bad_asms.append(lit)
    return result, bad_asms
示例#3
0
 def _order_parents_using_np_variables(self,
                                       solver: Solver,
                                       size: int,
                                       old_size: int = 0) -> None:
     for child in range(max(1, old_size - 1), size - 1):
         for parent in range(child):
             solver.append_formula(
                 _implication_to_clauses(
                     self._vars.var('p', child, parent),
                     self._vars.var('np', child + 1, parent - 1)))
示例#4
0
 def _preserve_parent_order_on_children(self,
                                        solver: Solver,
                                        size: int,
                                        old_size: int = 0) -> None:
     for child in range(max(2, old_size - 1), size - 1):
         for parent in range(1, child):
             for pre_parent in range(parent):
                 solver.append_formula(
                     _implication_to_clauses(
                         self._vars.var('p', child, parent),
                         -self._vars.var('p', child + 1, pre_parent)))
示例#5
0
 def _define_p_variables(self,
                         solver: Solver,
                         size: int,
                         old_size: int = 0) -> None:
     for child in range(old_size, size):
         for parent in range(child):
             solver.append_formula(
                 _iff_conjunction_to_clauses(
                     self._vars.var('p', child, parent),
                     tuple(-self._vars.var('t', prev, child)
                           for prev in range(parent)) +
                     (self._vars.var('t', parent, child), )))
示例#6
0
 def _define_t_variables(self,
                         solver: Solver,
                         size: int,
                         old_size: int = 0) -> None:
     for to in range(old_size, size):
         for from_ in range(to):
             solver.append_formula(
                 _iff_disjunction_to_clauses(
                     self._vars.var('t', from_, to),
                     tuple(
                         self._vars.var('y', from_, l_id, to)
                         for l_id in range(self._alphabet_size))))
示例#7
0
def find_truth(f, n, solver_name='cd', assumptions=None):
    out = []
    s = Solver(name=solver_name)
    s.append_formula(f.clauses)

    for idx, m in enumerate(s.enum_models(assumptions=assumptions)):
        out.append(m)
        if idx >= n - 1:
            break

    s.delete()
    return out
示例#8
0
 def _define_m_variables(self,
                         solver: Solver,
                         size: int,
                         old_size: int = 0) -> None:
     for child in range(old_size, size):
         for parent in range(child):
             for l_num in range(self._alphabet_size):
                 solver.append_formula(
                     _iff_conjunction_to_clauses(
                         self._vars.var('m', parent, l_num, child),
                         tuple(-self._vars.var('y', parent, l_less, child)
                               for l_less in range(l_num)) +
                         (self._vars.var('y', parent, l_num, child), )))
示例#9
0
 def _define_np_variables(self,
                          solver: Solver,
                          size: int,
                          old_size: int = 0) -> None:
     for child in range(max(old_size, 2), size):
         solver.append_formula(
             _iff_to_clauses(self._vars.var('np', child, 0),
                             -self._vars.var('p', child, 0)))
         for parent in range(1, child):
             solver.append_formula(
                 _iff_conjunction_to_clauses(
                     self._vars.var('np', child, parent),
                     (self._vars.var('np', child, parent - 1),
                      -self._vars.var('p', child, parent))))
示例#10
0
 def _define_p_variables_using_nt(self,
                                  solver: Solver,
                                  size: int,
                                  old_size: int = 0) -> None:
     for child in range(max(1, old_size), size):
         solver.append_formula(
             _iff_to_clauses(self._vars.var('p', child, 0),
                             self._vars.var('t', 0, child)))
         for parent in range(1, child):
             solver.append_formula(
                 _iff_conjunction_to_clauses(
                     self._vars.var('p', child, parent),
                     (self._vars.var('t', parent, child),
                      self._vars.var('nt', parent - 1, child))))
示例#11
0
def skeptical_entailment(KB, seed, q):
    # Check if KB entails a query

    s = Solver(name='g4')
    for k in KB:
        s.add_clause(k)
    for k in seed:
        s.add_clause(k)
    # add negation of query
    s.append_formula(q.negate().clauses)
    if s.solve() == False:
        return True
    else:
        return False
示例#12
0
 def _order_children_using_zm(self,
                              solver: Solver,
                              size: int,
                              old_size: int = 0) -> None:
     for child in range(max(0, old_size - 1), size - 1):
         for parent in range(child):
             for l_num in range(1, self._alphabet_size):
                 solver.append_formula(
                     _conjunction_implies_to_clauses(
                         (self._vars.var('p', child, parent),
                          self._vars.var('p', child + 1, parent),
                          self._vars.var('m', parent, l_num, child)),
                         self._vars.var('zm', parent, l_num - 1,
                                        child + 1)))
示例#13
0
def find_false(f, n, max_try=100, solver_name='cd'):
    if n == 0:
        return []
    nv = f.nv
    out = []
    s = Solver(name=solver_name)
    s.append_formula(f.clauses)
    tries = 0
    while len(out) < n and tries < max_try:
        assign = rand_assign(nv)
        if s.solve(assumptions=assign) is False:
            out.append(assign)
        tries += 1
    s.delete()
    return out
示例#14
0
def run_solver(clauses, solver_name):
    timing = time.time()
    sat_result = 'SAT'
    if solver_name == 'pycosat':
        result = pycosat.solve(clauses)
        if result == 'UNSAT':
            sat_result = 'UNSAT'
    else:
        f1 = CNF(from_file='clique.cnf')
        s = Solver(name=solver_name)
        s.append_formula(f1.clauses)
        s.solve()
        result = s.get_model()
        if result is None:
            sat_result = 'UNSAT'
    return sat_result, round(time.time() - timing, 2)
示例#15
0
 def _order_children_with_binary_alphabet(self,
                                          solver: Solver,
                                          size: int,
                                          old_size: int = 0) -> None:
     for child in range(max(0, old_size - 1), size - 1):
         for parent in range(child):
             solver.append_formula(
                 _conjunction_implies_to_clauses((
                     self._vars.var('p', child, parent),
                     self._vars.var('p', child + 1, parent),
                 ), self._vars.var('y', parent, 0, child)))
             solver.append_formula(
                 _conjunction_implies_to_clauses((
                     self._vars.var('p', child, parent),
                     self._vars.var('p', child + 1, parent),
                 ), self._vars.var('y', parent, 1, child + 1)))
示例#16
0
 def _define_zm_variables(self,
                          solver: Solver,
                          size: int,
                          old_size: int = 0) -> None:
     for child in range(old_size, size):
         for parent in range(child):
             solver.append_formula(
                 _iff_to_clauses(
                     self._vars.var('zm', parent, 0, child),
                     -self._vars.var('m', parent, 0, child),
                 ))
             for l_num in range(1, self._alphabet_size):
                 solver.append_formula(
                     _iff_conjunction_to_clauses(
                         self._vars.var('zm', parent, l_num, child), (
                             self._vars.var('zm', parent, l_num - 1, child),
                             -self._vars.var('m', parent, l_num, child),
                         )))
示例#17
0
    def coloring(self, n_color):
        """
        Returns whether or not there exists a vertex coloring
        of, at most, n_color colors.

        Accepts one param:
        - n_color: number of color to check

        Might raise ValueError exception.
        """
        if n_color < 0:
            raise ValueError('Number of colors must be positive integer')

        if n_color == 0:
            return not bool(self.vertices())

        logging.info('\nCodifying SAT Solver...')
        solver = Solver(name='cd')
        vpool = IDPool()

        logging.info(
            ' -> Codifying: Every vertex must have a color, and only one')
        for vertex in self.vertices():
            cnf = CardEnc.equals(lits=[
                vpool.id('{}color{}'.format(vertex, color))
                for color in range(n_color)
            ],
                                 vpool=vpool,
                                 encoding=0)

            solver.append_formula(cnf)

        logging.info(
            ' -> Codifying: No two neighbours can have the same color')
        for vertex in self.vertices():
            for neighbour in self[vertex]:
                for color in range(n_color):
                    solver.add_clause([
                        -vpool.id('{}color{}'.format(vertex, color)),
                        -vpool.id('{}color{}'.format(neighbour, color))
                    ])

        logging.info('Running SAT Solver...')
        return solver.solve()
示例#18
0
 def _state_status_compatible_with_node_status(
         self,
         solver: Solver,
         size: int,
         new_node_from: int = 0,
         old_size: int = 0,
         changed_statuses=None) -> None:
     if changed_statuses is None:
         changed_statuses = []
     for i in chain(range(new_node_from, self._apta.size),
                    changed_statuses):
         if self._apta.get_node(i).is_accepting():
             for j in range(old_size, size):
                 solver.append_formula(
                     _implication_to_clauses(self._vars.var('x', i, j),
                                             self._vars.var('z', j)))
         elif self._apta.get_node(i).is_rejecting():
             for j in range(old_size, size):
                 solver.append_formula(
                     _implication_to_clauses(self._vars.var('x', i, j),
                                             -self._vars.var('z', j)))
示例#19
0
 def _mapped_node_and_transition_force_mapping(self,
                                               solver: Solver,
                                               size: int,
                                               new_node_from: int = 0,
                                               old_size: int = 0) -> None:
     for parent in self._apta.nodes:
         for label, child in parent.children.items():
             if parent.id_ >= new_node_from or child.id_ >= new_node_from:
                 for from_ in range(old_size, size):
                     for to in range(old_size, size):
                         solver.append_formula(
                             _conjunction_implies_to_clauses((
                                 self._vars.var('x', parent.id_, from_),
                                 self._vars.var('y', from_, label, to),
                             ), self._vars.var('x', child.id_, to)))
                 if old_size > 0:
                     for from_ in range(old_size):
                         for to in range(old_size, size):
                             solver.append_formula(
                                 _conjunction_implies_to_clauses((
                                     self._vars.var('x', parent.id_, from_),
                                     self._vars.var('y', from_, label, to),
                                 ), self._vars.var('x', child.id_, to)))
                     for from_ in range(old_size, size):
                         for to in range(old_size):
                             solver.append_formula(
                                 _conjunction_implies_to_clauses((
                                     self._vars.var('x', parent.id_, from_),
                                     self._vars.var('y', from_, label, to),
                                 ), self._vars.var('x', child.id_, to)))
示例#20
0
    def __solve(quantifiers, formula, propagate):
        """
        Private method that implements the recursión that solve the problem.
        """
        if quantifiers:
            new_quant = copy.deepcopy(quantifiers)
            quantifier = new_quant.pop()

            if quantifier < 0:
                return NaiveQBF.__solve(
                    new_quant, formula,
                    propagate + [quantifier]) and NaiveQBF.__solve(
                        new_quant, formula, propagate + [-quantifier])
            if NaiveQBF.__solve(new_quant, formula, propagate + [quantifier]):
                return True
            return NaiveQBF.__solve(new_quant, formula,
                                    propagate + [quantifier])
        else:
            solver = Solver(name='cd')
            solver.append_formula(formula)
            print(formula.clauses)
            print(propagate)
            return solver.solve(assumptions=propagate)
示例#21
0
def find_false_worse(f, n, solver_name='cd', sat_at_most=[]):
    assert n == len(sat_at_most)
    nv = f.nv
    out = []

    while len(out) < n:
        clause = []
        i = len(out)
        flips = np.random.rand(1, len(f.clauses))
        for idx, c in enumerate(f.clauses):
            if not flips[0][idx] >= sat_at_most[i]:
                clause.append(c)
            else:
                for v in c:
                    clause.append([-v])

        s = Solver(name=solver_name)
        s.append_formula(clause)
        for _, m in enumerate(s.enum_models()):
            out.append(m)
            break
        s.delete()
    return out
示例#22
0
文件: ex2.py 项目: AlexTuisov/HW2
    def solve(self, solver_name="m22"):
        answers_dict = {}
        world_dynamics = self.world_dynamics()
        for q in self.queries:
            (i, j), t, s = q
            # create new solver and append world dynamics and query to it
            solver = Solver(name=solver_name)
            solver.append_formula(world_dynamics)
            solver.append_formula(self.read_observations())
            solver.append_formula(self.translate_query(q, state=True))

            assumptions = self.observations_to_assumptions()

            solution = solver.solve(assumptions=assumptions)

            if not solution:  # solution was false
                answers_dict[q] = 'F'
            else:  # solution was true
                other_states = ["Q", "U", "I", "H", "S"]
                other_states.remove(s)
                skip = False
                for other_state in other_states:
                    solver = Solver(name=solver_name)
                    solver.append_formula(world_dynamics)
                    solver.append_formula(self.read_observations())
                    q_new = (i, j), t, other_state
                    solver.append_formula(
                        self.translate_query(q_new, state=True))
                    assumptions = self.observations_to_assumptions()

                    solution = solver.solve(assumptions=assumptions)
                    if solution:  # check for ambiguity
                        answers_dict[q] = '?'
                        skip = True
                        break

                if not skip:
                    answers_dict[q] = 'T'

        return answers_dict
示例#23
0
    def solve_cnf_formula(self, solver=None, verbose=0):
        # corner case: looking for a circuit of size 0
        if self.number_of_gates == 0:
            for tt in self.output_truth_tables:
                f = BooleanFunction(tt)
                if not f.is_constant() and not f.is_any_literal():
                    return False

            return Circuit()

        self.finalize_cnf_formula()

        if solver is None:
            result = pycosat.solve(self.clauses, verbose=verbose)
            if result == 'UNSAT':
                return False
        elif solver == 'minisat':
            cnf_file_name = 'tmp.cnf'
            self.save_cnf_formula_to_file(cnf_file_name)
            # TODO: complete
            assert False
        elif solver == 'pysat':
            cnf_file_name = 'tmp.cnf'
            self.save_cnf_formula_to_file(cnf_file_name)
            f1 = CNF(from_file='tmp.cnf')
            s = Solver()
            s.append_formula(f1.clauses)
            s.solve()
            result = s.get_model()
            if result is None:
                return False
        else:
            assert False

        gate_descriptions = {}
        for gate in self.internal_gates:
            first_predecessor, second_predecessor = None, None
            for f, s in combinations(range(gate), 2):
                if self.predecessors_variable(gate, f, s) in result:
                    first_predecessor, second_predecessor = f, s
                else:
                    assert -self.predecessors_variable(gate, f, s) in result

            gate_type = []
            for p, q in product(range(2), repeat=2):
                if self.gate_type_variable(gate, p, q) in result:
                    gate_type.append(1)
                else:
                    assert -self.gate_type_variable(gate, p, q) in result
                    gate_type.append(0)

            first_predecessor = self.input_labels[first_predecessor] if first_predecessor in self.input_gates else first_predecessor
            second_predecessor = self.input_labels[second_predecessor] if second_predecessor in self.input_gates else second_predecessor
            gate_descriptions[gate] = (first_predecessor, second_predecessor, ''.join(map(str, gate_type)))

        output_gates = []
        for h in self.outputs:
            for gate in self.gates:
                if self.output_gate_variable(h, gate) in result:
                    output_gates.append(gate)

        return Circuit(self.input_labels, gate_descriptions, output_gates)
示例#24
0
文件: ex2.py 项目: AlexTuisov/HW2
def solve_problem(input):
    num_police = input["police"]
    num_medics = input["medics"]
    answer = {}
    if num_police == 0 and num_medics == 0:
        c = solve_problem1(input)
        s = Solver()
        s.append_formula(c)
        for q in input["queries"]:
            if not s.solve(assumptions=[get_n(q[2], q[1], q[0][0], q[0][1])]):  # if unsatisfiable
                answer[q] = 'F'
                continue
            elif s.solve(assumptions=[-get_n(q[2], q[1], q[0][0], q[0][1])]):  # if negation is satisfiable
                answer[q] = '?'
                continue
            answer[q] = 'T'
        s.delete()
        return answer
    else:
        c = solve_problem2(input)
        s = Solver()
        s.append_formula(c)
        for q in input["queries"]:
            if q[2] != 'S' and q[2] != 'Q':
                if not s.solve(assumptions=[get_k(q[2], q[1], q[0][0], q[0][1])]):  # if unsatisfiable
                    answer[q] = 'F'
                    continue
                elif s.solve(assumptions=[-get_k(q[2], q[1], q[0][0], q[0][1])]):  # if negation is satisfiable
                    answer[q] = '?'
                    continue
                answer[q] = 'T'
            elif q[2] == 'S':
                a1 = s.solve(assumptions=[get_k(q[2], q[1], q[0][0], q[0][1], 1)])
                a2 = s.solve(assumptions=[get_k(q[2], q[1], q[0][0], q[0][1], 2)])
                a3 = s.solve(assumptions=[get_k(q[2], q[1], q[0][0], q[0][1], 3)])
                if (not a1) and (not a2) and (not a3):
                    answer[q] = 'F'
                    continue
                else:
                    lst = []
                    if a1:
                        lst.append(-get_k(q[2], q[1], q[0][0], q[0][1], 1))
                    if a2:
                        lst.append(-get_k(q[2], q[1], q[0][0], q[0][1], 2))
                    if a3:
                        lst.append(-get_k(q[2], q[1], q[0][0], q[0][1], 3))
                    if s.solve(assumptions=lst):
                        answer[q] = '?'
                        continue
                answer[q] = 'T'
            elif q[2] == 'Q':
                a1 = s.solve(assumptions=[get_k(q[2], q[1], q[0][0], q[0][1], 1)])
                a2 = s.solve(assumptions=[get_k(q[2], q[1], q[0][0], q[0][1], 2)])
                if (not a1) and (not a2):
                    answer[q] = 'F'
                    continue
                else:
                    lst = []
                    if a1:
                        lst.append(-get_k(q[2], q[1], q[0][0], q[0][1], 1))
                    if a2:
                        lst.append(-get_k(q[2], q[1], q[0][0], q[0][1], 2))
                    if s.solve(assumptions=lst):
                        answer[q] = '?'
                        continue
                answer[q] = 'T'
        s.delete()
        return answer
示例#25
0
文件: ex2.py 项目: AlexTuisov/HW2
def solve_problem_t(input, T):

    n_police = input['police']
    n_medics = input['medics']
    observations = input['observations']
    queries = input['queries']
    H = len(observations[0])
    W = len(observations[0][0])

    vpool = IDPool()
    clauses = []
    tile = lambda code, r, c, t: '{0}{1}{2}_{3}'.format(code, r, c, t)
    CODES = ['U', 'H', 'S']
    ACTIONS = []
    if n_medics > 0:
        CODES.append('I')
        ACTIONS.append('medics')
    if n_police > 0:
        CODES.append('Q')
        ACTIONS.append('police')

    # create list of predicates and associate an integer in pySat for each predicate
    for t in range(T):
        for code in CODES + ACTIONS:
            for r in range(H):
                for c in range(W):
                    vpool.id(tile(code, r, c, t))

    # clauses for the known tiles in the observation, both positive and negative
    for t in range(T):
        curr_observ = observations[t]
        for r in range(H):
            for c in range(W):
                curr_code = curr_observ[r][c]
                for code in CODES:
                    if (code == curr_code) & (curr_code != '?'):
                        clauses.append(tile(code, r, c, t))
                    elif (code != curr_code) & (curr_code != '?'):
                        clauses.append('~' + tile(code, r, c, t))

    # Uxy_t ==> Uxy_t-1 pre-condition
    for t in range(1, T):
        for r in range(H):
            for c in range(W):
                clauses.append(
                    tile('U', r, c, t) + ' ==> ' + tile('U', r, c, t - 1))

    # Uxy_t ==> Uxy_t+1 add effect (no del effect)
    for t in range(T - 1):
        for r in range(H):
            for c in range(W):
                clauses.append(
                    tile('U', r, c, t) + ' ==> ' + tile('U', r, c, t + 1))

    # Ixy_t ==> Ixy_t-1 | (Hxy_t-1 & medicsxy_t-1)' - pre-condition of 'I'
    if n_medics > 0:
        for t in range(1, T):
            for r in range(H):
                for c in range(W):
                    #clauses.append(tile('I',r,c,t) + ' ==> (' + tile('I',r,c,t-1) + ' | ' + tile('H',r,c,t-1)+')')
                    clauses.append(tile('I',r,c,t) + ' ==> (' + tile('I',r,c,t-1) + \
                        ' | (' + tile('H',r,c,t-1) + ' & ' + tile('medics',r,c,t-1) + '))')

    # Ixy_t ==> Ixy_t+1 - add effect of 'I' (no del effect)
    if n_medics > 0:
        for t in range(T - 1):
            for r in range(H):
                for c in range(W):
                    clauses.append(
                        tile('I', r, c, t) + ' ==> ' + tile('I', r, c, t + 1))

    # Qxy_t ==> Qxy_t-1 | (Sxy_t-1 & policexy_t-1)' - pre-condition of 'Q'
    if n_police > 0:
        for t in range(1, T):
            for r in range(H):
                for c in range(W):
                    clauses.append(tile('Q',r,c,t) + ' ==> (' + tile('Q',r,c,t-1) + \
                        ' | (' + tile('S',r,c,t-1) + ' & ' + tile('police',r,c,t-1) + '))')

    # add and del effects of Qxy_t
    if n_police > 0:
        for t in range(T - 1):
            for r in range(H):
                for c in range(W):
                    if t < 1:
                        # Qxy_t ==> Qxy_t+1
                        clauses.append(
                            tile('Q', r, c, t) + ' ==> ' +
                            tile('Q', r, c, t + 1))
                    else:
                        # Qxy_t & ~Qxy_t-1 ==> Qxy_t+1
                        clauses.append('(' + tile('Q',r,c,t) + ' & ~' + tile('Q',r,c,t-1) + ')'\
                            + ' ==> ' + tile('Q',r,c,t+1))
                        # Qxy_t & Qxy_t-1 ==> Hxy_t+1
                        clauses.append('(' + tile('Q',r,c,t) +' & ' + tile('Q',r,c,t-1) + \
                            ') ==> ' + tile('H',r,c,t+1))
                        # Qxy_t & Qxy_t-1 ==> ~Qxy_t+1
                        clauses.append('(' + tile('Q',r,c,t) +' & ' + tile('Q',r,c,t-1) + \
                            ') ==> ~' + tile('Q',r,c,t+1))

    # precondition of S(x,y,t) is either S(x,y,t-1) or H(x,y,t-1) and at least one sick neighbor
    for t in range(1, T):
        for r in range(H):
            for c in range(W):
                #n_coords = get_neighbors(r,c,H,W)
                #curr_clause = tile('S',r,c,t) + ' ==> (' + tile('S',r,c,t-1) + ' | (' + tile('H',r,c,t-1) + ' & ('
                #for coord in n_coords:
                #    curr_clause+= tile('S',coord[0],coord[1],t-1) + ' | '
                #clauses.append(curr_clause[:-3] + ')))')
                clauses.append(
                    tile('S', r, c, t) + ' ==> (' + tile('S', r, c, t - 1) +
                    ' | ' + tile('H', r, c, t - 1) + ')')

    # add and del effects of Sxy_t
    for t in range(T - 1):
        for r in range(H):
            for c in range(W):
                if n_police > 0:
                    # Sxy_t & policexy_t ==> Qxy_t+1 - add effect of 'S' if there's police
                    clauses.append(
                        tile('S', r, c, t) + ' & ' + tile('police', r, c, t) +
                        ' ==> ' + tile('Q', r, c, t + 1))
                    # Sxy_t & policexy_t ==> ~Sxy_t+1 - del effect of 'S' if there's police
                    clauses.append(
                        tile('S', r, c, t) + ' & ' + tile('police', r, c, t) +
                        ' ==> ~' + tile('S', r, c, t + 1))
                    if t < 2:
                        # Sxy_t & ~policexy_t ==> Sxy_t+1
                        clauses.append('(' + tile('S',r,c,t) + ' & ' + tile('~police',r,c,t) + \
                            ') ==> ' + tile('S',r,c,t+1))
                    else:
                        # Sxy_t & ~policexy_t & ~Sxy_t-1 ==> Sxy_t+1
                        clauses.append('(' + tile('S',r,c,t) + ' & ~' + tile('police',r,c,t) + ' & ~' + tile('S',r,c,t-1) + ')'\
                            + ') ==> ' + tile('S',r,c,t+1))
                        # Sxy_t & ~policexy_t & Sxy_t-1 & ~Sxy_t-2 ==> Sxy_t+1
                        clauses.append('(' + tile('S',r,c,t) + ' & ~' + tile('police',r,c,t) + ' & ' + tile('S',r,c,t-1) + \
                            ' & ~' + tile('S',r,c,t-2) + ') ==> ' + tile('S',r,c,t+1))
                        # Sxy_t & ~policexy_t & Sxy_t-1 & Sxy_t-2 ==> Hxy_t+1
                        clauses.append('(' + tile('S',r,c,t) + ' & ~' + tile('police',r,c,t) + ' & ' + tile('S',r,c,t-1) + \
                            ' & ' + tile('S',r,c,t-2) + ') ==> ' + tile('H',r,c,t+1))
                        # Sxy_t & ~policexy_t & Sxy_t-1 & Sxy_t-2 ==> ~Sxy_t+1
                        clauses.append('(' + tile('S',r,c,t) + ' & ~' + tile('police',r,c,t) + ' & ' + tile('S',r,c,t-1) + \
                            ' & ' + tile('S',r,c,t-2) + ') ==> ~' + tile('S',r,c,t+1))
                else:
                    if t < 2:
                        # Sxy_t ==> Sxy_t+1
                        clauses.append(
                            tile('S', r, c, t) + ' ==> ' +
                            tile('S', r, c, t + 1))
                    else:
                        # Sxy_t & ~Sxy_t-1 ==> Sxy_t+1
                        clauses.append('(' + tile('S',r,c,t) + ' & ~' + tile('S',r,c,t-1) + ')'\
                            + ' ==> ' + tile('S',r,c,t+1))
                        # Sxy_t & Sxy_t-1 & ~Sxy_t-2 ==> Sxy_t+1
                        clauses.append('(' + tile('S',r,c,t) + ' & ' + tile('S',r,c,t-1) + \
                            ' & ~' + tile('S',r,c,t-2) + ') ==> ' + tile('S',r,c,t+1))
                        # Sxy_t & Sxy_t-1 & Sxy_t-2 ==> Hxy_t+1
                        clauses.append('(' + tile('S',r,c,t) + ' & ' + tile('S',r,c,t-1) + \
                            ' & ' + tile('S',r,c,t-2) + ' ) ==> ' + tile('H',r,c,t+1))
                        # Sxy_t & Sxy_t-1 & Sxy_t-2 ==> ~Sxy_t+1
                        clauses.append('(' + tile('S',r,c,t) + ' & ' + tile('S',r,c,t-1) + \
                            ' & ' + tile('S',r,c,t-2) + ' ) ==> ~' + tile('S',r,c,t+1))

    # pre-conditions of 'H'
    for t in range(1, T):
        for r in range(H):
            for c in range(W):
                if t < 3:
                    # Hxy_t ==> Hxy_t-1
                    clauses.append(
                        tile('H', r, c, t) + ' ==> ' + tile('H', r, c, t - 1))
                else:
                    # Hxy_t ==> Hxy_t-1 | (Sxy_t-1 & Sxy_t-2 & Sxy_t-3) | (Qxy_t-1, Qxy_t-2)
                    curr_clause = tile('H',r,c,t) + ' ==> ' + tile('H',r,c,t-1) + ' | (' + \
                        tile('S',r,c,t-1) + ' & ' + tile('S',r,c,t-2) + ' & ' + tile('S',r,c,t-3) + ')'
                    if n_police > 0:
                        curr_clause += ' | (' + tile(
                            'Q', r, c, t - 1) + ' & ' + tile('Q', r, c,
                                                             t - 2) + ')'
                    clauses.append(curr_clause)

    # add effect of 'H'
    for t in range(T - 1):
        for r in range(H):
            for c in range(W):
                n_coords = get_neighbors(r, c, H, W)
                if n_medics > 0:
                    # Hxy_t & medicsxy_t ==> Ixy_t+1
                    clauses.append(
                        tile('H', r, c, t) + ' & ' + tile('medics', r, c, t) +
                        ' ==> ' + tile('I', r, c, t + 1))
                    # Hxy_t & medicsxy_T ==> ~Hxy_t+1
                    clauses.append(
                        tile('H', r, c, t) + ' & ' + tile('medics', r, c, t) +
                        ' ==> ' + tile('~H', r, c, t + 1))
                    # Hxy_t & ~medicsxy_t & (at least one sick neighbor) ==> Sxy_t+1
                    curr_clause = tile('H', r, c, t) + ' & ' + tile(
                        '~medics', r, c, t) + ' & ('
                    for coord in n_coords:
                        if n_police > 0:  # if neighbors 'S' do not turn to 'Q' in the next turn
                            curr_clause += '(' + tile('S',coord[0],coord[1],t) + ' & ' + \
                                tile('~Q',coord[0],coord[1],t+1) + ') | '
                        else:
                            curr_clause += tile('S', coord[0], coord[1],
                                                t) + ' | '
                    curr_clause = curr_clause[:-3] + ') ==> ' + tile(
                        'S', r, c, t + 1)
                    clauses.append(curr_clause)
                    # Hxy_t & ~medicsxy_t & (at least one sick neighbor) ==> ~Hxy_t+1
                    curr_clause = tile('H', r, c, t) + ' & ' + tile(
                        '~medics', r, c, t) + ' & ('
                    for coord in n_coords:
                        curr_clause += tile('S', coord[0], coord[1], t) + ' | '
                    curr_clause = curr_clause[:-3] + ') ==> ' + tile(
                        '~H', r, c, t + 1)
                    clauses.append(curr_clause)
                    # Hxy_t & ~medicsxy_t & (no sick neighbors) ==> Hxy_t+1
                    curr_clause = []
                    curr_clause = tile('H', r, c, t) + ' & ' + tile(
                        '~medics', r, c, t) + ' & ('
                    for coord in n_coords:
                        curr_clause += tile('~S', coord[0], coord[1],
                                            t) + ' & '
                    curr_clause = curr_clause[:-3] + ') ==> ' + tile(
                        'H', r, c, t + 1)
                    clauses.append(curr_clause)
                else:
                    # Hxy_t & (at least one sick neighbor) ==> Sxy_t+1
                    curr_clause = tile('H', r, c, t) + ' & ('
                    for coord in n_coords:
                        if n_police > 0:
                            curr_clause += '(' + tile('S',coord[0],coord[1],t) + ' & ' + \
                                tile('~Q',coord[0],coord[1],t+1) + ') | '
                        else:
                            curr_clause += tile('S', coord[0], coord[1],
                                                t) + ' | '
                    curr_clause = curr_clause[:-3] + ') ==> ' + tile(
                        'S', r, c, t + 1)
                    clauses.append(curr_clause)
                    # Hxy_t & (at least one sick neighbor) ==> ~Hxy_t+1
                    curr_clause = tile('H', r, c, t) + ' & ('
                    for coord in n_coords:
                        curr_clause += tile('S', coord[0], coord[1], t) + ' | '
                    curr_clause = curr_clause[:-3] + ') ==> ' + tile(
                        '~H', r, c, t + 1)
                    clauses.append(curr_clause)
                    # Hxy_t & (no sick neighbors) ==> Hxy_t+1
                    curr_clause = tile('H', r, c, t) + ' & ('
                    for coord in n_coords:
                        curr_clause += tile('~S', coord[0], coord[1],
                                            t) + ' & '
                    curr_clause = curr_clause[:-3] + ') ==> ' + tile(
                        'H', r, c, t + 1)
                    clauses.append(curr_clause)

    ## Qxy_t ==> Qxy_t+1 - add effect of 'Q'
    #if n_police > 0:
    #    for t in range(T-1):
    #        for r in range(H):
    #            for c in range(W):
    #                clauses.append(tile('Q',r,c,t) + ' ==> ' + tile('Q',r,c,t+1))

    ## Sxy_t ==> Sxy_t-1 - precondition of 'S' if there's no sick tiles
    #for t in range(1,T):
    #    for r in range(H):
    #        for c in range(W):
    #            if T<3:
    #                clauses.append(tile('S',r,c,t) + ' ==> ' + tile('S',r,c,t-1))

    ## action-add effect of 'H' - if tile 'H' in (x,y,t) and no sick neighbors, then 'H' in (x,y,t+1)
    #for t in range(1,T):
    #    for r in range(H):
    #        for c in range(W):
    #            n_coords = get_neighbors(r,c,H,W)
    #            curr_clause = tile('H',r,c,t) + ' <== (' + tile('H',r,c,t-1)
    #            for coord in n_coords:
    #                curr_clause+= ' & ' + tile('~S',coord[0],coord[1],t-1)
    #            clauses.append(curr_clause + ')')

    ## Hxy_t ==> Hxy_t-1 - precondition of 'H' if there's no sick tiles
    #for t in range(1,T):
    #    for r in range(H):
    #        for c in range(W):
    #            if T<3:
    #                clauses.append(tile('H',r,c,t) + ' ==> ' + tile('H',r,c,t-1))

    # Hxy_t & medicsxy_t ==> Ixy_t+1 - add effect of 'H' if there's no sick
    #for t in range(T-1):
    #    for r in range(H):
    #        for c in range(W):
    #            clauses.append(tile('H',r,c,t) + ' & ' + tile('medics',r,c,t) + ' ==> ' + tile('I',r,c,t+1))
    #            # Hxy_t & medicsxy_t ==> ~Hxy_t+1 - del effect of 'H'
    #            clauses.append(tile('H',r,c,t) + ' & ' + tile('medics',r,c,t) + ' ==> ~' + tile('H',r,c,t+1))

    # a single tile can only contain one code, i.e. or 'H' or 'S' or 'U' or 'I' or 'Q'
    #exclude_clause = lambda code_1,code_2,r,c,t : '~{0}{1}{2}_{3} | ~{4}{5}{6}_{7}'.format(code_1,r,c,t,code_2,r,c,t)
    for t in range(T):
        for r in range(H):
            for c in range(W):
                literal_list = []
                for code in CODES:
                    literal_list.append('~' + tile(code, r, c, t))
                powerset_res = lim_powerset(literal_list, 2)
                for combo in powerset_res:
                    clauses.append(combo[0] + ' | ' + combo[1])
                    #CODES_reduced = []
                    #[CODES_reduced.append(code) if code != code_1 else '' for code in CODES]
                    #for code_2 in CODES_reduced:
                    #    clauses.append(exclude_clause(code_1,code_2,r,c,t))

    # medics is only valid for 'H' tiles
    if n_medics > 0:
        for code in CODES:
            for t in range(T):
                for r in range(H):
                    for c in range(W):
                        if code != 'H':
                            clauses.append(
                                tile(code, r, c, t) + ' ==> ' +
                                tile('~medics', r, c, t))

    if n_medics > 1:
        # medics has to be exactly n_medics times
        for t in range(T - 1):
            tile_coords = []
            for r in range(H):
                for c in range(W):
                    tile_coords.append(tile('medics', r, c, t))
            positive_tiles = lim_powerset(tile_coords, n_medics)
            curr_clause = '(('
            for combo in positive_tiles:
                for predicate in combo:
                    curr_clause += predicate + ' & '
                for curr_tile in tile_coords:
                    if curr_tile not in combo:
                        curr_clause += '~' + curr_tile + ' & '
                curr_clause = curr_clause[:-3] + ') | ('
            clauses.append(curr_clause[:-3] + ')')
    elif n_medics == 1:
        for t in range(T - 1):
            curr_clause = '('
            predicate_list = []
            for r in range(H):
                for c in range(W):
                    predicate_list.append(tile('~medics', r, c, t))
                    curr_clause += tile('medics', r, c, t) + ' | '
            clauses.append(curr_clause[:-3] + ')')
            powerset_res = lim_powerset(predicate_list, 2)
            for combo in powerset_res:
                clauses.append(combo[0] + ' | ' + combo[1])

    # police is only valid for 'S' tiles
    if n_police > 0:
        for code in CODES:
            for t in range(T):
                for r in range(H):
                    for c in range(W):
                        if code != 'S':
                            clauses.append(
                                tile(code, r, c, t) + ' ==> ' +
                                tile('~police', r, c, t))

    # police has to be exactly n_police times
    if n_police > 1:
        for t in range(T - 1):
            tile_coords = []
            for r in range(H):
                for c in range(W):
                    tile_coords.append(tile('police', r, c, t))
            positive_tiles = lim_powerset(tile_coords, n_police)
            curr_clause = '(('
            for combo in positive_tiles:
                for predicate in combo:
                    curr_clause += predicate + ' & '
                for curr_tile in tile_coords:
                    if curr_tile not in combo:
                        curr_clause += '~' + curr_tile + ' & '
                curr_clause = curr_clause[:-3] + ') | ('
            clauses.append(curr_clause[:-3] + ')')
    elif n_police == 1:
        for t in range(T - 1):
            curr_clause = '('
            predicate_list = []
            for r in range(H):
                for c in range(W):
                    predicate_list.append(tile('~police', r, c, t))
                    curr_clause += tile('police', r, c, t) + ' | '
            clauses.append(curr_clause[:-3] + ')')
            powerset_res = lim_powerset(predicate_list, 2)
            for combo in powerset_res:
                clauses.append(combo[0] + ' | ' + combo[1])

    clauses_in_cnf = all_clauses_in_cnf(clauses)
    clauses_in_pysat = all_clauses_in_pysat(clauses_in_cnf, vpool)
    s = Solver()
    s.append_formula(clauses_in_pysat)

    q_dict = dict()
    for q in queries:
        if VERBOSE:
            print('\n')
            print('Initial Observations')
            print_model(observations, T, H, W, n_police, n_medics)
            print('\n')
            print('Query')
            print(q)
            print('\n')
        res_list = []
        for code in CODES:
            clause_to_check = cnf_to_pysat(
                to_cnf(tile(code, q[0][0], q[0][1], q[1])), vpool)
            if s.solve(assumptions=clause_to_check):
                res_list.append(1)
                if VERBOSE:
                    print('Satisfiable for code=%s as follows:' % (code))
                    sat_observations = get_model(s.get_model(), vpool, T, H, W)
                    print_model(sat_observations, T, H, W, n_police, n_medics)
                    print('\n')
            else:
                res_list.append(0)
                if VERBOSE:
                    print('NOT Satisfiable for code=%s' % (code))
                    print('\n')
                    print(vpool.obj(s.get_core()[0]))
        if np.sum(res_list) == 1:
            if CODES[res_list.index(1)] == q[2]:
                q_dict[q] = 'T'
            else:
                q_dict[q] = 'F'
        else:
            q_dict[q] = '?'

    return q_dict
示例#26
0
def closest_string(bitarray_list, distance=4):
    """
    Return if a bitarray exists of distance at most 'distance'.
    Use example:

    s1=bitarray('0010')
    s2=bitarray('0011')
    closest_string([s1,s2], distance=0)
    > False
    closest_string([s1,s2], distance=2)
    > True
    """
    if distance < 0:
        raise ValueError('Distance must be positive integer')

    logging.info('\nCodifying SAT Solver...')

    length = max(len(bit_arr) for bit_arr in bitarray_list)
    solver = Solver(name='mcm')
    vpool = IDPool()
    local_list = bitarray_list.copy()

    logging.info(' -> Codifying: normalizing strings')
    for index, bitarr in enumerate(bitarray_list):
        aux = (length - len(bitarr)) * bitarray('0')
        local_list[index] = bitarr + aux

    logging.info(' -> Codifying: imposing distance condition')
    for index, word in enumerate(local_list):
        for pos in range(length):
            vpool.id(ut.xvar(index, pos))

    for pos in range(length):
        vpool.id(ut.yvar(pos))

    for index, word in enumerate(local_list):
        for pos in range(length):
            vpool.id(ut.zvar(index, pos))

    for index, word in enumerate(local_list):
        for pos in range(length):
            for clause in ut.triple_equal(ut.xvar(index, pos),
                                          ut.yvar(pos),
                                          ut.zvar(index, pos),
                                          vpool=vpool):
                solver.add_clause(clause)
        cnf = CardEnc.atleast(
            lits=[vpool.id(ut.zvar(index, pos)) for pos in range(length)],
            bound=length - distance,
            vpool=vpool)
        solver.append_formula(cnf)

    logging.info(' -> Codifying: Words Value')
    assumptions = []
    for index, word in enumerate(local_list):
        for pos in range(length):
            assumptions += [
                vpool.id(ut.xvar(index, pos)) * (-1)**(not word[pos])
            ]

    logging.info('Running SAT Solver...')
    return solver.solve(assumptions=assumptions)
示例#27
0
class MXExplainer(object):
    """
        An SMT-inspired minimal explanation extractor for XGBoost models.
    """
    def __init__(self, formula, intvs, imaps, ivars, feats, nof_classes,
                 options, xgb):
        """
            Constructor.
        """

        self.feats = feats
        self.intvs = intvs
        self.imaps = imaps
        self.ivars = ivars
        self.nofcl = nof_classes
        self.optns = options
        self.idmgr = IDPool()
        self.fcats = []

        # saving XGBooster
        self.xgb = xgb

        self.verbose = self.optns.verb

        # MaxSAT-based oracles
        self.oracles = {}
        if self.optns.encode == 'mxa':
            ortype = 'alien'
        elif self.optns.encode == 'mxe':
            ortype = 'ext'
        else:
            ortype = 'int'
        for clid in range(nof_classes):
            self.oracles[clid] = MXReasoner(formula,
                                            clid,
                                            solver=self.optns.solver,
                                            oracle=ortype,
                                            am1=self.optns.am1,
                                            exhaust=self.optns.exhaust,
                                            minz=self.optns.minz,
                                            trim=self.optns.trim)

        # a reference to the current oracle
        self.oracle = None

        # SAT-based predictor
        self.poracle = SATSolver(name='g3')
        for clid in range(nof_classes):
            self.poracle.append_formula(formula[clid].formula)

        # determining which features should go hand in hand
        categories = collections.defaultdict(lambda: [])
        for f in self.xgb.extended_feature_names_as_array_strings:
            if f in self.ivars:
                if '_' in f or len(self.ivars[f]) == 2:
                    categories[f.split('_')[0]].append(
                        self.xgb.mxe.vpos[self.ivars[f][0]])
                else:
                    for v in self.ivars[f]:
                        # this has to be checked and updated
                        categories[f].append(self.xgb.mxe.vpos[abs(v)])

        # these are the result indices of features going together
        self.fcats = [[min(ftups), max(ftups)]
                      for ftups in categories.values()]
        self.fcats_copy = self.fcats[:]

        # all used feature categories
        self.allcats = list(range(len(self.fcats)))

        # variable to original feature index in the sample
        self.v2feat = {}
        for var in self.xgb.mxe.vid2fid:
            feat, ub = self.xgb.mxe.vid2fid[var]
            self.v2feat[var] = int(feat.split('_')[0][1:])

        # number of oracle calls involved
        self.calls = 0

    def __del__(self):
        """
            Destructor.
        """

        self.delete()

    def delete(self):
        """
            Actual destructor.
        """

        # deleting MaxSAT-based reasoners
        if self.oracles:
            for clid, oracle in self.oracles.items():
                if oracle:
                    oracle.delete()
            self.oracles = {}
        self.oracle = None

        # deleting the SAT-based predictor
        if self.poracle:
            self.poracle.delete()
            self.poracle = None

    def predict(self, sample):
        """
            Run the encoding and determine the corresponding class.
        """

        # translating sample into assumption literals
        self.hypos = self.xgb.mxe.get_literals(sample)

        # variable to the category in use; this differs from
        # v2feat as here we may not have all the features here
        self.v2cat = {}
        for i, cat in enumerate(self.fcats):
            for v in range(cat[0], cat[1] + 1):
                self.v2cat[self.hypos[v]] = i

        # running the solver to propagate the prediction;
        # using solve() instead of propagate() to be able to extract a model
        assert self.poracle.solve(
            assumptions=self.hypos), 'Formula must be satisfiable!'
        model = self.poracle.get_model()

        # computing all the class scores
        scores = {}
        for clid in range(self.nofcl):
            # computing the value for the current class label
            scores[clid] = 0

            for lit, wght in self.xgb.mxe.enc[clid].leaves:
                if model[abs(lit) - 1] > 0:
                    scores[clid] += wght

        # returning the class corresponding to the max score
        return max(list(scores.items()), key=lambda t: t[1])[0]

    def prepare(self, sample):
        """
            Prepare the oracle for computing an explanation.
        """

        # first, we need to determine the prediction, according to the model
        self.out_id = self.predict(sample)

        # selecting the right oracle
        self.oracle = self.oracles[self.out_id]

        # transformed sample
        self.sample = list(self.xgb.transform(sample)[0])

        # correct class id (corresponds to the maximum computed)
        self.output = self.xgb.target_name[self.out_id]

        if self.verbose:
            inpvals = self.xgb.readable_sample(sample)

            self.preamble = []
            for f, v in zip(self.xgb.feature_names, inpvals):
                if f not in str(v):
                    self.preamble.append('{0} == {1}'.format(f, v))
                else:
                    self.preamble.append(str(v))

            print('  explaining:  "IF {0} THEN {1}"'.format(
                ' AND '.join(self.preamble), self.output))

    def explain(self, sample, smallest, expl_ext=None, prefer_ext=False):
        """
            Hypotheses minimization.
        """

        start_mem = resource.getrusage(resource.RUSAGE_CHILDREN).ru_maxrss + \
                    resource.getrusage(resource.RUSAGE_SELF).ru_maxrss

        self.time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \
                resource.getrusage(resource.RUSAGE_SELF).ru_utime

        # adapt the solver to deal with the current sample
        self.prepare(sample)

        if self.optns.encode != 'mxe':
            # dummy call with the full instance to detect all the necessary cores
            self.oracle.get_coex(self.hypos,
                                 full_instance=True,
                                 early_stop=True)

        # calling the actual explanation procedure
        self._explain(sample,
                      smallest=smallest,
                      xtype=self.optns.xtype,
                      xnum=self.optns.xnum,
                      unit_mcs=self.optns.unit_mcs,
                      reduce_=self.optns.reduce)

        self.time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \
                resource.getrusage(resource.RUSAGE_SELF).ru_utime - self.time

        self.used_mem = resource.getrusage(resource.RUSAGE_CHILDREN).ru_maxrss + \
                    resource.getrusage(resource.RUSAGE_SELF).ru_maxrss - start_mem

        if self.verbose:
            for expl in self.expls:
                hyps = list(
                    reduce(lambda x, y: x + self.hypos[y[0]:y[1] + 1],
                           [self.fcats[c] for c in expl], []))
                expl = sorted(set(map(lambda v: self.v2feat[v], hyps)))
                preamble = [self.preamble[i] for i in expl]
                label = self.xgb.target_name[self.out_id]

                if self.optns.xtype in ('contrastive', 'con'):
                    preamble = [l.replace('==', '!=') for l in preamble]
                    label = 'NOT {0}'.format(label)

                print('  explanation: "IF {0} THEN {1}"'.format(
                    ' AND '.join(preamble), label))
                print('  # hypos left:', len(expl))

            print('  calls:', self.calls)
            print('  rtime: {0:.2f}'.format(self.time))

        return self.expls

    def _explain(self,
                 sample,
                 smallest=True,
                 xtype='abd',
                 xnum=1,
                 unit_mcs=False,
                 reduce_='none'):
        """
            Compute an explanation.
        """

        if xtype in ('abductive', 'abd'):
            # abductive explanations => MUS computation and enumeration
            if not smallest and xnum == 1:
                self.expls = [self.extract_mus(reduce_=reduce_)]
            else:
                self.mhs_mus_enumeration(xnum, smallest=smallest)
        else:  # contrastive explanations => MCS enumeration
            self.mhs_mcs_enumeration(xnum, smallest, reduce_)

    def extract_mus(self, reduce_='lin', start_from=None):
        """
            Compute one abductive explanation.
        """
        def _do_linear(core):
            """
                Do linear search.
            """
            def _assump_needed(a):
                if len(to_test) > 1:
                    to_test.remove(a)
                    self.calls += 1
                    # actual binary hypotheses to test
                    if not self.oracle.get_coex(self._cats2hypos(to_test),
                                                early_stop=True):
                        return False
                    to_test.add(a)
                    return True
                else:
                    return True

            to_test = set(core)
            return list(filter(lambda a: _assump_needed(a), core))

        def _do_quickxplain(core):
            """
                Do QuickXplain-like search.
            """

            wset = core[:]
            filt_sz = len(wset) / 2.0
            while filt_sz >= 1:
                i = 0
                while i < len(wset):
                    to_test = wset[:i] + wset[(i + int(filt_sz)):]
                    # actual binary hypotheses to test
                    self.calls += 1
                    if to_test and not self.oracle.get_coex(
                            self._cats2hypos(to_test), early_stop=True):
                        # assumps are not needed
                        wset = to_test
                    else:
                        # assumps are needed => check the next chunk
                        i += int(filt_sz)
                # decreasing size of the set to filter
                filt_sz /= 2.0
                if filt_sz > len(wset) / 2.0:
                    # next size is too large => make it smaller
                    filt_sz = len(wset) / 2.0
            return wset

        self.fcats = self.fcats_copy[:]

        # this is our MUS over-approximation
        if start_from is None:
            assert self.oracle.get_coex(
                self.hypos, full_instance=True,
                early_stop=True) == None, 'No prediction'

            # getting the core
            core = self.oracle.get_reason(self.v2cat)
        else:
            core = start_from

        if self.verbose > 2:
            print('core:', core)

        self.calls = 1  # we have already made one call

        if reduce_ == 'qxp':
            expl = _do_quickxplain(core)
        else:  # by default, linear MUS extraction is used
            expl = _do_linear(core)

        return expl

    def mhs_mus_enumeration(self, xnum, smallest=False):
        """
            Enumerate subset- and cardinality-minimal explanations.
        """

        # result
        self.expls = []

        # just in case, let's save dual (contrastive) explanations
        self.duals = []

        with Hitman(bootstrap_with=[self.allcats],
                    htype='sorted' if smallest else 'lbx') as hitman:
            # computing unit-size MCSes
            if self.optns.unit_mcs:
                for c in self.allcats:
                    self.calls += 1
                    if self.oracle.get_coex(
                            self._cats2hypos(self.allcats[:c] +
                                             self.allcats[(c + 1):]),
                            early_stop=True):
                        hitman.hit([c])
                        self.duals.append([c])

            # main loop
            iters = 0
            while True:
                hset = hitman.get()
                iters += 1

                if self.verbose > 2:
                    print('iter:', iters)
                    print('cand:', hset)

                if hset == None:
                    break

                self.calls += 1
                hypos = self._cats2hypos(hset)
                coex = self.oracle.get_coex(hypos, early_stop=True)
                if coex:
                    to_hit = []
                    satisfied, unsatisfied = [], []

                    removed = list(set(self.hypos).difference(set(hypos)))

                    for h in removed:
                        if coex[abs(h) - 1] != h:
                            unsatisfied.append(self.v2cat[h])
                        else:
                            hset.append(self.v2cat[h])

                    unsatisfied = list(set(unsatisfied))
                    hset = list(set(hset))

                    # computing an MCS (expensive)
                    for h in unsatisfied:
                        self.calls += 1
                        if self.oracle.get_coex(self._cats2hypos(hset + [h]),
                                                early_stop=True):
                            hset.append(h)
                        else:
                            to_hit.append(h)

                    if self.verbose > 2:
                        print('coex:', to_hit)

                    hitman.hit(to_hit)

                    self.duals.append([to_hit])
                else:
                    if self.verbose > 2:
                        print('expl:', hset)

                    self.expls.append(hset)

                    if len(self.expls) != xnum:
                        hitman.block(hset)
                    else:
                        break

    def mhs_mcs_enumeration(self,
                            xnum,
                            smallest=False,
                            reduce_='none',
                            unit_mcs=False):
        """
            Enumerate subset- and cardinality-minimal contrastive explanations.
        """

        # result
        self.expls = []

        # just in case, let's save dual (abductive) explanations
        self.duals = []

        with Hitman(bootstrap_with=[self.allcats],
                    htype='sorted' if smallest else 'lbx') as hitman:
            # computing unit-size MUSes
            for c in self.allcats:
                self.calls += 1

                if not self.oracle.get_coex(self._cats2hypos([c]),
                                            early_stop=True):
                    hitman.hit([c])
                    self.duals.append([c])
                elif unit_mcs and self.oracle.get_coex(
                        self._cats2hypos(self.allcats[:c] +
                                         self.allcats[(c + 1):]),
                        early_stop=True):
                    # this is a unit-size MCS => block immediately
                    self.calls += 1
                    hitman.block([c])
                    self.expls.append([c])

            # main loop
            iters = 0
            while True:
                hset = hitman.get()
                iters += 1

                if self.verbose > 2:
                    print('iter:', iters)
                    print('cand:', hset)

                if hset == None:
                    break

                self.calls += 1
                if not self.oracle.get_coex(self._cats2hypos(
                        set(self.allcats).difference(set(hset))),
                                            early_stop=True):
                    to_hit = self.oracle.get_reason(self.v2cat)

                    if len(to_hit) > 1 and reduce_ != 'none':
                        to_hit = self.extract_mus(reduce_=reduce_,
                                                  start_from=to_hit)

                    self.duals.append(to_hit)

                    if self.verbose > 2:
                        print('coex:', to_hit)

                    hitman.hit(to_hit)
                else:
                    if self.verbose > 2:
                        print('expl:', hset)

                    self.expls.append(hset)

                    if len(self.expls) != xnum:
                        hitman.block(hset)
                    else:
                        break

    def _cats2hypos(self, scats):
        """
            Translate selected categories into propositional hypotheses.
        """

        return list(
            reduce(lambda x, y: x + self.hypos[y[0]:y[1] + 1],
                   [self.fcats[c] for c in scats], []))

    def _hypos2cats(self, hypos):
        """
            Translate propositional hypotheses into a list of categories.
        """

        pass
示例#28
0
文件: ex2.py 项目: AlexTuisov/HW2
def check_queries(queries, input):
    solutions = {}

    num_police = input['police']
    num_medics = input['medics']
    observations = input['observations']

    encoded_pop = {}
    encoded_actions = {}

    board_size = len(observations[0]) * len(observations[0][0])
    row_len = len(observations[0])
    pop_and_actions_num = 5 + 8 + 5
    pops_encoded = {
        pop: i + 1
        for i, pop in enumerate(['H', 'S', 'Q', 'I', 'U'])
    }
    actions_encoded = {
        action: i + 6
        for i, action in enumerate([
            'VA', 'QR', 'IR', 'IL', 'ID', 'IU', 'HS', 'HQ', 'NOH', 'NOS',
            'NOQ', 'NOI', 'NOU'
        ])
    }

    KB = []

    for turn, observation in enumerate(observations):
        for i, row in enumerate(observation):
            for j, pop in enumerate(row):
                for population in pops_encoded.keys():
                    encoded_pop[(turn, i, j, population)] = turn*board_size*pop_and_actions_num + \
                                                            (i*row_len + j)*pop_and_actions_num + pops_encoded[population]
                if turn < len(observations) - 1:
                    for action in actions_encoded.keys():
                        if (action == 'HS' or action == 'HQ') and turn < 2:
                            continue
                        encoded_actions[(turn, i, j, action)] = turn*board_size*pop_and_actions_num + \
                                                                (i*row_len + j)*pop_and_actions_num + actions_encoded[action]
                # Known facts:
                if pop != '?':
                    KB.append([encoded_pop[(turn, i, j, pop)]])
                    for p in pops_encoded.keys():
                        if not p == pop:
                            KB.append([-encoded_pop[(turn, i, j, p)]])

                else:
                    KB.append([
                        encoded_pop[(turn, i, j, population)]
                        for population in pops_encoded.keys()
                    ])
                    combs = itertools.combinations(pops_encoded.keys(), 2)
                    for comb in combs:
                        KB.append([
                            -encoded_pop[(turn, i, j, comb[0])],
                            -encoded_pop[(turn, i, j, comb[1])]
                        ])
                    # First turn can't have 'I's and 'Q's
                    if turn == 0:
                        KB.append([-encoded_pop[(0, i, j, 'Q')]])
                        KB.append([-encoded_pop[(0, i, j, 'I')]])

                # no medics no vacc:
                if not num_medics:
                    KB.append([-encoded_pop[(turn, i, j, 'I')]])
                # no police no quar:
                if not num_police:
                    KB.append([-encoded_pop[(turn, i, j, 'Q')]])

    grid_size = (len(observations[0]), len(observations[0][0]))

    # Action Precondition Clauses
    for action_key, action_value in encoded_actions.items():
        (turn, i, j, action) = action_key

        if num_medics and action == 'VA':
            KB.append([-action_value, encoded_pop[(turn, i, j, 'H')]])

        if num_police and action == 'QR':
            KB.append([-action_value, encoded_pop[(turn, i, j, 'S')]])

        if j < grid_size[1] - 1 and action == 'IR':
            KB += [[-action_value, encoded_pop[(turn, i, j, 'H')]],
                   [-action_value, encoded_pop[(turn, i, j + 1, 'S')]]]

        if j > 0 and action == 'IL':
            KB += [[-action_value, encoded_pop[(turn, i, j, 'H')]],
                   [-action_value, encoded_pop[(turn, i, j - 1, 'S')]]]

        if i < grid_size[0] - 1 and action == 'ID':
            KB += [[-action_value, encoded_pop[(turn, i, j, 'H')]],
                   [-action_value, encoded_pop[(turn, i + 1, j, 'S')]]]

        if i > 0 and action == 'IU':
            KB += [[-action_value, encoded_pop[(turn, i, j, 'H')]],
                   [-action_value, encoded_pop[(turn, i - 1, j, 'S')]]]

        if action == 'HS':
            KB += [[-action_value, encoded_pop[(turn, i, j, 'S')]],
                   [-action_value, encoded_pop[(turn - 1, i, j, 'S')]],
                   [-action_value, encoded_pop[(turn - 2, i, j, 'S')]]]
        if action == 'HQ':
            KB += [[-action_value, encoded_pop[(turn, i, j, 'Q')]],
                   [-action_value, encoded_pop[(turn - 1, i, j, 'Q')]]]

        if action == 'NOH':
            KB.append([-action_value, encoded_pop[(turn, i, j, 'H')]])
            if j < grid_size[1] - 1:
                KB.append([
                    -action_value, -encoded_pop[(turn, i, j + 1, 'S')],
                    encoded_actions[(turn, i, j + 1, 'QR')]
                ])
            if j > 0:
                KB.append([
                    -action_value, -encoded_pop[(turn, i, j - 1, 'S')],
                    encoded_actions[(turn, i, j - 1, 'QR')]
                ])
            if i < grid_size[0] - 1:
                KB.append([
                    -action_value, -encoded_pop[(turn, i + 1, j, 'S')],
                    encoded_actions[(turn, i + 1, j, 'QR')]
                ])
            if i > 0:
                KB.append([
                    -action_value, -encoded_pop[(turn, i - 1, j, 'S')],
                    encoded_actions[(turn, i - 1, j, 'QR')]
                ])

        if action == 'NOS':
            KB.append([-action_value, encoded_pop[(turn, i, j, 'S')]])
            if turn >= 2:
                KB.append([
                    -action_value, -encoded_pop[(turn - 1, i, j, 'S')],
                    -encoded_pop[(turn - 2, i, j, 'S')]
                ])

        if action == 'NOQ':
            KB.append([-action_value, encoded_pop[(turn, i, j, 'Q')]])
            if turn >= 2:
                KB.append([-action_value, -encoded_pop[(turn - 1, i, j, 'Q')]])

        if action == 'NOI':
            KB.append([-action_value, encoded_pop[(turn, i, j, 'I')]])

        if action == 'NOU':
            KB.append([-action_value, encoded_pop[(turn, i, j, 'U')]])

    # Action Interference Clauses
    for turn, observation in enumerate(observations):
        if turn < len(observations) - 1:
            for i, row in enumerate(observation):
                for j, pop in enumerate(row):
                    for action_key, action_code in actions_encoded.items():
                        if action_key != 'VA':
                            if turn >= 2 or (action_key != 'HS'
                                             and action_key != 'HQ'):
                                KB.append([
                                    -encoded_actions[(turn, i, j, 'VA')],
                                    -encoded_actions[(turn, i, j, action_key)]
                                ])
                        if action_key != 'QR':
                            if turn >= 2 or (action_key != 'HS'
                                             and action_key != 'HQ'):
                                KB.append([
                                    -encoded_actions[(turn, i, j, 'QR')],
                                    -encoded_actions[(turn, i, j, action_key)]
                                ])
                        if action_key != 'NOH':
                            if turn >= 2 or (action_key != 'HS'
                                             and action_key != 'HQ'):
                                KB.append([
                                    -encoded_actions[(turn, i, j, 'NOH')],
                                    -encoded_actions[(turn, i, j, action_key)]
                                ])
                        if action_key != 'NOS':
                            if turn >= 2 or (action_key != 'HS'
                                             and action_key != 'HQ'):
                                KB.append([
                                    -encoded_actions[(turn, i, j, 'NOS')],
                                    -encoded_actions[(turn, i, j, action_key)]
                                ])
                        if action_key != 'NOQ':
                            if turn >= 2 or (action_key != 'HS'
                                             and action_key != 'HQ'):
                                KB.append([
                                    -encoded_actions[(turn, i, j, 'NOQ')],
                                    -encoded_actions[(turn, i, j, action_key)]
                                ])
                        if action_key != 'NOI':
                            if turn >= 2 or (action_key != 'HS'
                                             and action_key != 'HQ'):
                                KB.append([
                                    -encoded_actions[(turn, i, j, 'NOI')],
                                    -encoded_actions[(turn, i, j, action_key)]
                                ])
                        if action_key != 'NOU':
                            if turn >= 2 or (action_key != 'HS'
                                             and action_key != 'HQ'):
                                KB.append([
                                    -encoded_actions[(turn, i, j, 'NOU')],
                                    -encoded_actions[(turn, i, j, action_key)]
                                ])

                    if j < grid_size[1] - 1:
                        KB.append([
                            -encoded_actions[(turn, i, j, 'IR')],
                            -encoded_actions[(turn, i, j + 1, 'QR')]
                        ])

                    if j > 0:
                        KB.append([
                            -encoded_actions[(turn, i, j, 'IL')],
                            -encoded_actions[(turn, i, j - 1, 'QR')]
                        ])

                    if i < grid_size[0] - 1:
                        KB.append([
                            -encoded_actions[(turn, i, j, 'ID')],
                            -encoded_actions[(turn, i + 1, j, 'QR')]
                        ])

                    if i > 0:
                        KB.append([
                            -encoded_actions[(turn, i, j, 'IU')],
                            -encoded_actions[(turn, i - 1, j, 'QR')]
                        ])

                    if turn >= 2:
                        KB.append([
                            -encoded_actions[(turn, i, j, 'HS')],
                            -encoded_actions[(turn, i, j, 'HQ')]
                        ])
                        for ac in ['IR', 'IL', 'ID', 'IU']:
                            KB.append([
                                -encoded_actions[(turn, i, j, ac)],
                                -encoded_actions[(turn, i, j, 'HS')]
                            ])
                            KB.append([
                                -encoded_actions[(turn, i, j, ac)],
                                -encoded_actions[(turn, i, j, 'HQ')]
                            ])

    # Fact Achievement Clauses
    for pop_key, pop_val in encoded_pop.items():
        (turn, i, j, pop) = pop_key
        if turn > 0:
            if pop == 'H':
                clause = [-pop_val, encoded_actions[(turn - 1, i, j, 'NOH')]]
                if turn > 2:
                    clause.append(encoded_actions[(turn - 1, i, j, 'HS')])
                    clause.append(encoded_actions[(turn - 1, i, j, 'HQ')])
                KB.append(clause)

            if pop == 'S':
                clause = [-pop_val, encoded_actions[(turn - 1, i, j, 'NOS')]]
                if j < grid_size[1] - 1:
                    clause.append(encoded_actions[(turn - 1, i, j, 'IR')])
                if j > 0:
                    clause.append(encoded_actions[(turn - 1, i, j, 'IL')])
                if i < grid_size[0] - 1:
                    clause.append(encoded_actions[(turn - 1, i, j, 'ID')])
                if i > 0:
                    clause.append(encoded_actions[(turn - 1, i, j, 'IU')])
                KB.append(clause)

            if pop == 'Q':
                KB.append([
                    -pop_val, encoded_actions[(turn - 1, i, j, 'NOQ')],
                    encoded_actions[(turn - 1, i, j, 'QR')]
                ])

            if pop == 'I':
                KB.append([
                    -pop_val, encoded_actions[(turn - 1, i, j, 'NOI')],
                    encoded_actions[(turn - 1, i, j, 'VA')]
                ])

            if pop == 'U':
                KB.append([-pop_val, encoded_actions[(turn - 1, i, j, 'NOU')]])

    # TEAMS
    def cnf_to_clauses(teams_cnf):
        clauses = str(teams_cnf).replace(" ", "").replace("(", "").replace(
            ")", "").replace("~", "-").split('&')
        return [[int(c) for c in clause.split("|")] for clause in clauses]

    if num_medics or num_police:
        # Inferred from teams:
        for turn, observation in enumerate(observations):
            if turn == len(observations) - 1:
                break
            healthy = []
            sick = []

            curr_num_medics = num_medics
            curr_num_police = num_police

            for i, row in enumerate(observation):
                for j, pop in enumerate(row):
                    if pop == 'H':
                        if observations[turn + 1][i][j] == 'I':
                            curr_num_medics -= 1
                        elif observations[turn + 1][i][j] == '?':
                            healthy.append(encoded_pop[(turn + 1, i, j, 'I')])
                    elif pop == 'S':
                        if observations[turn + 1][i][j] == 'Q':
                            curr_num_police -= 1
                        elif observations[turn + 1][i][j] == '?':
                            sick.append(encoded_pop[(turn + 1, i, j, 'Q')])
                    elif pop == '?':
                        healthy.append(encoded_pop[(turn + 1, i, j, 'I')])
                        sick.append(encoded_pop[(turn + 1, i, j, 'Q')])

            for i, row in enumerate(observation):
                for j, pop in enumerate(row):
                    if observations[turn + 1][i][j] == '?':
                        if not curr_num_medics:
                            KB.append([-encoded_actions[(turn, i, j, 'VA')]])
                        if not curr_num_police:
                            KB.append([-encoded_actions[(turn, i, j, 'QR')]])

            def teams_clauses(curr_num_teams, team_locs):
                if len(team_locs) == 1:
                    return [team_locs]

                syms = ''
                for i in team_locs:
                    syms += str(i) + ' '

                symbs = symbols(syms)
                and_groups = [
                    POSform(group, minterms=[{symb: 1
                                              for symb in group}])
                    for group in combinations(
                        symbs, min(curr_num_teams, len(team_locs)))
                ]

                SOP = SOPform(and_groups,
                              minterms=[{
                                  group: 1
                              } for group in and_groups])

                Xor_group = True
                if len(team_locs) > curr_num_teams:
                    for clause1, clause2 in combinations(SOP.args, 2):
                        Xor_group = And(Xor_group,
                                        Or(Not(clause1), Not(clause2)))

                final_teams_cnf = to_cnf(And(Xor_group, SOP))
                return cnf_to_clauses(final_teams_cnf)

            if curr_num_medics and len(healthy):
                KB += teams_clauses(curr_num_medics, healthy)

            if curr_num_police and len(sick):
                KB += teams_clauses(curr_num_police, sick)

    for query in queries:
        solver1 = Solver()
        alpha = [encoded_pop[(query[1], query[0][0], query[0][1], query[2])]]
        solver1.append_formula(KB + [alpha])

        with_alpha = solver1.solve(assumptions=alpha)

        solver2 = Solver()
        not_alpha = [
            -encoded_pop[(query[1], query[0][0], query[0][1], query[2])]
        ]
        solver2.append_formula(KB + [not_alpha])

        with_not_alpha = solver2.solve()

        if with_alpha and with_not_alpha:
            solutions[query] = '?'
        elif with_alpha and not with_not_alpha:
            solutions[query] = 'T'
        elif not with_alpha:
            solutions[query] = 'F'

    return solutions
示例#29
0
class LSU:
    """
        Linear SAT-UNSAT algorithm for MaxSAT [1]_. The algorithm can be seen
        as a series of satisfiability oracle calls refining an upper bound on
        the MaxSAT cost, followed by one unsatisfiability call, which stops the
        algorithm. The implementation encodes the sum of all selector literals
        using the *iterative totalizer encoding* [2]_. At every iteration, the
        upper bound on the cost is reduced and enforced by adding the
        corresponding unit size clause to the working formula. No clauses are
        removed during the execution of the algorithm. As a result, the SAT
        oracle is used incrementally.

        .. warning:: At this point, :class:`LSU` supports only
            **unweighted** problems.

        The constructor receives an input :class:`.WCNF` formula, a name of the
        SAT solver to use (see :class:`.SolverNames` for details), and an
        integer verbosity level.

        :param formula: input MaxSAT formula
        :param solver: name of SAT solver
        :param pb_enc_type: PB encoding type to use for solving weighted problems
        :param expect_interrupt: whether or not an :meth:`interrupt` call is expected
        :param verbose: verbosity level

        :type formula: :class:`.WCNF`
        :type solver: str
        :type expect_interrupt: bool
        :type verbose: int
    """
    def __init__(self,
                 formula,
                 solver='g4',
                 pb_enc_type=EncType.best,
                 expect_interrupt=False,
                 verbose=0):
        """
            Constructor.
        """

        self.verbose = verbose
        self.solver = solver
        self.pb_enc_type = pb_enc_type
        self.expect_interrupt = expect_interrupt
        self.formula = formula
        self.vpool = IDPool(occupied=[
            (1, formula.nv)
        ])  # variable pool used for managing card/PB encodings
        self.sels = []  # soft clause selector variables
        self.is_weighted = False  # auxiliary flag indicating if it's a weighted problem
        self.tot = None  # totalizer encoder for the cardinality constraint
        self._init(formula)  # initialize SAT oracle

    def _init(self, formula):
        """
            SAT oracle initialization. The method creates a new SAT oracle and
            feeds it with the formula's hard clauses. Afterwards, all soft
            clauses of the formula are augmented with selector literals and
            also added to the solver. The list of all introduced selectors is
            stored in variable ``self.sels``.

            :param formula: input MaxSAT formula
            :type formula: :class:`WCNF`
        """

        self.oracle = Solver(name=self.solver,
                             bootstrap_with=formula.hard,
                             incr=True,
                             use_timer=True)

        for i, cl in enumerate(formula.soft):
            # TODO: if clause is unit, use its literal as selector
            # (ITotalizer must be extended to support PB constraints first)
            selv = self.vpool._next()
            cl.append(selv)
            self.oracle.add_clause(cl)
            self.sels.append(selv)
        self.is_weighted = any(w > 1 for w in formula.wght)

        if self.verbose > 1:
            print('c formula: {0} vars, {1} hard, {2} soft'.format(
                formula.nv, len(formula.hard), len(formula.soft)))

    def __del__(self):
        """
            Destructor.
        """

        self.delete()

    def __enter__(self):
        """
            'with' constructor.
        """

        return self

    def __exit__(self, exc_type, exc_value, traceback):
        """
            'with' destructor.
        """

        self.delete()

    def delete(self):
        """
            Explicit destructor of the internal SAT oracle and the
            :class:`.ITotalizer` object.
        """

        if self.oracle:
            self.oracle.delete()
            self.oracle = None

        if self.tot:
            self.tot.delete()
            self.tot = None

    def solve(self):
        """
            Computes a solution to the MaxSAT problem. The method implements
            the LSU/LSUS algorithm, i.e. it represents a loop, each iteration
            of which calls a SAT oracle on the working MaxSAT formula and
            refines the upper bound on the MaxSAT cost until the formula
            becomes unsatisfiable.

            Returns ``True`` if the hard part of the MaxSAT formula is
            satisfiable, i.e. if there is a MaxSAT solution, and ``False``
            otherwise.

            :rtype: bool
        """

        is_sat = False

        while self.oracle.solve_limited(
                expect_interrupt=self.expect_interrupt):
            is_sat = True
            self.model = self.oracle.get_model()
            self.cost = self._get_model_cost(self.formula, self.model)
            if self.verbose:
                print('o {0}'.format(self.cost))
                sys.stdout.flush()
            if self.cost == 0:  # if cost is 0, then model is an optimum solution
                break
            self._assert_lt(self.cost)

        if is_sat:
            self.model = filter(lambda l: abs(l) <= self.formula.nv,
                                self.model)
            if self.verbose:
                if self.found_optimum():
                    print('s OPTIMUM FOUND')
                else:
                    print('s SATISFIABLE')
        elif self.verbose:
            print('s UNSATISFIABLE')

        return is_sat

    def get_model(self):
        """
            This method returns a model obtained during a prior satisfiability
            oracle call made in :func:`solve`.

            :rtype: list(int)
        """

        return self.model

    def found_optimum(self):
        """
            Checks if the optimum solution was found in a prior call to
            :func:`solve`.

            :rtype: bool
        """

        return self.oracle.get_status() is not None

    def _get_model_cost(self, formula, model):
        """
            Given a WCNF formula and a model, the method computes the MaxSAT
            cost of the model, i.e. the sum of weights of soft clauses that are
            unsatisfied by the model.

            :param formula: an input MaxSAT formula
            :param model: a satisfying assignment

            :type formula: :class:`.WCNF`
            :type model: list(int)

            :rtype: int
        """

        model_set = set(model)
        cost = 0

        for cl, w in zip(formula.soft, formula.wght):
            cost += w if all(l not in model_set for l in filter(
                lambda l: abs(l) <= self.formula.nv, cl)) else 0

        return cost

    def _assert_lt(self, cost):
        """
            The method enforces an upper bound on the cost of the MaxSAT
            solution. For unweighted problems, this is done by encoding the sum
            of all soft clause selectors with the use the iterative totalizer
            encoding, i.e. :class:`.ITotalizer`. Note that the sum is created
            once, at the beginning. Each of the following calls to this method
            only enforces the upper bound on the created sum by adding the
            corresponding unit size clause. For weighted problems, the PB
            encoding given through the :meth:`__init__` method is used.
            Each such clause is added on the fly with no restart of the
            underlying SAT oracle.

            :param cost: the cost of the next MaxSAT solution is enforced to be
                *lower* than this current cost

            :type cost: int
        """

        if self.is_weighted:
            # TODO: use incremental PB encoding
            self.oracle.append_formula(
                PBEnc.leq(self.sels,
                          weights=self.formula.wght,
                          bound=cost - 1,
                          vpool=self.vpool))
        else:

            if self.tot is None:
                self.tot = ITotalizer(lits=self.sels,
                                      ubound=cost - 1,
                                      top_id=self.vpool.top)
                self.vpool.top = self.tot.top_id

                for cl in self.tot.cnf.clauses:
                    self.oracle.add_clause(cl)

            self.oracle.add_clause([-self.tot.rhs[cost - 1]])

    def interrupt(self):
        """
            Interrupt the current execution of LSU's :meth:`solve` method.
            Can be used to enforce time limits using timer objects. The
            interrupt must be cleared before running the LSU algorithm again
            (see :meth:`clear_interrupt`).
        """

        self.oracle.interrupt()

    def clear_interrupt(self):
        """
            Clears an interruption.
        """

        self.oracle.clear_interrupt()

    def oracle_time(self):
        """
            Method for calculating and reporting the total SAT solving time.
        """

        return self.oracle.time_accum()
示例#30
0
def solve_problem(problem):
    p_atoms, last = enumerate_states(problem, 1)
    a_atoms, last = enumerate_agent_actions(last, problem)
    spa_atoms, last = enumerate_spread_actions(problem, last)
    agea_atoms, last = enumerate_states(problem, last)
    noopa_atoms, last = enumerate_states(problem, last)
    observation_clauses = gen_observation_clauses(p_atoms,
                                                  problem['observations'])
    action_preconditions_clauses = gen_precondition_clauses(
        p_atoms, a_atoms, spa_atoms, noopa_atoms, agea_atoms)
    fact_acheivement_clauses = gen_fact_acheivement_clauses(
        p_atoms, a_atoms, spa_atoms, noopa_atoms, agea_atoms)
    action_interefernce_clauses = gen_action_interefernce_clauses(
        a_atoms, spa_atoms, noopa_atoms, agea_atoms)
    must_spread_clauses = gen_must_spread_clauses(p_atoms, a_atoms, spa_atoms)
    must_age_clauses = gen_must_age_clauses(p_atoms, a_atoms, agea_atoms)
    S_representation_clauses = gen_S_representation_clauses(p_atoms)
    must_use_teams_clauses = gen_must_use_teams_clauses(
        p_atoms, a_atoms, problem['medics'], problem['police'])
    limited_teams_clauses = gen_limited_teams_clauses(p_atoms, a_atoms,
                                                      problem['medics'],
                                                      problem['police'])
    all_clauses = observation_clauses + action_preconditions_clauses + fact_acheivement_clauses + action_interefernce_clauses + must_spread_clauses + must_age_clauses + S_representation_clauses + must_use_teams_clauses + limited_teams_clauses
    answer = {}
    non_starter_list = ['I', 'Q']
    for query in problem['queries']:
        row = query[0][0]
        col = query[0][1]
        step = query[1]
        status = query[2]
        if step == 0 and status in non_starter_list:
            answer[
                query] = 'F'  # a query about a state that can't be in step 0 returns 'F' instantly
        else:
            if status == 'S':
                query_clause = [
                    p_atoms['S3'][step][row][col],
                    p_atoms['S2'][step][row][col],
                    p_atoms['S1'][step][row][col]
                ]
            elif status == 'Q':
                if step == 0:  # can't have Q in initial conditions
                    query_clause = [False]
                else:
                    query_clause = [
                        p_atoms['Q2'][step][row][col],
                        p_atoms['Q1'][step][row][col]
                    ]
            elif status == 'I':
                if step == 0:  # can't have I in initial conditions
                    query_clause = [False]
                else:
                    query_clause = [p_atoms['I'][step][row][col]]
            else:
                query_clause = [p_atoms[status][step][row][col]]
            cnf_formula1 = CNF()
            sat_solver = Solver()
            cnf_formula1.clauses = all_clauses + [query_clause]
            #cnf_formula.clauses = all_clauses + [[21, 29, 37]]
            sat_solver.append_formula(cnf_formula1)
            res1 = sat_solver.solve()
            #res2 = sat_solver.get_model()
            sat_solver.delete()
            res1_other = False
            if res1 == True:
                # try to solve again- with any otehr possible status instead of status
                # If there is a solution- then it means there could be another status
                # It this case the result can;t be conclusive
                sat_solver = Solver()
                negative_query_clauses = []
                for q in query_clause:
                    negative_query_clauses.append([-q])
                cnf_formula2 = CNF()
                cnf_formula2.clauses = all_clauses + negative_query_clauses

                sat_solver.append_formula(cnf_formula2)
                res1_other = sat_solver.solve()
                model = sat_solver.get_model()
                sat_solver.delete()

            if res1 and not res1_other:  # The query status is possible and no other status is possible
                answer[query] = 'T'
            elif res1 and res1_other:  # The query status is possible but also at least one more status can be
                answer[query] = '?'
            else:
                answer[query] = 'F'  # The query status is not possible
            sat_solver.delete()

    return answer