def get_unsat_core_pysat(fmla, alpha=None):
    n_vars = fmla.nv
    vpool = IDPool(start_from=n_vars + 1)
    r = lambda i: vpool.id(i)
    new_fmla = fmla.copy()
    num_clauses = len(new_fmla.clauses)
    for count in list(range(0, num_clauses)):
        new_fmla.clauses[count].append(r(count))  # add r_i to the ith clause
    s = Solver(name="cdl")
    s.append_formula(new_fmla)
    asms = [-r(i) for i in list(range(0, num_clauses))]
    if alpha is not None:
        asms = asms + alpha
    if not s.solve(assumptions=asms):
        core_aux = s.get_core()
    else:  # TODO(jesse): better error handling
        raise Exception("formula is sat")
    # return list(filter(lambda x: x is not None, [vpool.obj(abs(r)) for r in core_aux]))
    result = []
    bad_asms = []
    for lit in core_aux:
        if abs(lit) > n_vars:
            result.append(vpool.obj(abs(lit)))
        else:
            bad_asms.append(lit)
    return result, bad_asms
Esempio n. 2
0
class VarPool:

    def __init__(self) -> None:
        self._vpool = IDPool()

    def var(self, name: str, ind1, ind2=0, ind3=0) -> int:
        return self._vpool.id(f'{name}_{ind1}_{ind2}_{ind3}')

    def var_name(self, id_: int):
        return self._vpool.obj(id_)
Esempio n. 3
0
class Hitman(object):
    """
        A cardinality-/subset-minimal hitting set enumerator. The enumerator
        can be set up to use either a MaxSAT solver :class:`.RC2` or an MCS
        enumerator (either :class:`.LBX` or :class:`.MCSls`). In the former
        case, the hitting sets enumerated are ordered by size (smallest size
        hitting sets are computed first), i.e. *sorted*. In the latter case,
        subset-minimal hitting are enumerated in an arbitrary order, i.e.
        *unsorted*.

        This is handled with the use of parameter ``htype``, which is set to be
        ``'sorted'`` by default. The MaxSAT-based enumerator can be chosen by
        setting ``htype`` to one of the following values: ``'maxsat'``,
        ``'mxsat'``, or ``'rc2'``. Alternatively, by setting it to ``'mcs'`` or
        ``'lbx'``, a user can enforce using the :class:`.LBX` MCS enumerator.
        If ``htype`` is set to ``'mcsls'``, the :class:`.MCSls` enumerator is
        used.

        In either case, an underlying problem solver can use a SAT oracle
        specified as an input parameter ``solver``. The default SAT solver is
        Glucose3 (specified as ``g3``, see :class:`.SolverNames` for details).

        Objects of class :class:`Hitman` can be bootstrapped with an iterable
        of iterables, e.g. a list of lists. This is handled using the
        ``bootstrap_with`` parameter. Each set to hit can comprise elements of
        any type, e.g. integers, strings or objects of any Python class, as
        well as their combinations. The bootstrapping phase is done in
        :func:`init`.

        A few other optional parameters include the possible options for RC2
        as well as for LBX- and MCSls-like MCS enumerators that control the
        behaviour of the underlying solvers.

        :param bootstrap_with: input set of sets to hit
        :param weights: a mapping from objects to their weights (if weighted)
        :param solver: name of SAT solver
        :param htype: enumerator type
        :param mxs_adapt: detect and process AtMost1 constraints in RC2
        :param mxs_exhaust: apply unsatisfiable core exhaustion in RC2
        :param mxs_minz: apply heuristic core minimization in RC2
        :param mxs_trim: trim unsatisfiable cores at most this number of times
        :param mcs_usecld: use clause-D heuristic in the MCS enumerator

        :type bootstrap_with: iterable(iterable(obj))
        :type weights: dict(obj)
        :type solver: str
        :type htype: str
        :type mxs_adapt: bool
        :type mxs_exhaust: bool
        :type mxs_minz: bool
        :type mxs_trim: int
        :type mcs_usecld: bool
    """
    def __init__(self,
                 bootstrap_with=[],
                 weights=None,
                 solver='g3',
                 htype='sorted',
                 mxs_adapt=False,
                 mxs_exhaust=False,
                 mxs_minz=False,
                 mxs_trim=0,
                 mcs_usecld=False):
        """
            Constructor.
        """

        # hitting set solver
        self.oracle = None

        # name of SAT solver
        self.solver = solver

        # various oracle options
        self.adapt = mxs_adapt
        self.exhaust = mxs_exhaust
        self.minz = mxs_minz
        self.trim = mxs_trim
        self.usecld = mcs_usecld

        # hitman type: either a MaxSAT solver or an MCS enumerator
        if htype in ('maxsat', 'mxsat', 'rc2', 'sorted'):
            self.htype = 'rc2'
        elif htype in ('mcs', 'lbx'):
            self.htype = 'lbx'
        else:  # 'mcsls'
            self.htype = 'mcsls'

        # pool of variable identifiers (for objects to hit)
        self.idpool = IDPool()

        # initialize hitting set solver
        self.init(bootstrap_with, weights)

    def __del__(self):
        """
            Destructor.
        """

        self.delete()

    def __enter__(self):
        """
            'with' constructor.
        """

        return self

    def __exit__(self, exc_type, exc_value, traceback):
        """
            'with' destructor.
        """

        self.delete()

    def init(self, bootstrap_with, weights=None):
        """
            This method serves for initializing the hitting set solver with a
            given list of sets to hit. Concretely, the hitting set problem is
            encoded into partial MaxSAT as outlined above, which is then fed
            either to a MaxSAT solver or an MCS enumerator.

            An additional optional parameter is ``weights``, which can be used
            to specify non-unit weights for the target objects in the sets to
            hit. This only works if ``'sorted'`` enumeration of hitting sets
            is applied.

            :param bootstrap_with: input set of sets to hit
            :param weights: weights of the objects in case the problem is weighted
            :type bootstrap_with: iterable(iterable(obj))
            :type weights: dict(obj)
        """

        # formula encoding the sets to hit
        formula = WCNF()

        # hard clauses
        for to_hit in bootstrap_with:
            to_hit = list(map(lambda obj: self.idpool.id(obj), to_hit))

            formula.append(to_hit)

        # soft clauses
        for obj_id in six.iterkeys(self.idpool.id2obj):
            formula.append(
                [-obj_id],
                weight=1 if not weights else weights[self.idpool.obj(obj_id)])

        if self.htype == 'rc2':
            if not weights or min(weights.values()) == max(weights.values()):
                self.oracle = RC2(formula,
                                  solver=self.solver,
                                  adapt=self.adapt,
                                  exhaust=self.exhaust,
                                  minz=self.minz,
                                  trim=self.trim)
            else:
                self.oracle = RC2Stratified(formula,
                                            solver=self.solver,
                                            adapt=self.adapt,
                                            exhaust=self.exhaust,
                                            minz=self.minz,
                                            nohard=True,
                                            trim=self.trim)
        elif self.htype == 'lbx':
            self.oracle = LBX(formula,
                              solver_name=self.solver,
                              use_cld=self.usecld)
        else:
            self.oracle = MCSls(formula,
                                solver_name=self.solver,
                                use_cld=self.usecld)

    def delete(self):
        """
            Explicit destructor of the internal hitting set oracle.
        """

        if self.oracle:
            self.oracle.delete()
            self.oracle = None

    def get(self):
        """
            This method computes and returns a hitting set. The hitting set is
            obtained using the underlying oracle operating the MaxSAT problem
            formulation. The computed solution is mapped back to objects of the
            problem domain.

            :rtype: list(obj)
        """

        model = self.oracle.compute()

        if model is not None:
            if self.htype == 'rc2':
                # extracting a hitting set
                self.hset = filter(lambda v: v > 0, model)
            else:
                self.hset = model

            return list(map(lambda vid: self.idpool.id2obj[vid], self.hset))

    def hit(self, to_hit, weights=None):
        """
            This method adds a new set to hit to the hitting set solver. This
            is done by translating the input iterable of objects into a list of
            Boolean variables in the MaxSAT problem formulation.

            Note that an optional parameter that can be passed to this method
            is ``weights``, which contains a mapping the objects under
            question into weights. Also note that the weight of an object must
            not change from one call of :meth:`hit` to another.

            :param to_hit: a new set to hit
            :param weights: a mapping from objects to weights

            :type to_hit: iterable(obj)
            :type weights: dict(obj)
        """

        # translating objects to variables
        to_hit = list(map(lambda obj: self.idpool.id(obj), to_hit))

        # a soft clause should be added for each new object
        new_obj = list(
            filter(lambda vid: vid not in self.oracle.vmap.e2i, to_hit))

        # new hard clause
        self.oracle.add_clause(to_hit)

        # new soft clauses
        for vid in new_obj:
            self.oracle.add_clause(
                [-vid], 1 if not weights else weights[self.idpool.obj(vid)])

    def block(self, to_block, weights=None):
        """
            The method serves for imposing a constraint forbidding the hitting
            set solver to compute a given hitting set. Each set to block is
            encoded as a hard clause in the MaxSAT problem formulation, which
            is then added to the underlying oracle.

            Note that an optional parameter that can be passed to this method
            is ``weights``, which contains a mapping the objects under
            question into weights. Also note that the weight of an object must
            not change from one call of :meth:`hit` to another.

            :param to_block: a set to block
            :param weights: a mapping from objects to weights

            :type to_block: iterable(obj)
            :type weights: dict(obj)
        """

        # translating objects to variables
        to_block = list(map(lambda obj: self.idpool.id(obj), to_block))

        # a soft clause should be added for each new object
        new_obj = list(
            filter(lambda vid: vid not in self.oracle.vmap.e2i, to_block))

        # new hard clause
        self.oracle.add_clause([-vid for vid in to_block])

        # new soft clauses
        for vid in new_obj:
            self.oracle.add_clause(
                [-vid], 1 if not weights else weights[self.idpool.obj(vid)])

    def enumerate(self):
        """
            The method can be used as a simple iterator computing and blocking
            the hitting sets on the fly. It essentially calls :func:`get`
            followed by :func:`block`. Each hitting set is reported as a list
            of objects in the original problem domain, i.e. it is mapped back
            from the solutions over Boolean variables computed by the
            underlying oracle.

            :rtype: list(obj)
        """

        done = False
        while not done:
            hset = self.get()

            if hset != None:
                self.block(hset)
                yield hset
            else:
                done = True

    def oracle_time(self):
        """
            Report the total SAT solving time.
        """

        return self.oracle.oracle_time()
Esempio n. 4
0
class Solver:
    def __init__(self, solver_input):
        self.num_police = solver_input['police']
        self.num_medics = solver_input['medics']
        self.observations = solver_input['observations']
        self.num_turns = len(self.observations)  # TODO maybe max on queries
        self.height = len(self.observations[0])
        self.width = len(self.observations[0][0])
        self.vpool = IDPool()
        self.tiles = [(i, j) for i in range(self.height)
                      for j in range(self.width)]

        self.clauses = self.generate_clauses()

    def generate_clauses(self):
        clauses = []
        clauses.extend(
            self.generate_observations_clauses())  # TODO check validity
        clauses.extend(self.generate_validity_clauses())  # TODO check validity
        clauses.extend(self.generate_dynamics_clauses())  # TODO check validity
        clauses.extend(
            self.generate_valid_actions_clauses())  # TODO check validity
        return clauses

    def generate_observations_clauses(self):
        clauses = []

        for turn, observation in enumerate(self.observations):
            for row in range(self.height):
                for col in range(self.width):
                    state = observation[row][col]
                    if state == SICK:
                        clauses.append([
                            self.vpool.id((turn, row, col, SICK_0)),
                            self.vpool.id((turn, row, col, SICK_1)),
                            self.vpool.id((turn, row, col, SICK_2))
                        ])
                    elif state == QUARANTINED:
                        clauses.append([
                            self.vpool.id((turn, row, col, QUARANTINED_0)),
                            self.vpool.id((turn, row, col, QUARANTINED_1))
                        ])
                    elif state == IMMUNE:
                        clauses.append([
                            self.vpool.id((turn, row, col, IMMUNE_RECENTLY)),
                            self.vpool.id((turn, row, col, IMMUNE))
                        ])
                    elif state == UNK:
                        continue
                    else:
                        clauses.append(
                            [self.vpool.id((turn, row, col, state))])

        return clauses

    def generate_validity_clauses(self):
        clauses = []
        for row in range(self.height):
            for col in range(self.width):
                clauses.extend(self.first_turn_clauses(row, col))
                clauses.extend(self.second_turn_clauses(row, col))
                clauses.extend(self.uniqueness_clauses(row, col))

        return clauses

    def first_turn_clauses(self, row, col):
        lits = [
            self.vpool.id((0, row, col, state)) for state in FIRST_TURN_STATES
        ]
        clauses = CardEnc.equals(lits, bound=1, vpool=self.vpool).clauses
        for state in STATES:
            if state not in FIRST_TURN_STATES:
                clauses.append([-self.vpool.id((0, row, col, state))])
        return clauses

    def second_turn_clauses(self, row, col):
        lits = [
            self.vpool.id((1, row, col, state)) for state in SECOND_TURN_STATES
        ]
        clauses = CardEnc.equals(lits, bound=1, vpool=self.vpool).clauses
        for state in STATES:
            if state not in SECOND_TURN_STATES:
                clauses.append([-self.vpool.id((0, row, col, state))])
        return clauses

    def uniqueness_clauses(self, row, col):
        clauses = []
        for turn in range(2, self.num_turns):
            lits = [self.vpool.id((turn, row, col, state)) for state in STATES]
            clauses.extend(
                CardEnc.equals(lits, bound=1, vpool=self.vpool).clauses)
        return clauses

    def generate_dynamics_clauses(self):
        clauses = []
        for turn in range(self.num_turns):
            for row in range(self.height):
                for col in range(self.width):
                    clauses.extend(self.unpopulated_clauses(turn, row, col))
                    clauses.extend(self.sick_clauses(turn, row, col))
                    clauses.extend(self.healthy_clauses(turn, row, col))
                    clauses.extend(self.immune_clauses(turn, row, col))
                    clauses.extend(self.quarantine_clauses(turn, row, col))

        return clauses

    def unpopulated_clauses(self, turn, row, col):
        clauses = []

        # Previous Turn
        if 0 < turn:
            # U_t = > U_t-1
            clauses.append([
                -self.vpool.id((turn, row, col, UNPOPULATED)),
                self.vpool.id((turn - 1, row, col, UNPOPULATED))
            ])

        # Next Turn
        if turn < self.num_turns - 1:
            # U_t => U_t+1
            clauses.append([
                -self.vpool.id((turn, row, col, UNPOPULATED)),
                self.vpool.id((turn + 1, row, col, UNPOPULATED))
            ])

        return clauses

    def sick_clauses(self, turn, row, col):
        clauses = []
        neighbors = self.get_neighbors(row, col)

        # Is sick
        # Previous Turn
        if 0 < turn:
            # S2_t => H_t-1
            clauses.append([
                -self.vpool.id((turn, row, col, SICK_2)),
                self.vpool.id((turn - 1, row, col, HEALTHY))
            ])
            # S1_t => S2_t-1
            clauses.append([
                -self.vpool.id((turn, row, col, SICK_1)),
                self.vpool.id((turn - 1, row, col, SICK_2))
            ])
            # S0_t => S1_t-1
            clauses.append([
                -self.vpool.id((turn, row, col, SICK_0)),
                self.vpool.id((turn - 1, row, col, SICK_1))
            ])

        # Next Turn
        if turn < self.num_turns - 1:
            # S2_t => S1_t+1 v Q1_t+1
            clauses.append([
                -self.vpool.id((turn, row, col, SICK_2)),
                self.vpool.id((turn + 1, row, col, SICK_1)),
                self.vpool.id((turn + 1, row, col, QUARANTINED_1))
            ])
            # S1_t => S0_t+1 v Q1_t+1
            clauses.append([
                -self.vpool.id((turn, row, col, SICK_1)),
                self.vpool.id((turn + 1, row, col, SICK_0)),
                self.vpool.id((turn + 1, row, col, QUARANTINED_1))
            ])
            # S0_t => H_t+1 v Q1_t+1
            clauses.append([
                -self.vpool.id((turn, row, col, SICK_0)),
                self.vpool.id((turn + 1, row, col, HEALTHY)),
                self.vpool.id((turn + 1, row, col, QUARANTINED_1))
            ])

        # Infected By Someone
        if 0 < turn:
            # S2_t => V (S2n_t-1 v S1n_t-1 v S0n_t-1) for n in neighbors
            clause = [-self.vpool.id((turn, row, col, SICK_2))]
            for (n_row, n_col) in neighbors:
                clause.extend([
                    self.vpool.id((turn - 1, n_row, n_col, SICK_2)),
                    self.vpool.id((turn - 1, n_row, n_col, SICK_1)),
                    self.vpool.id((turn - 1, n_row, n_col, SICK_0))
                ])
            clauses.append(clause)

        # Infecting Others
        if turn < self.num_turns - 1:
            for (n_row, n_col) in neighbors:
                for sick_i in [SICK_0, SICK_1, SICK_2]:
                    # Si_t /\ Hn_t /\ -Q1_t+1 /\ -I_recent_t+1 => S2n_t+1 (Sn, Hn stand for neighbor)
                    clauses.append([
                        -self.vpool.id((turn, row, col, sick_i)),
                        -self.vpool.id((turn, n_row, n_col, HEALTHY)),
                        self.vpool.id((turn + 1, row, col, QUARANTINED_1)),
                        self.vpool.id(
                            (turn + 1, n_row, n_col, IMMUNE_RECENTLY)),
                        self.vpool.id((turn + 1, n_row, n_col, SICK_2))
                    ])

        return clauses

    def healthy_clauses(self, turn, row, col):
        clauses = []

        # Previous Turn
        if 0 < turn:
            # H_t => H_t-1 v Q0_t-1 v S0_t-1
            clauses.append([
                -self.vpool.id((turn, row, col, HEALTHY)),
                self.vpool.id((turn - 1, row, col, HEALTHY)),
                self.vpool.id((turn - 1, row, col, QUARANTINED_0)),
                self.vpool.id((turn - 1, row, col, SICK_0))
            ])

        # Next Turn
        if turn < self.num_turns - 1:
            # H_t => H_t+1 \/ S2_t+1 \/ I_recent_t+1
            clauses.append([
                -self.vpool.id((turn, row, col, HEALTHY)),
                self.vpool.id((turn + 1, row, col, HEALTHY)),
                self.vpool.id((turn + 1, row, col, SICK_2)),
                self.vpool.id((turn + 1, row, col, IMMUNE_RECENTLY))
            ])
        return clauses

    def immune_clauses(self, turn, row, col):
        clauses = []

        # Previous Turn
        if 0 < turn:
            # I_t => I_t-1 v I_recent_t-1
            clauses.append([
                -self.vpool.id((turn, row, col, IMMUNE)),
                self.vpool.id((turn - 1, row, col, IMMUNE)),
                self.vpool.id((turn - 1, row, col, IMMUNE_RECENTLY))
            ])

            # I_recent_t => H_t-1
            clauses.append([
                -self.vpool.id((turn, row, col, IMMUNE_RECENTLY)),
                self.vpool.id((turn - 1, row, col, HEALTHY))
            ])

        # Next Turn
        if turn < self.num_turns - 1:
            # I_t => I_t+1
            clauses.append([
                -self.vpool.id((turn, row, col, IMMUNE)),
                self.vpool.id((turn + 1, row, col, IMMUNE))
            ])

            # I_recent_t => I_t+1
            clauses.append([
                -self.vpool.id((turn, row, col, IMMUNE_RECENTLY)),
                self.vpool.id((turn + 1, row, col, IMMUNE))
            ])

        return clauses

    def quarantine_clauses(self, turn, row, col):
        clauses = []

        # Previous Turn
        if 0 < turn:
            # Q1_t => S2_t-1 v S1_t-1 v S0_t-1
            clauses.append([
                -self.vpool.id((turn, row, col, QUARANTINED_1)),
                self.vpool.id((turn - 1, row, col, SICK_2)),
                self.vpool.id((turn - 1, row, col, SICK_1)),
                self.vpool.id((turn - 1, row, col, SICK_0))
            ])

            # Q0_t => Q1_t-1
            clauses.append([
                -self.vpool.id((turn, row, col, QUARANTINED_0)),
                self.vpool.id((turn - 1, row, col, QUARANTINED_1))
            ])

        # Next Turn
        if turn < self.num_turns - 1:
            # Q1_t => Q0_t+1
            clauses.append([
                -self.vpool.id((turn, row, col, QUARANTINED_1)),
                self.vpool.id((turn + 1, row, col, QUARANTINED_0))
            ])

            # Q0_t => H_t+1
            clauses.append([
                -self.vpool.id((turn, row, col, QUARANTINED_0)),
                self.vpool.id((turn + 1, row, col, HEALTHY))
            ])

        return clauses

    def generate_valid_actions_clauses(self):
        clauses = []
        clauses.extend(self.generate_police_clauses())
        clauses.extend(self.generate_medic_clauses())
        return clauses

    def generate_police_clauses(self):
        clauses = []

        for turn in range(1, self.num_turns):
            lits = [
                self.vpool.id((turn, row, col, QUARANTINED_1))
                for row in range(self.height) for col in range(self.width)
            ]
            clauses.extend(
                CardEnc.atmost(lits, bound=self.num_police,
                               vpool=self.vpool).clauses)
        # TODO check case of 0 policemen
        if self.num_police == 0:
            return clauses

        for turn in range(self.num_turns - 1):
            for num_sick in range(self.width * self.height):
                for sick_tiles in itertools.combinations(self.tiles, num_sick):
                    healthy_tiles = [
                        tile for tile in self.tiles if tile not in sick_tiles
                    ]
                    # TODO don't iterate over all sick states
                    for sick_state_perm in \
                            itertools.combinations_with_replacement(self.possible_sick_states(turn), num_sick):
                        clause = []

                        for (row, col), state in zip(sick_tiles,
                                                     sick_state_perm):
                            clause.append(-self.vpool.id((turn, row, col,
                                                          state)))
                        for row, col in healthy_tiles:
                            for state in self.possible_sick_states(turn):
                                clause.append(
                                    self.vpool.id((turn, row, col, state)))

                        lits = [
                            self.vpool.id((turn + 1, row, col, QUARANTINED_1))
                            for row, col in sick_tiles
                        ]
                        equals_clauses = CardEnc.equals(
                            lits,
                            bound=min(self.num_police, num_sick),
                            vpool=self.vpool).clauses
                        for sub_clause in equals_clauses:
                            temp_clause = deepcopy(clause)
                            temp_clause += sub_clause
                            clauses.append(temp_clause)

                        # if num_sick <= self.num_police:
                        #     for (row, col) in sick_tiles:
                        #         temp_clause = deepcopy(clause)
                        #         temp_clause.append(self.vpool.id((turn+1, row, col, QUARANTINED_1)))
                        #         clauses.extend(temp_clause)
                        #
                        #     # for (row, col) in healthy_tiles:
                        #     #     temp_clause = deepcopy(clause)
                        #     #     temp_clause.append(-self.vpool.id((turn+1, row, col, QUARANTINED_1)))
                        #     #     clauses.extend(temp_clause)
                        #
                        # else:
                        #     lits = [self.vpool.id((turn+1, row, col, QUARANTINED_1))
                        #             for row in range(self.height)
                        #             for col in range(self.width)]
                        #     equals_clauses = CardEnc.equals(lits, bound=self.num_police, vpool=self.vpool)
                        #
                        #     for sub_clause in equals_clauses.clauses():
                        #         temp_clause = deepcopy(clause)
                        #         temp_clause += sub_clause
                        #         clauses.extend(temp_clause)
        return clauses

    def generate_medic_clauses(self):
        clauses = []

        for turn in range(self.num_turns):
            lits = [
                self.vpool.id((turn, row, col, IMMUNE_RECENTLY))
                for row in range(self.height) for col in range(self.width)
            ]
            clauses.extend(
                CardEnc.atmost(lits, bound=self.num_medics,
                               vpool=self.vpool).clauses)
        # TODO check case of 0 medics
        if self.num_medics == 0:
            return clauses

        for turn in range(self.num_turns - 1):
            for num_healthy in range(self.width * self.height):
                for healthy_tiles in itertools.combinations(
                        self.tiles, num_healthy):
                    sick_tiles = [
                        tile for tile in self.tiles
                        if tile not in healthy_tiles
                    ]
                    clause = []

                    for row, col in healthy_tiles:
                        clause.append(-self.vpool.id((turn, row, col,
                                                      HEALTHY)))

                    for row, col in sick_tiles:
                        clause.append(self.vpool.id((turn, row, col, HEALTHY)))

                    lits = [
                        self.vpool.id((turn + 1, row, col, IMMUNE_RECENTLY))
                        for row, col in healthy_tiles
                    ]
                    equals_clauses = CardEnc.equals(lits,
                                                    bound=min(
                                                        self.num_medics,
                                                        num_healthy),
                                                    vpool=self.vpool).clauses
                    for sub_clause in equals_clauses:
                        temp_clause = deepcopy(clause)
                        temp_clause += sub_clause
                        clauses.append(temp_clause)

        return clauses

    def get_neighbors(self, i, j):
        return [
            val for val in [(i - 1, j), (i + 1, j), (i, j - 1), (i, j + 1)]
            if self.in_board(*val)
        ]

    def in_board(self, i, j):
        return 0 <= i < self.height and 0 <= j < self.width

    def generate_query_clause(self, query):
        (q_row, q_col), turn, state = query

        if state == SICK:
            clause = [
                self.vpool.id((turn, q_row, q_col, SICK_0)),
                self.vpool.id((turn, q_row, q_col, SICK_1)),
                self.vpool.id((turn, q_row, q_col, SICK_2))
            ]

        elif state == QUARANTINED:
            clause = [
                self.vpool.id((turn, q_row, q_col, QUARANTINED_0)),
                self.vpool.id((turn, q_row, q_col, QUARANTINED_1))
            ]

        elif state == IMMUNE:
            clause = [
                self.vpool.id((turn, q_row, q_col, IMMUNE)),
                self.vpool.id((turn, q_row, q_col, IMMUNE_RECENTLY))
            ]

        else:
            clause = [self.vpool.id((turn, q_row, q_col, state))]

        return clause

    def __str__(self):
        return '\n'.join(self.repr_clauses())

    def repr_clauses(self):
        return [self.clause2str(clause) for clause in self.clauses]

    def clause2str(self, clause):
        # out = ''
        # for ind in clause[:-1]:
        #     out += f'{self.prop2str(self.vpool.obj(abs(ind)))} v '
        # out += self.prop2str(self.vpool.obj(abs(clause[-1])))

        out = ' \\/ '.join([
            '-' * (ind < 0) + self.prop2str(self.vpool.obj(abs(ind)))
            for ind in clause
        ])
        return out

    @staticmethod
    def prop2str(prop):
        if prop is None:
            return 'Fictive'
        turn, row, col, state = prop
        return f'{state}_{turn}_({row},{col})'

    @staticmethod
    def possible_sick_states(turn):
        if turn == 0:
            return [SICK_2]
        if turn == 1:
            return [SICK_1, SICK_2]
        return SICK_STATES
Esempio n. 5
0
def solve_problem_t(input, T):

    n_police = input['police']
    n_medics = input['medics']
    observations = input['observations']
    queries = input['queries']
    H = len(observations[0])
    W = len(observations[0][0])

    vpool = IDPool()
    clauses = []
    tile = lambda code, r, c, t: '{0}{1}{2}_{3}'.format(code, r, c, t)
    CODES = ['U', 'H', 'S']
    ACTIONS = []
    if n_medics > 0:
        CODES.append('I')
        ACTIONS.append('medics')
    if n_police > 0:
        CODES.append('Q')
        ACTIONS.append('police')

    # create list of predicates and associate an integer in pySat for each predicate
    for t in range(T):
        for code in CODES + ACTIONS:
            for r in range(H):
                for c in range(W):
                    vpool.id(tile(code, r, c, t))

    # clauses for the known tiles in the observation, both positive and negative
    for t in range(T):
        curr_observ = observations[t]
        for r in range(H):
            for c in range(W):
                curr_code = curr_observ[r][c]
                for code in CODES:
                    if (code == curr_code) & (curr_code != '?'):
                        clauses.append(tile(code, r, c, t))
                    elif (code != curr_code) & (curr_code != '?'):
                        clauses.append('~' + tile(code, r, c, t))

    # Uxy_t ==> Uxy_t-1 pre-condition
    for t in range(1, T):
        for r in range(H):
            for c in range(W):
                clauses.append(
                    tile('U', r, c, t) + ' ==> ' + tile('U', r, c, t - 1))

    # Uxy_t ==> Uxy_t+1 add effect (no del effect)
    for t in range(T - 1):
        for r in range(H):
            for c in range(W):
                clauses.append(
                    tile('U', r, c, t) + ' ==> ' + tile('U', r, c, t + 1))

    # Ixy_t ==> Ixy_t-1 | (Hxy_t-1 & medicsxy_t-1)' - pre-condition of 'I'
    if n_medics > 0:
        for t in range(1, T):
            for r in range(H):
                for c in range(W):
                    #clauses.append(tile('I',r,c,t) + ' ==> (' + tile('I',r,c,t-1) + ' | ' + tile('H',r,c,t-1)+')')
                    clauses.append(tile('I',r,c,t) + ' ==> (' + tile('I',r,c,t-1) + \
                        ' | (' + tile('H',r,c,t-1) + ' & ' + tile('medics',r,c,t-1) + '))')

    # Ixy_t ==> Ixy_t+1 - add effect of 'I' (no del effect)
    if n_medics > 0:
        for t in range(T - 1):
            for r in range(H):
                for c in range(W):
                    clauses.append(
                        tile('I', r, c, t) + ' ==> ' + tile('I', r, c, t + 1))

    # Qxy_t ==> Qxy_t-1 | (Sxy_t-1 & policexy_t-1)' - pre-condition of 'Q'
    if n_police > 0:
        for t in range(1, T):
            for r in range(H):
                for c in range(W):
                    clauses.append(tile('Q',r,c,t) + ' ==> (' + tile('Q',r,c,t-1) + \
                        ' | (' + tile('S',r,c,t-1) + ' & ' + tile('police',r,c,t-1) + '))')

    # add and del effects of Qxy_t
    if n_police > 0:
        for t in range(T - 1):
            for r in range(H):
                for c in range(W):
                    if t < 1:
                        # Qxy_t ==> Qxy_t+1
                        clauses.append(
                            tile('Q', r, c, t) + ' ==> ' +
                            tile('Q', r, c, t + 1))
                    else:
                        # Qxy_t & ~Qxy_t-1 ==> Qxy_t+1
                        clauses.append('(' + tile('Q',r,c,t) + ' & ~' + tile('Q',r,c,t-1) + ')'\
                            + ' ==> ' + tile('Q',r,c,t+1))
                        # Qxy_t & Qxy_t-1 ==> Hxy_t+1
                        clauses.append('(' + tile('Q',r,c,t) +' & ' + tile('Q',r,c,t-1) + \
                            ') ==> ' + tile('H',r,c,t+1))
                        # Qxy_t & Qxy_t-1 ==> ~Qxy_t+1
                        clauses.append('(' + tile('Q',r,c,t) +' & ' + tile('Q',r,c,t-1) + \
                            ') ==> ~' + tile('Q',r,c,t+1))

    # precondition of S(x,y,t) is either S(x,y,t-1) or H(x,y,t-1) and at least one sick neighbor
    for t in range(1, T):
        for r in range(H):
            for c in range(W):
                #n_coords = get_neighbors(r,c,H,W)
                #curr_clause = tile('S',r,c,t) + ' ==> (' + tile('S',r,c,t-1) + ' | (' + tile('H',r,c,t-1) + ' & ('
                #for coord in n_coords:
                #    curr_clause+= tile('S',coord[0],coord[1],t-1) + ' | '
                #clauses.append(curr_clause[:-3] + ')))')
                clauses.append(
                    tile('S', r, c, t) + ' ==> (' + tile('S', r, c, t - 1) +
                    ' | ' + tile('H', r, c, t - 1) + ')')

    # add and del effects of Sxy_t
    for t in range(T - 1):
        for r in range(H):
            for c in range(W):
                if n_police > 0:
                    # Sxy_t & policexy_t ==> Qxy_t+1 - add effect of 'S' if there's police
                    clauses.append(
                        tile('S', r, c, t) + ' & ' + tile('police', r, c, t) +
                        ' ==> ' + tile('Q', r, c, t + 1))
                    # Sxy_t & policexy_t ==> ~Sxy_t+1 - del effect of 'S' if there's police
                    clauses.append(
                        tile('S', r, c, t) + ' & ' + tile('police', r, c, t) +
                        ' ==> ~' + tile('S', r, c, t + 1))
                    if t < 2:
                        # Sxy_t & ~policexy_t ==> Sxy_t+1
                        clauses.append('(' + tile('S',r,c,t) + ' & ' + tile('~police',r,c,t) + \
                            ') ==> ' + tile('S',r,c,t+1))
                    else:
                        # Sxy_t & ~policexy_t & ~Sxy_t-1 ==> Sxy_t+1
                        clauses.append('(' + tile('S',r,c,t) + ' & ~' + tile('police',r,c,t) + ' & ~' + tile('S',r,c,t-1) + ')'\
                            + ') ==> ' + tile('S',r,c,t+1))
                        # Sxy_t & ~policexy_t & Sxy_t-1 & ~Sxy_t-2 ==> Sxy_t+1
                        clauses.append('(' + tile('S',r,c,t) + ' & ~' + tile('police',r,c,t) + ' & ' + tile('S',r,c,t-1) + \
                            ' & ~' + tile('S',r,c,t-2) + ') ==> ' + tile('S',r,c,t+1))
                        # Sxy_t & ~policexy_t & Sxy_t-1 & Sxy_t-2 ==> Hxy_t+1
                        clauses.append('(' + tile('S',r,c,t) + ' & ~' + tile('police',r,c,t) + ' & ' + tile('S',r,c,t-1) + \
                            ' & ' + tile('S',r,c,t-2) + ') ==> ' + tile('H',r,c,t+1))
                        # Sxy_t & ~policexy_t & Sxy_t-1 & Sxy_t-2 ==> ~Sxy_t+1
                        clauses.append('(' + tile('S',r,c,t) + ' & ~' + tile('police',r,c,t) + ' & ' + tile('S',r,c,t-1) + \
                            ' & ' + tile('S',r,c,t-2) + ') ==> ~' + tile('S',r,c,t+1))
                else:
                    if t < 2:
                        # Sxy_t ==> Sxy_t+1
                        clauses.append(
                            tile('S', r, c, t) + ' ==> ' +
                            tile('S', r, c, t + 1))
                    else:
                        # Sxy_t & ~Sxy_t-1 ==> Sxy_t+1
                        clauses.append('(' + tile('S',r,c,t) + ' & ~' + tile('S',r,c,t-1) + ')'\
                            + ' ==> ' + tile('S',r,c,t+1))
                        # Sxy_t & Sxy_t-1 & ~Sxy_t-2 ==> Sxy_t+1
                        clauses.append('(' + tile('S',r,c,t) + ' & ' + tile('S',r,c,t-1) + \
                            ' & ~' + tile('S',r,c,t-2) + ') ==> ' + tile('S',r,c,t+1))
                        # Sxy_t & Sxy_t-1 & Sxy_t-2 ==> Hxy_t+1
                        clauses.append('(' + tile('S',r,c,t) + ' & ' + tile('S',r,c,t-1) + \
                            ' & ' + tile('S',r,c,t-2) + ' ) ==> ' + tile('H',r,c,t+1))
                        # Sxy_t & Sxy_t-1 & Sxy_t-2 ==> ~Sxy_t+1
                        clauses.append('(' + tile('S',r,c,t) + ' & ' + tile('S',r,c,t-1) + \
                            ' & ' + tile('S',r,c,t-2) + ' ) ==> ~' + tile('S',r,c,t+1))

    # pre-conditions of 'H'
    for t in range(1, T):
        for r in range(H):
            for c in range(W):
                if t < 3:
                    # Hxy_t ==> Hxy_t-1
                    clauses.append(
                        tile('H', r, c, t) + ' ==> ' + tile('H', r, c, t - 1))
                else:
                    # Hxy_t ==> Hxy_t-1 | (Sxy_t-1 & Sxy_t-2 & Sxy_t-3) | (Qxy_t-1, Qxy_t-2)
                    curr_clause = tile('H',r,c,t) + ' ==> ' + tile('H',r,c,t-1) + ' | (' + \
                        tile('S',r,c,t-1) + ' & ' + tile('S',r,c,t-2) + ' & ' + tile('S',r,c,t-3) + ')'
                    if n_police > 0:
                        curr_clause += ' | (' + tile(
                            'Q', r, c, t - 1) + ' & ' + tile('Q', r, c,
                                                             t - 2) + ')'
                    clauses.append(curr_clause)

    # add effect of 'H'
    for t in range(T - 1):
        for r in range(H):
            for c in range(W):
                n_coords = get_neighbors(r, c, H, W)
                if n_medics > 0:
                    # Hxy_t & medicsxy_t ==> Ixy_t+1
                    clauses.append(
                        tile('H', r, c, t) + ' & ' + tile('medics', r, c, t) +
                        ' ==> ' + tile('I', r, c, t + 1))
                    # Hxy_t & medicsxy_T ==> ~Hxy_t+1
                    clauses.append(
                        tile('H', r, c, t) + ' & ' + tile('medics', r, c, t) +
                        ' ==> ' + tile('~H', r, c, t + 1))
                    # Hxy_t & ~medicsxy_t & (at least one sick neighbor) ==> Sxy_t+1
                    curr_clause = tile('H', r, c, t) + ' & ' + tile(
                        '~medics', r, c, t) + ' & ('
                    for coord in n_coords:
                        if n_police > 0:  # if neighbors 'S' do not turn to 'Q' in the next turn
                            curr_clause += '(' + tile('S',coord[0],coord[1],t) + ' & ' + \
                                tile('~Q',coord[0],coord[1],t+1) + ') | '
                        else:
                            curr_clause += tile('S', coord[0], coord[1],
                                                t) + ' | '
                    curr_clause = curr_clause[:-3] + ') ==> ' + tile(
                        'S', r, c, t + 1)
                    clauses.append(curr_clause)
                    # Hxy_t & ~medicsxy_t & (at least one sick neighbor) ==> ~Hxy_t+1
                    curr_clause = tile('H', r, c, t) + ' & ' + tile(
                        '~medics', r, c, t) + ' & ('
                    for coord in n_coords:
                        curr_clause += tile('S', coord[0], coord[1], t) + ' | '
                    curr_clause = curr_clause[:-3] + ') ==> ' + tile(
                        '~H', r, c, t + 1)
                    clauses.append(curr_clause)
                    # Hxy_t & ~medicsxy_t & (no sick neighbors) ==> Hxy_t+1
                    curr_clause = []
                    curr_clause = tile('H', r, c, t) + ' & ' + tile(
                        '~medics', r, c, t) + ' & ('
                    for coord in n_coords:
                        curr_clause += tile('~S', coord[0], coord[1],
                                            t) + ' & '
                    curr_clause = curr_clause[:-3] + ') ==> ' + tile(
                        'H', r, c, t + 1)
                    clauses.append(curr_clause)
                else:
                    # Hxy_t & (at least one sick neighbor) ==> Sxy_t+1
                    curr_clause = tile('H', r, c, t) + ' & ('
                    for coord in n_coords:
                        if n_police > 0:
                            curr_clause += '(' + tile('S',coord[0],coord[1],t) + ' & ' + \
                                tile('~Q',coord[0],coord[1],t+1) + ') | '
                        else:
                            curr_clause += tile('S', coord[0], coord[1],
                                                t) + ' | '
                    curr_clause = curr_clause[:-3] + ') ==> ' + tile(
                        'S', r, c, t + 1)
                    clauses.append(curr_clause)
                    # Hxy_t & (at least one sick neighbor) ==> ~Hxy_t+1
                    curr_clause = tile('H', r, c, t) + ' & ('
                    for coord in n_coords:
                        curr_clause += tile('S', coord[0], coord[1], t) + ' | '
                    curr_clause = curr_clause[:-3] + ') ==> ' + tile(
                        '~H', r, c, t + 1)
                    clauses.append(curr_clause)
                    # Hxy_t & (no sick neighbors) ==> Hxy_t+1
                    curr_clause = tile('H', r, c, t) + ' & ('
                    for coord in n_coords:
                        curr_clause += tile('~S', coord[0], coord[1],
                                            t) + ' & '
                    curr_clause = curr_clause[:-3] + ') ==> ' + tile(
                        'H', r, c, t + 1)
                    clauses.append(curr_clause)

    ## Qxy_t ==> Qxy_t+1 - add effect of 'Q'
    #if n_police > 0:
    #    for t in range(T-1):
    #        for r in range(H):
    #            for c in range(W):
    #                clauses.append(tile('Q',r,c,t) + ' ==> ' + tile('Q',r,c,t+1))

    ## Sxy_t ==> Sxy_t-1 - precondition of 'S' if there's no sick tiles
    #for t in range(1,T):
    #    for r in range(H):
    #        for c in range(W):
    #            if T<3:
    #                clauses.append(tile('S',r,c,t) + ' ==> ' + tile('S',r,c,t-1))

    ## action-add effect of 'H' - if tile 'H' in (x,y,t) and no sick neighbors, then 'H' in (x,y,t+1)
    #for t in range(1,T):
    #    for r in range(H):
    #        for c in range(W):
    #            n_coords = get_neighbors(r,c,H,W)
    #            curr_clause = tile('H',r,c,t) + ' <== (' + tile('H',r,c,t-1)
    #            for coord in n_coords:
    #                curr_clause+= ' & ' + tile('~S',coord[0],coord[1],t-1)
    #            clauses.append(curr_clause + ')')

    ## Hxy_t ==> Hxy_t-1 - precondition of 'H' if there's no sick tiles
    #for t in range(1,T):
    #    for r in range(H):
    #        for c in range(W):
    #            if T<3:
    #                clauses.append(tile('H',r,c,t) + ' ==> ' + tile('H',r,c,t-1))

    # Hxy_t & medicsxy_t ==> Ixy_t+1 - add effect of 'H' if there's no sick
    #for t in range(T-1):
    #    for r in range(H):
    #        for c in range(W):
    #            clauses.append(tile('H',r,c,t) + ' & ' + tile('medics',r,c,t) + ' ==> ' + tile('I',r,c,t+1))
    #            # Hxy_t & medicsxy_t ==> ~Hxy_t+1 - del effect of 'H'
    #            clauses.append(tile('H',r,c,t) + ' & ' + tile('medics',r,c,t) + ' ==> ~' + tile('H',r,c,t+1))

    # a single tile can only contain one code, i.e. or 'H' or 'S' or 'U' or 'I' or 'Q'
    #exclude_clause = lambda code_1,code_2,r,c,t : '~{0}{1}{2}_{3} | ~{4}{5}{6}_{7}'.format(code_1,r,c,t,code_2,r,c,t)
    for t in range(T):
        for r in range(H):
            for c in range(W):
                literal_list = []
                for code in CODES:
                    literal_list.append('~' + tile(code, r, c, t))
                powerset_res = lim_powerset(literal_list, 2)
                for combo in powerset_res:
                    clauses.append(combo[0] + ' | ' + combo[1])
                    #CODES_reduced = []
                    #[CODES_reduced.append(code) if code != code_1 else '' for code in CODES]
                    #for code_2 in CODES_reduced:
                    #    clauses.append(exclude_clause(code_1,code_2,r,c,t))

    # medics is only valid for 'H' tiles
    if n_medics > 0:
        for code in CODES:
            for t in range(T):
                for r in range(H):
                    for c in range(W):
                        if code != 'H':
                            clauses.append(
                                tile(code, r, c, t) + ' ==> ' +
                                tile('~medics', r, c, t))

    if n_medics > 1:
        # medics has to be exactly n_medics times
        for t in range(T - 1):
            tile_coords = []
            for r in range(H):
                for c in range(W):
                    tile_coords.append(tile('medics', r, c, t))
            positive_tiles = lim_powerset(tile_coords, n_medics)
            curr_clause = '(('
            for combo in positive_tiles:
                for predicate in combo:
                    curr_clause += predicate + ' & '
                for curr_tile in tile_coords:
                    if curr_tile not in combo:
                        curr_clause += '~' + curr_tile + ' & '
                curr_clause = curr_clause[:-3] + ') | ('
            clauses.append(curr_clause[:-3] + ')')
    elif n_medics == 1:
        for t in range(T - 1):
            curr_clause = '('
            predicate_list = []
            for r in range(H):
                for c in range(W):
                    predicate_list.append(tile('~medics', r, c, t))
                    curr_clause += tile('medics', r, c, t) + ' | '
            clauses.append(curr_clause[:-3] + ')')
            powerset_res = lim_powerset(predicate_list, 2)
            for combo in powerset_res:
                clauses.append(combo[0] + ' | ' + combo[1])

    # police is only valid for 'S' tiles
    if n_police > 0:
        for code in CODES:
            for t in range(T):
                for r in range(H):
                    for c in range(W):
                        if code != 'S':
                            clauses.append(
                                tile(code, r, c, t) + ' ==> ' +
                                tile('~police', r, c, t))

    # police has to be exactly n_police times
    if n_police > 1:
        for t in range(T - 1):
            tile_coords = []
            for r in range(H):
                for c in range(W):
                    tile_coords.append(tile('police', r, c, t))
            positive_tiles = lim_powerset(tile_coords, n_police)
            curr_clause = '(('
            for combo in positive_tiles:
                for predicate in combo:
                    curr_clause += predicate + ' & '
                for curr_tile in tile_coords:
                    if curr_tile not in combo:
                        curr_clause += '~' + curr_tile + ' & '
                curr_clause = curr_clause[:-3] + ') | ('
            clauses.append(curr_clause[:-3] + ')')
    elif n_police == 1:
        for t in range(T - 1):
            curr_clause = '('
            predicate_list = []
            for r in range(H):
                for c in range(W):
                    predicate_list.append(tile('~police', r, c, t))
                    curr_clause += tile('police', r, c, t) + ' | '
            clauses.append(curr_clause[:-3] + ')')
            powerset_res = lim_powerset(predicate_list, 2)
            for combo in powerset_res:
                clauses.append(combo[0] + ' | ' + combo[1])

    clauses_in_cnf = all_clauses_in_cnf(clauses)
    clauses_in_pysat = all_clauses_in_pysat(clauses_in_cnf, vpool)
    s = Solver()
    s.append_formula(clauses_in_pysat)

    q_dict = dict()
    for q in queries:
        if VERBOSE:
            print('\n')
            print('Initial Observations')
            print_model(observations, T, H, W, n_police, n_medics)
            print('\n')
            print('Query')
            print(q)
            print('\n')
        res_list = []
        for code in CODES:
            clause_to_check = cnf_to_pysat(
                to_cnf(tile(code, q[0][0], q[0][1], q[1])), vpool)
            if s.solve(assumptions=clause_to_check):
                res_list.append(1)
                if VERBOSE:
                    print('Satisfiable for code=%s as follows:' % (code))
                    sat_observations = get_model(s.get_model(), vpool, T, H, W)
                    print_model(sat_observations, T, H, W, n_police, n_medics)
                    print('\n')
            else:
                res_list.append(0)
                if VERBOSE:
                    print('NOT Satisfiable for code=%s' % (code))
                    print('\n')
                    print(vpool.obj(s.get_core()[0]))
        if np.sum(res_list) == 1:
            if CODES[res_list.index(1)] == q[2]:
                q_dict[q] = 'T'
            else:
                q_dict[q] = 'F'
        else:
            q_dict[q] = '?'

    return q_dict
Esempio n. 6
0
#    for l in c:
#        print(vpool.obj(l))
lsu = LSU(wcnf1)
res = lsu.solve()
if not res:
    print("Hard constraints could not be satisfied")
else:
    #    print(lsu.cost)
    model = list(lsu.model)
    pos_lits = list(filter((lambda x: x > 0), model))
    unsatisfied_constraints = []
    ta_allocation = dict()
    tas_allocated = []
    for id in id2varmap.values():
        if id in pos_lits:
            (course_name, ta) = vpool.obj(id)
            tas_allocated.append(ta)
            if course_name not in ta_allocation.keys():
                talist = []
                ta_allocation[course_name] = talist

            if ":" not in ta:
                ta_allocation[course_name].append(ta)
        else:
            (course_name, ta) = vpool.obj(id)
            if ":" in ta:
                unsatisfied_constraints.append(ta)

    for course_name in ta_allocation.keys():
        print(course_name, " : ", str(len(ta_allocation[course_name])), " : ",
              ta_allocation[course_name])
Esempio n. 7
0
    def __init__(self, size, exhaustive=False, topv=0, verb=False):
        """
            Constructor.
        """

        # initializing CNF's internal parameters
        super(CB, self).__init__()

        # cell number
        cell = lambda i, j: (i - 1) * 2 * size + j

        # initializing the pool of variable ids
        vpool = IDPool(start_from=topv + 1)
        var = lambda c1, c2: vpool.id('edge: ({0}, {1})'.format(
            min(c1, c2), max(c1, c2)))

        for i in range(1, 2 * size + 1):
            for j in range(1, 2 * size + 1):
                adj = []

                # current cell
                c = cell(i, j)

                # removing first and last cells (they are white)
                if c in (1, 4 * size * size):
                    continue

                # each cell has 2 <= k <= 4 adjacents
                if i > 1 and cell(i - 1, j) != 1:
                    adj.append(var(c, cell(i - 1, j)))

                if j > 1 and cell(i, j - 1) != 1:
                    adj.append(var(c, cell(i, j - 1)))

                if i < 2 * size and cell(i + 1, j) != 4 * size * size:
                    adj.append(var(c, cell(i + 1, j)))

                if j < 2 * size and cell(i, j + 1) != 4 * size * size:
                    adj.append(var(c, cell(i, j + 1)))

                if not adj:  # when n == 1, no clauses will be added
                    continue

                # adding equals1 constraint for black and white cells
                if exhaustive:
                    cnf = CardEnc.equals(lits=adj,
                                         bound=1,
                                         encoding=EncType.pairwise)
                    self.extend(cnf.clauses)
                else:
                    # atmost1 constraint for white cells
                    if i % 2 and c % 2 or i % 2 == 0 and c % 2 == 0:
                        am1 = CardEnc.atmost(lits=adj,
                                             bound=1,
                                             encoding=EncType.pairwise)
                        self.extend(am1.clauses)
                    else:  # atleast1 constrant for black cells
                        self.append(adj)

        if verb:
            head = 'c CB formula for the chessboard of size {0}x{0}'.format(
                2 * size)
            head += '\nc The encoding is {0}exhaustive'.format(
                '' if exhaustive else 'not ')
            self.comments.append(head)

            for v in range(1, vpool.top + 1):
                self.comments.append('c {0}; bool var: {1}'.format(
                    vpool.obj(v), v))
Esempio n. 8
0
class CoreOracle(Solver):
    """
        This class is for storing the dependencies between unsatisfiable cores
        detected by RC2. It can be used to determine the cores that can be
        reused given the current assumption literals.
    """
    def __init__(self, name='m22'):
        """
            Initializer.
        """

        # first, calling base class method
        super(CoreOracle, self).__init__(name=name)

        # we are going to redefine the variables so that there are no conflicts
        self.pool = IDPool(start_from=1)

        # this is a global selector; all clauses should have it
        self.selv = self.pool.id()

        # here are all the known sum literals
        self.lits = set([])

    def delete(self):
        """
            Destructor.
        """

        # first, calling base class method
        super(CoreOracle, self).delete()

        # setting the vars to None
        self.pool, self.selv, self.lits = None, None, None

    def record(self, core, slit):
        """
            Record a new fact (core -> slit). The "core" lits must be already
            negated.
        """

        # translating the literals into internal representation
        cl = [int(copysign(self.pool.id(abs(l)), l)) for l in core + [slit]]

        # adding the clause
        self.add_clause([-self.selv] + cl)

        # storing the sum for future filtering
        self.lits.add(int(copysign(self.pool.id(abs(slit)), slit)))

    def get_active(self, assumps):
        """
            Check what cores are propagated given a list of assumptions.
        """

        # translating assumptions into internal representation
        assumps = [int(copysign(self.pool.id(abs(l)), l)) for l in assumps]

        # doing the actual propagation
        st, props = self.propagate(assumptions=[self.selv] + assumps,
                                   phase_saving=2)
        assert st, 'Something is wrong. The core-deps formula is unsatisfiable'

        # processing literals and returning the result; note
        # that literals must remain in the right order here
        return tuple(
            map(lambda l: int(copysign(self.pool.obj(abs(l)), l)),
                filter(lambda l: l in self.lits, props[1:])))