Example #1
0
    def __init__(self,
                 bootstrap_with=[],
                 costs=[],
                 solver='g3',
                 htype='sorted'):
        """
            Constructor.
        """

        # hitting set solver
        self.oracle = None

        # name of SAT solver
        self.solver = solver

        # hitman type: either a MaxSAT solver or an MCS enumerator
        if htype in ('maxsat', 'mxsat', 'rc2', 'sorted'):
            self.htype = 'rc2'
        elif htype in ('mcs', 'lbx'):
            self.htype = 'lbx'
        else:  # 'mcsls'
            self.htype = 'mcsls'

        # pool of variable identifiers (for objects to hit)
        self.idpool = IDPool()

        # initialize hitting set solver
        self.init(bootstrap_with, costs)
def get_unsat_core_pysat(fmla, alpha=None):
    n_vars = fmla.nv
    vpool = IDPool(start_from=n_vars + 1)
    r = lambda i: vpool.id(i)
    new_fmla = fmla.copy()
    num_clauses = len(new_fmla.clauses)
    for count in list(range(0, num_clauses)):
        new_fmla.clauses[count].append(r(count))  # add r_i to the ith clause
    s = Solver(name="cdl")
    s.append_formula(new_fmla)
    asms = [-r(i) for i in list(range(0, num_clauses))]
    if alpha is not None:
        asms = asms + alpha
    if not s.solve(assumptions=asms):
        core_aux = s.get_core()
    else:  # TODO(jesse): better error handling
        raise Exception("formula is sat")
    # return list(filter(lambda x: x is not None, [vpool.obj(abs(r)) for r in core_aux]))
    result = []
    bad_asms = []
    for lit in core_aux:
        if abs(lit) > n_vars:
            result.append(vpool.obj(abs(lit)))
        else:
            bad_asms.append(lit)
    return result, bad_asms
Example #3
0
    def __init__(self, size, topv=0, verb=False):
        """
            Constructor.
        """

        # initializing CNF's internal parameters
        super(Parity, self).__init__()

        # initializing the pool of variable ids
        vpool = IDPool(start_from=topv + 1)
        var = lambda i, j: vpool.id('v_{0}_{1}'.format(min(i, j), max(i, j)))

        for i in range(1, 2 * size + 2):
            self.append([var(i, j) for j in range(1, 2 * size + 2) if j != i])

        for j in range(1, 2 * size + 2):
            for i, k in itertools.combinations(range(1, 2 * size + 2), 2):
                if i == j or k == j:
                    continue

                self.append([-var(i, j), -var(k, j)])

        if verb:
            self.comments.append(
                'c Parity formula for m == {0} ({1} vertices)'.format(
                    size, 2 * size + 1))
            for i in range(1, 2 * size + 2):
                for j in range(i + 1, 2 * size + 2):
                    self.comments.append('c edge: {0}; bool var: {1}'.format(
                        (i, j), var(i, j)))
Example #4
0
    def __init__(self, nof_holes, kval=1, topv=0, verb=False):
        """
            Constructor.
        """

        # initializing CNF's internal parameters
        super(PHP, self).__init__()

        # initializing the pool of variable ids
        vpool = IDPool(start_from=topv + 1)
        var = lambda i, j: vpool.id('v_{0}_{1}'.format(i, j))

        # placing all pigeons into holes
        for i in range(1, kval * nof_holes + 2):
            self.append([var(i, j) for j in range(1, nof_holes + 1)])

        # there cannot be more than k pigeons in a hole
        pigeons = range(1, kval * nof_holes + 2)
        for j in range(1, nof_holes + 1):
            for comb in itertools.combinations(pigeons, kval + 1):
                self.append([-var(i, j) for i in comb])

        if verb:
            head = 'c {0}PHP formula for'.format('' if kval ==
                                                 1 else str(kval) + '-')
            head += ' {0} pigeons and {1} holes'.format(
                kval * nof_holes + 1, nof_holes)
            self.comments.append(head)

            for i in range(1, kval * nof_holes + 2):
                for j in range(1, nof_holes + 1):
                    self.comments.append(
                        'c (pigeon, hole) pair: ({0}, {1}); bool var: {2}'.
                        format(i, j, var(i, j)))
Example #5
0
    def __init__(self, formula, feats, nof_classes, xgb):
        """
            Constructor.
        """

        self.ftids = {f: i for i, f in enumerate(feats)}
        self.nofcl = nof_classes
        self.idmgr = IDPool()
        self.optns = xgb.options

        # xgbooster will also be needed
        self.xgb = xgb

        self.verbose = self.optns.verb
        self.oracle = Solver(name=self.xgb.options.solver)

        self.inps = []  # input (feature value) variables
        for f in self.xgb.extended_feature_names_as_array_strings:
            if '_' not in f:
                self.inps.append(Symbol(f, typename=REAL))
            else:
                self.inps.append(Symbol(f, typename=BOOL))

        self.outs = []  # output (class  score) variables
        for c in range(self.nofcl):
            self.outs.append(Symbol('class{0}_score'.format(c), typename=REAL))

        # theory
        self.oracle.add_assertion(formula)

        # current selector
        self.selv = None
Example #6
0
    def __init__(self, bootstrap_with=[], weights=None, subject_to=[],
            solver='g3', htype='sorted', mxs_adapt=False, mxs_exhaust=False,
            mxs_minz=False, mxs_trim=0, mcs_usecld=False):
        """
            Constructor.
        """

        # hitting set solver
        self.oracle = None

        # name of SAT solver
        self.solver = solver

        # various oracle options
        self.adapt    = mxs_adapt
        self.exhaust  = mxs_exhaust
        self.minz     = mxs_minz
        self.trim     = mxs_trim
        self.usecld   = mcs_usecld

        # hitman type: either a MaxSAT solver or an MCS enumerator
        if htype in ('maxsat', 'mxsat', 'rc2', 'sorted'):
            self.htype = 'rc2'
        elif htype in ('mcs', 'lbx'):
            self.htype = 'lbx'
        else:  # 'mcsls'
            self.htype = 'mcsls'

        # pool of variable identifiers (for objects to hit)
        self.idpool = IDPool()

        # initialize hitting set solver
        self.init(bootstrap_with, weights=weights, subject_to=subject_to)
Example #7
0
def solve_problem(input_):
    initial = input_
    vpool = IDPool()
    var = lambda t, pos, turn: vpool.id(f'{t}_({pos[0]},{pos[1]})_{turn}')
    cnf = _build_clauses(initial=initial, var=var, vpool=vpool)
    solution = sat_solver(cnf=cnf, queries=initial['queries'], vpool=vpool, var=var)
    return solution
Example #8
0
    def dominating_subset(self, k=1):
        """
        Check if there exists a vertex cover of, at most, k-vertices.
        Accepts as params:
        - n_color: number of color to check
        - verbose: whether or not print the process
        """
        if not self.edges():
            return []

        logging.info('\nCodifying SAT Solver...')
        solver = Solver(name='cd')
        vpool = IDPool()
        vertices_ids = [vpool.id(vertex) for vertex in self.vertices()]

        logging.info(' -> Codifying: Every vertex must be accessible')
        for vertex in self.vertices():
            solver.add_clause([vpool.id(vertex)] + [
                vpool.id(adjacent_vertex) for adjacent_vertex in self[vertex]
            ])

        logging.info(' -> Codifying: At most', k,
                     'vertices should be selected')

        cnf = CardEnc.atmost(lits=vertices_ids, bound=k, vpool=vpool)
        solver.append_formula(cnf)

        logging.info('Running SAT Solver...')
        return solver.solve()
Example #9
0
def gen_constraints(idpool: IDPool, id2varmap, courses: tCourses,
                    constraints: List[tConstraint]) -> WCNF:
    """ Generate complete formula for all the constraints including conflicting course constraints"""
    wcnf = gen_constraint_conflict_courses(idpool, id2varmap, courses)
    for con in constraints:
        cnf = get_constraint(idpool, id2varmap, con)
        """ if the constraint is not hard, add an auxiliary variable and keep only this 
            auxiliary variable as soft. This is to allow displaying to the user which high
            level constraint specified by the user was satisfied
        """
        if not con.ishard:
            t1 = tuple((con.course_name, con.course_name + "->" + con.con_str))
            if t1 not in id2varmap:
                id2varmap[t1] = idpool.id(t1)
            id1 = idpool.id(t1)
            clauses = cnf.clauses.copy()
            for c in clauses:
                c.append(-id1)
                wcnf.append(c)
            c = []
            c.append(id1)
            wcnf.append(c, soft_weight)
        else:
            clauses = cnf.clauses.copy()
            for c in clauses:
                wcnf.append(c)
    return wcnf
Example #10
0
class VarPool:

    def __init__(self) -> None:
        self._vpool = IDPool()

    def var(self, name: str, ind1, ind2=0, ind3=0) -> int:
        return self._vpool.id(f'{name}_{ind1}_{ind2}_{ind3}')

    def var_name(self, id_: int):
        return self._vpool.obj(id_)
Example #11
0
def get_all_problem_variables(observations, rows_num, cols_num):
    T = len(observations)

    index_by_variable = IDPool()

    for t in range(T):
        for row in range(rows_num):
            for col in range(cols_num):
                for state in possible_states:
                    index_by_variable.id(f'{state}_{row}_{col}_{t}')

    return index_by_variable
Example #12
0
    def __init__(self, solver_input):
        self.num_police = solver_input['police']
        self.num_medics = solver_input['medics']
        self.observations = solver_input['observations']
        self.num_turns = len(self.observations)  # TODO maybe max on queries
        self.height = len(self.observations[0])
        self.width = len(self.observations[0][0])
        self.vpool = IDPool()
        self.tiles = [(i, j) for i in range(self.height)
                      for j in range(self.width)]

        self.clauses = self.generate_clauses()
Example #13
0
def get_constraint(idpool: IDPool, id2varmap,
                   constraint: tConstraint) -> CNFPlus:
    """ Generate formula for a given cardinality constraint"""
    validate_constraint(constraint)
    lits = []
    for ta in constraint.tas:
        t1 = tuple((constraint.course_name, ta))
        if t1 not in id2varmap.keys():
            id1 = idpool.id(t1)
            id2varmap[t1] = id1
        else:
            id1 = id2varmap[t1]
        lits.append(id1)

    if constraint.type == tCardType.GREATEROREQUALS:
        if (constraint.bound == 1):
            cnf = CNFPlus()
            cnf.append(lits)
        elif (constraint.bound > len(lits)):
            msg = "Num TAs available for constraint:" + constraint.con_str + "is more than the bound in the constraint. \
            Changing the bound to " + str(len(lits)) + ".\n"
            print(msg, file=sys.stderr)
            constraint.bound = len(lits)

        cnf = CardEnc.atleast(lits, vpool=idpool, bound=constraint.bound)
    elif constraint.type == tCardType.LESSOREQUALS:
        cnf = CardEnc.atmost(lits, vpool=idpool, bound=constraint.bound)
    return cnf
Example #14
0
    def __init__(self, name='m22'):
        """
            Initializer.
        """

        # first, calling base class method
        super(CoreOracle, self).__init__(name=name)

        # we are going to redefine the variables so that there are no conflicts
        self.pool = IDPool(start_from=1)

        # this is a global selector; all clauses should have it
        self.selv = self.pool.id()

        # here are all the known sum literals
        self.lits = set([])
Example #15
0
def test_atmost():
    vp = IDPool()
    n = 20
    b = 50
    assert n <= b

    lits = [vp.id(v) for v in range(1, n + 1)]
    top = vp.top

    G = CardEnc.atmost(lits, b, vpool=vp)

    assert len(G.clauses) == 0

    try:
        assert vp.top >= top
    except AssertionError as e:
        print(f"\nvp.top = {vp.top} (expected >= {top})\n")
        raise e
Example #16
0
    def coloring(self, n_color):
        """
        Returns whether or not there exists a vertex coloring
        of, at most, n_color colors.

        Accepts one param:
        - n_color: number of color to check

        Might raise ValueError exception.
        """
        if n_color < 0:
            raise ValueError('Number of colors must be positive integer')

        if n_color == 0:
            return not bool(self.vertices())

        logging.info('\nCodifying SAT Solver...')
        solver = Solver(name='cd')
        vpool = IDPool()

        logging.info(
            ' -> Codifying: Every vertex must have a color, and only one')
        for vertex in self.vertices():
            cnf = CardEnc.equals(lits=[
                vpool.id('{}color{}'.format(vertex, color))
                for color in range(n_color)
            ],
                                 vpool=vpool,
                                 encoding=0)

            solver.append_formula(cnf)

        logging.info(
            ' -> Codifying: No two neighbours can have the same color')
        for vertex in self.vertices():
            for neighbour in self[vertex]:
                for color in range(n_color):
                    solver.add_clause([
                        -vpool.id('{}color{}'.format(vertex, color)),
                        -vpool.id('{}color{}'.format(neighbour, color))
                    ])

        logging.info('Running SAT Solver...')
        return solver.solve()
Example #17
0
    def __init__(self, model, feats, nof_classes, xgb, from_file=None):
        """
            Constructor.
        """

        self.model = model
        self.feats = {f: i for i, f in enumerate(feats)}
        self.nofcl = nof_classes
        self.idmgr = IDPool()
        self.optns = xgb.options

        # xgbooster will also be needed
        self.xgb = xgb

        # for interval-based encoding
        self.intvs, self.imaps, self.ivars = None, None, None

        if from_file:
            self.load_from(from_file)
Example #18
0
    def __init__(self, formula, intvs, imaps, ivars, feats, nof_classes,
                 options, xgb):
        """
            Constructor.
        """

        self.feats = feats
        self.intvs = intvs
        self.imaps = imaps
        self.ivars = ivars
        self.nofcl = nof_classes
        self.optns = options
        self.idmgr = IDPool()

        # saving XGBooster
        self.xgb = xgb

        self.verbose = self.optns.verb
        self.oracle = Solver(name=options.solver)

        self.inps = []  # input (feature value) variables
        for f in self.xgb.extended_feature_names_as_array_strings:
            if '_' not in f:
                self.inps.append(Symbol(f, typename=REAL))
            else:
                self.inps.append(Symbol(f, typename=BOOL))

        self.outs = []  # output (class  score) variables
        for c in range(self.nofcl):
            self.outs.append(Symbol('class{0}_score'.format(c), typename=REAL))

        # theory
        self.oracle.add_assertion(formula)

        # current selector
        self.selv = None

        # save and use dual explanations whenever needed
        self.dualx = []

        # number of oracle calls involved
        self.calls = 0
Example #19
0
    def __init__(self, inputs):
        # unpack inputs
        self.police = inputs["police"]
        self.medics = inputs["medics"]
        self.observations = inputs["observations"]
        self.queries = inputs["queries"]

        # auxiliary variables
        self.t_max = len(self.observations) - 1
        self.num_observations = len(self.observations)
        self.rows = len(self.observations[0])
        self.cols = len(self.observations[0][0])
        self.num_tiles = self.rows * self.cols
        self.tiles = {(i, j)
                      for j in range(self.cols) for i in range(self.rows)}

        # create predicates
        self.pool = IDPool()
        self.fill_predicates()
        self.obj2id = self.pool.obj2id
Example #20
0
    def __init__(self,
                 formula,
                 solver='g4',
                 pb_enc_type=EncType.best,
                 expect_interrupt=False,
                 verbose=0):
        """
            Constructor.
        """

        self.verbose = verbose
        self.solver = solver
        self.pb_enc_type = pb_enc_type
        self.expect_interrupt = expect_interrupt
        self.formula = formula
        self.vpool = IDPool(occupied=[
            (1, formula.nv)
        ])  # variable pool used for managing card/PB encodings
        self.sels = []  # soft clause selector variables
        self.is_weighted = False  # auxiliary flag indicating if it's a weighted problem
        self.tot = None  # totalizer encoder for the cardinality constraint
        self._init(formula)  # initialize SAT oracle
Example #21
0
    def __init__(self, size, topv=0, verb=False):
        """
            Constructor.
        """

        # initializing CNF's internal parameters
        super(GT, self).__init__()

        # initializing the pool of variable ids
        vpool = IDPool(start_from=topv + 1)
        var = lambda i, j: vpool.id('v_{0}_{1}'.format(i, j))

        # anti-symmetric relation clauses
        for i in range(1, size):
            for j in range(i + 1, size + 1):
                self.append([-var(i, j), -var(j, i)])

        # transitive relation clauses
        for i in range(1, size + 1):
            for j in range(1, size + 1):
                if j != i:
                    for k in range(1, size + 1):
                        if k != i and k != j:
                            self.append([-var(i, j), -var(j, k), var(i, k)])

        # successor clauses
        for j in range(1, size + 1):
            self.append([var(k, j) for k in range(1, size + 1) if k != j])

        if verb:
            self.comments.append('c GT formula for {0} elements'.format(size))
            for i in range(1, size + 1):
                for j in range(1, size + 1):
                    if i != j:
                        self.comments.append(
                            'c orig pair: {0}; bool var: {1}'.format((i, j),
                                                                     var(i,
                                                                         j)))
Example #22
0
def get_clause(t, c):
    p = CNF()
    vpool = IDPool()
    # add numbers that are prefilled and add all boolean variables to the pool of used variables
    for i in range(n):
        for j in range(n):
            for z in range(1, n + 1):
                vpool.id('v{0}'.format(s(i, j, z)))
            if t[i][j] != "_":
                p.extend(
                    PBEnc.equals(lits=[s(i, j, t[i][j])], bound=1,
                                 vpool=vpool).clauses)
    # ensure there is at least one value per square
    for x in range(n):
        for y in range(n):
            lits = list(map(lambda z: s(x, y, z), range(1, n + 1)))
            p.extend(PBEnc.atleast(lits=lits, bound=1, vpool=vpool).clauses)
    # ensure there exists only 1 of each value in each row and column
    for z in range(1, n + 1):
        for a in range(n):
            lits_row = list(map(lambda b: s(a, b, z), range(n)))
            lits_col = list(map(lambda b: s(b, a, z), range(n)))
            p.extend(PBEnc.equals(lits=lits_row, bound=1, vpool=vpool).clauses)
            p.extend(PBEnc.equals(lits=lits_col, bound=1, vpool=vpool).clauses)
    # ensure inequalities hold
    for x in c:
        (a, b) = x
        (i1, j1) = a
        (i2, j2) = b
        lits = list(map(lambda z: s(i1, j1, z), range(1, n + 1))) + \
               list(map(lambda z: s(i2, j2, z), range(1, n + 1)))
        weights = list(range(1, n + 1)) + list(range(-1, -n - 1, -1))
        p.extend(
            PBEnc.atleast(lits=lits, weights=weights, bound=1,
                          vpool=vpool).clauses)
    return p
Example #23
0
def gen_constraint_conflict_courses(idpool: IDPool, id2varmap,
                                    courses: tCourses) -> WCNF:
    """ Generate a constraint that two conflicting courses can not share TAs"""
    wcnf = WCNF()
    conflict_courses = compute_conflict_courses(courses)
    for course in conflict_courses.keys():
        for ccourse in conflict_courses[course]:
            for t in courses[course].tas_available:
                if t in courses[ccourse].tas_available:
                    t1 = tuple((course, t))
                    t2 = tuple((ccourse, t))
                    id1 = idpool.id(t1)
                    id2 = idpool.id(t2)
                    if t1 not in id2varmap.keys():
                        id2varmap[t1] = id1
                    if t2 not in id2varmap.keys():
                        id2varmap[t2] = id2
                    wcnf.append([-id1, -id2])
    return wcnf
Example #24
0
class LSU:
    """
        Linear SAT-UNSAT algorithm for MaxSAT [1]_. The algorithm can be seen
        as a series of satisfiability oracle calls refining an upper bound on
        the MaxSAT cost, followed by one unsatisfiability call, which stops the
        algorithm. The implementation encodes the sum of all selector literals
        using the *iterative totalizer encoding* [2]_. At every iteration, the
        upper bound on the cost is reduced and enforced by adding the
        corresponding unit size clause to the working formula. No clauses are
        removed during the execution of the algorithm. As a result, the SAT
        oracle is used incrementally.

        .. warning:: At this point, :class:`LSU` supports only
            **unweighted** problems.

        The constructor receives an input :class:`.WCNF` formula, a name of the
        SAT solver to use (see :class:`.SolverNames` for details), and an
        integer verbosity level.

        :param formula: input MaxSAT formula
        :param solver: name of SAT solver
        :param pb_enc_type: PB encoding type to use for solving weighted problems
        :param expect_interrupt: whether or not an :meth:`interrupt` call is expected
        :param verbose: verbosity level

        :type formula: :class:`.WCNF`
        :type solver: str
        :type expect_interrupt: bool
        :type verbose: int
    """
    def __init__(self,
                 formula,
                 solver='g4',
                 pb_enc_type=EncType.best,
                 expect_interrupt=False,
                 verbose=0):
        """
            Constructor.
        """

        self.verbose = verbose
        self.solver = solver
        self.pb_enc_type = pb_enc_type
        self.expect_interrupt = expect_interrupt
        self.formula = formula
        self.vpool = IDPool(occupied=[
            (1, formula.nv)
        ])  # variable pool used for managing card/PB encodings
        self.sels = []  # soft clause selector variables
        self.is_weighted = False  # auxiliary flag indicating if it's a weighted problem
        self.tot = None  # totalizer encoder for the cardinality constraint
        self._init(formula)  # initialize SAT oracle

    def _init(self, formula):
        """
            SAT oracle initialization. The method creates a new SAT oracle and
            feeds it with the formula's hard clauses. Afterwards, all soft
            clauses of the formula are augmented with selector literals and
            also added to the solver. The list of all introduced selectors is
            stored in variable ``self.sels``.

            :param formula: input MaxSAT formula
            :type formula: :class:`WCNF`
        """

        self.oracle = Solver(name=self.solver,
                             bootstrap_with=formula.hard,
                             incr=True,
                             use_timer=True)

        for i, cl in enumerate(formula.soft):
            # TODO: if clause is unit, use its literal as selector
            # (ITotalizer must be extended to support PB constraints first)
            selv = self.vpool._next()
            cl.append(selv)
            self.oracle.add_clause(cl)
            self.sels.append(selv)
        self.is_weighted = any(w > 1 for w in formula.wght)

        if self.verbose > 1:
            print('c formula: {0} vars, {1} hard, {2} soft'.format(
                formula.nv, len(formula.hard), len(formula.soft)))

    def __del__(self):
        """
            Destructor.
        """

        self.delete()

    def __enter__(self):
        """
            'with' constructor.
        """

        return self

    def __exit__(self, exc_type, exc_value, traceback):
        """
            'with' destructor.
        """

        self.delete()

    def delete(self):
        """
            Explicit destructor of the internal SAT oracle and the
            :class:`.ITotalizer` object.
        """

        if self.oracle:
            self.oracle.delete()
            self.oracle = None

        if self.tot:
            self.tot.delete()
            self.tot = None

    def solve(self):
        """
            Computes a solution to the MaxSAT problem. The method implements
            the LSU/LSUS algorithm, i.e. it represents a loop, each iteration
            of which calls a SAT oracle on the working MaxSAT formula and
            refines the upper bound on the MaxSAT cost until the formula
            becomes unsatisfiable.

            Returns ``True`` if the hard part of the MaxSAT formula is
            satisfiable, i.e. if there is a MaxSAT solution, and ``False``
            otherwise.

            :rtype: bool
        """

        is_sat = False

        while self.oracle.solve_limited(
                expect_interrupt=self.expect_interrupt):
            is_sat = True
            self.model = self.oracle.get_model()
            self.cost = self._get_model_cost(self.formula, self.model)
            if self.verbose:
                print('o {0}'.format(self.cost))
                sys.stdout.flush()
            if self.cost == 0:  # if cost is 0, then model is an optimum solution
                break
            self._assert_lt(self.cost)

        if is_sat:
            self.model = filter(lambda l: abs(l) <= self.formula.nv,
                                self.model)
            if self.verbose:
                if self.found_optimum():
                    print('s OPTIMUM FOUND')
                else:
                    print('s SATISFIABLE')
        elif self.verbose:
            print('s UNSATISFIABLE')

        return is_sat

    def get_model(self):
        """
            This method returns a model obtained during a prior satisfiability
            oracle call made in :func:`solve`.

            :rtype: list(int)
        """

        return self.model

    def found_optimum(self):
        """
            Checks if the optimum solution was found in a prior call to
            :func:`solve`.

            :rtype: bool
        """

        return self.oracle.get_status() is not None

    def _get_model_cost(self, formula, model):
        """
            Given a WCNF formula and a model, the method computes the MaxSAT
            cost of the model, i.e. the sum of weights of soft clauses that are
            unsatisfied by the model.

            :param formula: an input MaxSAT formula
            :param model: a satisfying assignment

            :type formula: :class:`.WCNF`
            :type model: list(int)

            :rtype: int
        """

        model_set = set(model)
        cost = 0

        for cl, w in zip(formula.soft, formula.wght):
            cost += w if all(l not in model_set for l in filter(
                lambda l: abs(l) <= self.formula.nv, cl)) else 0

        return cost

    def _assert_lt(self, cost):
        """
            The method enforces an upper bound on the cost of the MaxSAT
            solution. For unweighted problems, this is done by encoding the sum
            of all soft clause selectors with the use the iterative totalizer
            encoding, i.e. :class:`.ITotalizer`. Note that the sum is created
            once, at the beginning. Each of the following calls to this method
            only enforces the upper bound on the created sum by adding the
            corresponding unit size clause. For weighted problems, the PB
            encoding given through the :meth:`__init__` method is used.
            Each such clause is added on the fly with no restart of the
            underlying SAT oracle.

            :param cost: the cost of the next MaxSAT solution is enforced to be
                *lower* than this current cost

            :type cost: int
        """

        if self.is_weighted:
            # TODO: use incremental PB encoding
            self.oracle.append_formula(
                PBEnc.leq(self.sels,
                          weights=self.formula.wght,
                          bound=cost - 1,
                          vpool=self.vpool))
        else:

            if self.tot is None:
                self.tot = ITotalizer(lits=self.sels,
                                      ubound=cost - 1,
                                      top_id=self.vpool.top)
                self.vpool.top = self.tot.top_id

                for cl in self.tot.cnf.clauses:
                    self.oracle.add_clause(cl)

            self.oracle.add_clause([-self.tot.rhs[cost - 1]])

    def interrupt(self):
        """
            Interrupt the current execution of LSU's :meth:`solve` method.
            Can be used to enforce time limits using timer objects. The
            interrupt must be cleared before running the LSU algorithm again
            (see :meth:`clear_interrupt`).
        """

        self.oracle.interrupt()

    def clear_interrupt(self):
        """
            Clears an interruption.
        """

        self.oracle.clear_interrupt()

    def oracle_time(self):
        """
            Method for calculating and reporting the total SAT solving time.
        """

        return self.oracle.time_accum()
Example #25
0
class SMTEncoder(object):
    """
        Encoder of XGBoost tree ensembles into SMT.
    """
    def __init__(self, model, feats, nof_classes, xgb, from_file=None):
        """
            Constructor.
        """

        self.model = model
        self.feats = {f: i for i, f in enumerate(feats)}
        self.nofcl = nof_classes
        self.idmgr = IDPool()
        self.optns = xgb.options

        # xgbooster will also be needed
        self.xgb = xgb

        # for interval-based encoding
        self.intvs, self.imaps, self.ivars = None, None, None

        if from_file:
            self.load_from(from_file)

    def traverse(self, tree, tvar, prefix=[]):
        """
            Traverse a tree and encode each node.
        """

        if tree.children:
            pos, neg = self.encode_node(tree)

            self.traverse(tree.children[0], tvar, prefix + [pos])
            self.traverse(tree.children[1], tvar, prefix + [neg])
        else:  # leaf node
            if prefix:
                self.enc.append(
                    Implies(And(prefix), Equals(tvar, Real(tree.values))))
            else:
                self.enc.append(Equals(tvar, Real(tree.values)))

    def encode_node(self, node):
        """
            Encode a node of a tree.
        """

        if '_' not in node.name:
            # continuous features => expecting an upper bound
            # feature and its upper bound (value)
            f, v = node.name, node.threshold

            existing = True if tuple([f, v]) in self.idmgr.obj2id else False
            vid = self.idmgr.id(tuple([f, v]))
            bv = Symbol('bvar{0}'.format(vid), typename=BOOL)

            if not existing:
                if self.intvs:
                    d = self.imaps[f][v] + 1
                    pos, neg = self.ivars[f][:d], self.ivars[f][d:]
                    self.enc.append(Iff(bv, Or(pos)))
                    self.enc.append(Iff(Not(bv), Or(neg)))
                else:
                    fvar, fval = Symbol(f, typename=REAL), Real(v)
                    self.enc.append(Iff(bv, LT(fvar, fval)))

            return bv, Not(bv)
        else:
            # all features are expected to be categorical and
            # encoded with one-hot encoding into Booleans
            # each node is expected to be of the form: f_i < 0.5
            bv = Symbol(node.name, typename=BOOL)

            # left branch is positive,  i.e. bv is true
            # right branch is negative, i.e. bv is false
            return Not(bv), bv

    def compute_intervals(self):
        """
            Traverse all trees in the ensemble and extract intervals for each
            feature.

            At this point, the method only works for numerical datasets!
        """
        def traverse_intervals(tree):
            """
                Auxiliary function. Recursive tree traversal.
            """

            if tree.children:
                f = tree.name
                v = tree.threshold
                self.intvs[f].add(v)

                traverse_intervals(tree.children[0])
                traverse_intervals(tree.children[1])

        # initializing the intervals
        self.intvs = {
            'f{0}'.format(i): set([])
            for i in range(len(self.feats))
        }

        for tree in self.ensemble.trees:
            traverse_intervals(tree)

        # OK, we got all intervals; let's sort the values
        self.intvs = {
            f: sorted(self.intvs[f]) + ['+']
            for f in six.iterkeys(self.intvs)
        }

        self.imaps, self.ivars = {}, {}
        for feat, intvs in six.iteritems(self.intvs):
            self.imaps[feat] = {}
            self.ivars[feat] = []
            for i, ub in enumerate(intvs):
                self.imaps[feat][ub] = i

                ivar = Symbol(name='{0}_intv{1}'.format(feat, i),
                              typename=BOOL)
                self.ivars[feat].append(ivar)

    def encode(self):
        """
            Do the job.
        """

        self.enc = []

        # getting a tree ensemble
        self.ensemble = TreeEnsemble(
            self.model,
            self.xgb.extended_feature_names_as_array_strings,
            nb_classes=self.nofcl)

        # introducing class score variables
        csum = []
        for j in range(self.nofcl):
            cvar = Symbol('class{0}_score'.format(j), typename=REAL)
            csum.append(tuple([cvar, []]))

        # if targeting interval-based encoding,
        # traverse all trees and extract all possible intervals
        # for each feature
        if self.optns.encode == 'smtbool':
            self.compute_intervals()

        # traversing and encoding each tree
        for i, tree in enumerate(self.ensemble.trees):
            # getting class id
            clid = i % self.nofcl

            # encoding the tree
            tvar = Symbol('tr{0}_score'.format(i + 1), typename=REAL)
            self.traverse(tree, tvar, prefix=[])

            # this tree contributes to class with clid
            csum[clid][1].append(tvar)

        # encoding the sums
        for pair in csum:
            cvar, tvars = pair
            self.enc.append(Equals(cvar, Plus(tvars)))

        # enforce exactly one of the feature values to be chosen
        # (for categorical features)
        categories = collections.defaultdict(lambda: [])
        for f in self.xgb.extended_feature_names_as_array_strings:
            if '_' in f:
                categories[f.split('_')[0]].append(
                    Symbol(name=f, typename=BOOL))
        for c, feats in six.iteritems(categories):
            self.enc.append(ExactlyOne(feats))

        # number of assertions
        nof_asserts = len(self.enc)

        # making conjunction
        self.enc = And(self.enc)

        # number of variables
        nof_vars = len(self.enc.get_free_variables())

        if self.optns.verb:
            print('encoding vars:', nof_vars)
            print('encoding asserts:', nof_asserts)

        return self.enc, self.intvs, self.imaps, self.ivars

    def test_sample(self, sample):
        """
            Check whether or not the encoding "predicts" the same class
            as the classifier given an input sample.
        """

        # first, compute the scores for all classes as would be
        # predicted by the classifier

        # score arrays computed for each class
        csum = [[] for c in range(self.nofcl)]

        if self.optns.verb:
            print('testing sample:', list(sample))

        sample_internal = list(self.xgb.transform(sample)[0])

        # traversing all trees
        for i, tree in enumerate(self.ensemble.trees):
            # getting class id
            clid = i % self.nofcl

            # a score computed by the current tree
            score = scores_tree(tree, sample_internal)

            # this tree contributes to class with clid
            csum[clid].append(score)

        # final scores for each class
        cscores = [sum(scores) for scores in csum]

        # second, get the scores computed with the use of the encoding

        # asserting the sample
        hypos = []

        if not self.intvs:
            for i, fval in enumerate(sample_internal):
                feat, vid = self.xgb.transform_inverse_by_index(i)
                fid = self.feats[feat]

                if vid == None:
                    fvar = Symbol('f{0}'.format(fid), typename=REAL)
                    hypos.append(Equals(fvar, Real(float(fval))))
                else:
                    fvar = Symbol('f{0}_{1}'.format(fid, vid), typename=BOOL)
                    if int(fval) == 1:
                        hypos.append(fvar)
                    else:
                        hypos.append(Not(fvar))
        else:
            for i, fval in enumerate(sample_internal):
                feat, _ = self.xgb.transform_inverse_by_index(i)
                feat = 'f{0}'.format(self.feats[feat])

                # determining the right interval and the corresponding variable
                for ub, fvar in zip(self.intvs[feat], self.ivars[feat]):
                    if ub == '+' or fval < ub:
                        hypos.append(fvar)
                        break
                else:
                    assert 0, 'No proper interval found for {0}'.format(feat)

        # now, getting the model
        escores = []
        model = get_model(And(self.enc, *hypos), solver_name=self.optns.solver)
        for c in range(self.nofcl):
            v = Symbol('class{0}_score'.format(c), typename=REAL)
            escores.append(float(model.get_py_value(v)))

        assert all(map(lambda c, e: abs(c - e) <= 0.001, cscores, escores)), \
                'wrong prediction: {0} vs {1}'.format(cscores, escores)

        if self.optns.verb:
            print('xgb scores:', cscores)
            print('enc scores:', escores)

    def save_to(self, outfile):
        """
            Save the encoding into a file with a given name.
        """

        if outfile.endswith('.txt'):
            outfile = outfile[:-3] + 'smt2'

        write_smtlib(self.enc, outfile)

        # appending additional information
        with open(outfile, 'r') as fp:
            contents = fp.readlines()

        # comments
        comments = [
            '; features: {0}\n'.format(', '.join(self.feats)),
            '; classes: {0}\n'.format(self.nofcl)
        ]

        if self.intvs:
            for f in self.xgb.extended_feature_names_as_array_strings:
                c = '; i {0}: '.format(f)
                c += ', '.join([
                    '{0}<->{1}'.format(u, v)
                    for u, v in zip(self.intvs[f], self.ivars[f])
                ])
                comments.append(c + '\n')

        contents = comments + contents
        with open(outfile, 'w') as fp:
            fp.writelines(contents)

    def load_from(self, infile):
        """
            Loads the encoding from an input file.
        """

        with open(infile, 'r') as fp:
            file_content = fp.readlines()

        # empty intervals for the standard encoding
        self.intvs, self.imaps, self.ivars = {}, {}, {}

        for line in file_content:
            if line[0] != ';':
                break
            elif line.startswith('; i '):
                f, arr = line[4:].strip().split(': ', 1)
                f = f.replace('-', '_')
                self.intvs[f], self.imaps[f], self.ivars[f] = [], {}, []

                for i, pair in enumerate(arr.split(', ')):
                    ub, symb = pair.split('<->')

                    if ub[0] != '+':
                        ub = float(ub)
                    symb = Symbol(symb, typename=BOOL)

                    self.intvs[f].append(ub)
                    self.ivars[f].append(symb)
                    self.imaps[f][ub] = i

            elif line.startswith('; features:'):
                self.feats = line[11:].strip().split(', ')
            elif line.startswith('; classes:'):
                self.nofcl = int(line[10:].strip())

        parser = SmtLibParser()
        script = parser.get_script(StringIO(''.join(file_content)))

        self.enc = script.get_last_formula()

    def access(self):
        """
            Get access to the encoding, features names, and the number of
            classes.
        """

        return self.enc, self.intvs, self.imaps, self.ivars, self.feats, self.nofcl
Example #26
0
    def init_soft(self, encoding, clid):
        """
            Processing the leaves and creating the set of soft clauses.
        """

        # new vpool for the leaves, and total cost
        vpool = IDPool(start_from=self.formulas[clid].nv + 1)

        # all leaves to be used in the formula, am1 constraints and cost
        wghts, atmosts, cost = collections.defaultdict(lambda: 0), [], 0

        for label in (clid, self.target):
            if label != self.target:
                coeff = 1
            else:  # this is the target class
                if len(encoding) > 2:
                    coeff = -1
                else:
                    # we don't encoding the target class if there are
                    # only two classes - it duplicates the other class
                    continue

            # here we are going to automatically detect am1 constraints
            for tree in encoding[label].trees:
                am1 = []
                for i in range(tree[0], tree[1]):
                    lit, wght = encoding[label].leaves[i]

                    # all leaves of each tree comprise an AtMost1 constraint
                    am1.append(lit)

                    # updating literal's final weight
                    wghts[lit] += coeff * wght

                atmosts.append(am1)

        # filtering out those with zero-weights
        wghts = dict(filter(lambda p: p[1] != 0, wghts.items()))

        # processing the opposite literals, if any
        i, lits = 0, sorted(wghts.keys(), key=lambda l: 2 * abs(l) + (0 if l > 0 else 1))
        while i < len(lits) - 1:
            if lits[i] == -lits[i + 1]:
                l1, l2 = lits[i], lits[i + 1]
                minw = min(wghts[l1], wghts[l2], key=lambda w: abs(w))

                # updating the weights
                wghts[l1] -= minw
                wghts[l2] -= minw

                # updating the cost if there is a conflict between l and -l
                if wghts[l1] * wghts[l2] > 0:
                    cost += abs(minw)

                i += 2
            else:
                i += 1

        # flipping literals with negative weights
        lits = list(wghts.keys())
        for l in lits:
            if wghts[l] < 0:
                cost += -wghts[l]
                wghts[-l] = -wghts[l]
                del wghts[l]

        # maximum value of the objective function
        self.formulas[clid].vmax = sum(wghts.values())

        # processing all AtMost1 constraints
        atmosts = set([tuple([l for l in am1 if l in wghts and wghts[l] != 0]) for am1 in atmosts])
        for am1 in sorted(atmosts, key=lambda am1: len(am1), reverse=True):
            if len(am1) < 2:
                continue
            cost += self.process_am1(self.formulas[clid], am1, wghts, vpool)

        # here is the start cost
        self.formulas[clid].cost = cost

        # adding remaining leaves with non-zero weights as soft clauses
        for lit, wght in wghts.items():
            if wght != 0:
                self.formulas[clid].append([ lit], weight=wght)
Example #27
0
class SMTValidator(object):
    """
        Validating Anchor's explanations using SMT solving.
    """
    def __init__(self, formula, feats, nof_classes, xgb):
        """
            Constructor.
        """

        self.ftids = {f: i for i, f in enumerate(feats)}
        self.nofcl = nof_classes
        self.idmgr = IDPool()
        self.optns = xgb.options

        # xgbooster will also be needed
        self.xgb = xgb

        self.verbose = self.optns.verb
        self.oracle = Solver(name=self.xgb.options.solver)

        self.inps = []  # input (feature value) variables
        for f in self.xgb.extended_feature_names_as_array_strings:
            if '_' not in f:
                self.inps.append(Symbol(f, typename=REAL))
            else:
                self.inps.append(Symbol(f, typename=BOOL))

        self.outs = []  # output (class  score) variables
        for c in range(self.nofcl):
            self.outs.append(Symbol('class{0}_score'.format(c), typename=REAL))

        # theory
        self.oracle.add_assertion(formula)

        # current selector
        self.selv = None

    def prepare(self, sample, expl):
        """
            Prepare the oracle for validating an explanation given a sample.
        """

        if self.selv:
            # disable the previous assumption if any
            self.oracle.add_assertion(Not(self.selv))

        # creating a fresh selector for a new sample
        sname = ','.join([str(v).strip() for v in sample])

        # the samples should not repeat; otherwise, they will be
        # inconsistent with the previously introduced selectors
        assert sname not in self.idmgr.obj2id, 'this sample has been considered before (sample {0})'.format(
            self.idmgr.id(sname))
        self.selv = Symbol('sample{0}_selv'.format(self.idmgr.id(sname)),
                           typename=BOOL)

        self.rhypos = []  # relaxed hypotheses

        # transformed sample
        self.sample = list(self.xgb.transform(sample)[0])

        # preparing the selectors
        for i, (inp, val) in enumerate(zip(self.inps, self.sample), 1):
            feat = inp.symbol_name().split('_')[0]
            selv = Symbol('selv_{0}'.format(feat))
            val = float(val)

            self.rhypos.append(selv)

        # adding relaxed hypotheses to the oracle
        for inp, val, sel in zip(self.inps, self.sample, self.rhypos):
            if '_' not in inp.symbol_name():
                hypo = Implies(self.selv,
                               Implies(sel, Equals(inp, Real(float(val)))))
            else:
                hypo = Implies(self.selv, Implies(sel,
                                                  inp if val else Not(inp)))

            self.oracle.add_assertion(hypo)

        # propagating the true observation
        if self.oracle.solve([self.selv] + self.rhypos):
            model = self.oracle.get_model()
        else:
            assert 0, 'Formula is unsatisfiable under given assumptions'

        # choosing the maximum
        outvals = [float(model.get_py_value(o)) for o in self.outs]
        maxoval = max(zip(outvals, range(len(outvals))))

        # correct class id (corresponds to the maximum computed)
        true_output = maxoval[1]

        # forcing a misclassification, i.e. a wrong observation
        disj = []
        for i in range(len(self.outs)):
            if i != true_output:
                disj.append(GT(self.outs[i], self.outs[true_output]))
        self.oracle.add_assertion(Implies(self.selv, Or(disj)))

        # removing all hypotheses except for those in the explanation
        hypos = []
        for i, hypo in enumerate(self.rhypos):
            j = self.ftids[self.xgb.transform_inverse_by_index(i)[0]]
            if j in expl:
                hypos.append(hypo)
        self.rhypos = hypos

        if self.verbose:
            inpvals = self.xgb.readable_sample(sample)

            preamble = []
            for f, v in zip(self.xgb.feature_names, inpvals):
                if f not in v:
                    preamble.append('{0} = {1}'.format(f, v))
                else:
                    preamble.append(v)

            print('  explanation for:  "IF {0} THEN {1}"'.format(
                ' AND '.join(preamble), self.xgb.target_name[true_output]))

    def validate(self, sample, expl):
        """
            Make an effort to show that the explanation is too optimistic.
        """

        self.time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \
                resource.getrusage(resource.RUSAGE_SELF).ru_utime

        # adapt the solver to deal with the current sample
        self.prepare(sample, expl)

        # if satisfiable, then there is a counterexample
        if self.oracle.solve([self.selv] + self.rhypos):
            model = self.oracle.get_model()
            inpvals = [float(model.get_py_value(i)) for i in self.inps]
            outvals = [float(model.get_py_value(o)) for o in self.outs]
            maxoval = max(zip(outvals, range(len(outvals))))

            inpvals = self.xgb.transform_inverse(np.array(inpvals))[0]
            self.coex = tuple([inpvals, maxoval[1]])
            inpvals = self.xgb.readable_sample(inpvals)

            if self.verbose:
                preamble = []
                for f, v in zip(self.xgb.feature_names, inpvals):
                    if f not in v:
                        preamble.append('{0} = {1}'.format(f, v))
                    else:
                        preamble.append(v)

                print('  explanation is incorrect')
                print('  counterexample: "IF {0} THEN {1}"'.format(
                    ' AND '.join(preamble), self.xgb.target_name[maxoval[1]]))
        else:
            self.coex = None

            if self.verbose:
                print('  explanation is correct')

        self.time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \
                resource.getrusage(resource.RUSAGE_SELF).ru_utime - self.time

        if self.verbose:
            print('  time: {0:.2f}'.format(self.time))

        return self.coex
Example #28
0
class Hitman(object):
    """
        A cardinality-/subset-minimal hitting set enumerator. The enumerator
        can be set up to use either a MaxSAT solver :class:`.RC2` or an MCS
        enumerator (either :class:`.LBX` or :class:`.MCSls`). In the former
        case, the hitting sets enumerated are ordered by size (smallest size
        hitting sets are computed first), i.e. *sorted*. In the latter case,
        subset-minimal hitting are enumerated in an arbitrary order, i.e.
        *unsorted*.

        This is handled with the use of parameter ``htype``, which is set to be
        ``'sorted'`` by default. The MaxSAT-based enumerator can be chosen by
        setting ``htype`` to one of the following values: ``'maxsat'``,
        ``'mxsat'``, or ``'rc2'``. Alternatively, by setting it to ``'mcs'`` or
        ``'lbx'``, a user can enforce using the :class:`.LBX` MCS enumerator.
        If ``htype`` is set to ``'mcsls'``, the :class:`.MCSls` enumerator is
        used.

        In either case, an underlying problem solver can use a SAT oracle
        specified as an input parameter ``solver``. The default SAT solver is
        Glucose3 (specified as ``g3``, see :class:`.SolverNames` for details).

        Objects of class :class:`Hitman` can be bootstrapped with an iterable
        of iterables, e.g. a list of lists. This is handled using the
        ``bootstrap_with`` parameter. Each set to hit can comprise elements of
        any type, e.g. integers, strings or objects of any Python class, as
        well as their combinations. The bootstrapping phase is done in
        :func:`init`.

        A few other optional parameters include the possible options for RC2
        as well as for LBX- and MCSls-like MCS enumerators that control the
        behaviour of the underlying solvers.

        :param bootstrap_with: input set of sets to hit
        :param weights: a mapping from objects to their weights (if weighted)
        :param solver: name of SAT solver
        :param htype: enumerator type
        :param mxs_adapt: detect and process AtMost1 constraints in RC2
        :param mxs_exhaust: apply unsatisfiable core exhaustion in RC2
        :param mxs_minz: apply heuristic core minimization in RC2
        :param mxs_trim: trim unsatisfiable cores at most this number of times
        :param mcs_usecld: use clause-D heuristic in the MCS enumerator

        :type bootstrap_with: iterable(iterable(obj))
        :type weights: dict(obj)
        :type solver: str
        :type htype: str
        :type mxs_adapt: bool
        :type mxs_exhaust: bool
        :type mxs_minz: bool
        :type mxs_trim: int
        :type mcs_usecld: bool
    """
    def __init__(self,
                 bootstrap_with=[],
                 weights=None,
                 solver='g3',
                 htype='sorted',
                 mxs_adapt=False,
                 mxs_exhaust=False,
                 mxs_minz=False,
                 mxs_trim=0,
                 mcs_usecld=False):
        """
            Constructor.
        """

        # hitting set solver
        self.oracle = None

        # name of SAT solver
        self.solver = solver

        # various oracle options
        self.adapt = mxs_adapt
        self.exhaust = mxs_exhaust
        self.minz = mxs_minz
        self.trim = mxs_trim
        self.usecld = mcs_usecld

        # hitman type: either a MaxSAT solver or an MCS enumerator
        if htype in ('maxsat', 'mxsat', 'rc2', 'sorted'):
            self.htype = 'rc2'
        elif htype in ('mcs', 'lbx'):
            self.htype = 'lbx'
        else:  # 'mcsls'
            self.htype = 'mcsls'

        # pool of variable identifiers (for objects to hit)
        self.idpool = IDPool()

        # initialize hitting set solver
        self.init(bootstrap_with, weights)

    def __del__(self):
        """
            Destructor.
        """

        self.delete()

    def __enter__(self):
        """
            'with' constructor.
        """

        return self

    def __exit__(self, exc_type, exc_value, traceback):
        """
            'with' destructor.
        """

        self.delete()

    def init(self, bootstrap_with, weights=None):
        """
            This method serves for initializing the hitting set solver with a
            given list of sets to hit. Concretely, the hitting set problem is
            encoded into partial MaxSAT as outlined above, which is then fed
            either to a MaxSAT solver or an MCS enumerator.

            An additional optional parameter is ``weights``, which can be used
            to specify non-unit weights for the target objects in the sets to
            hit. This only works if ``'sorted'`` enumeration of hitting sets
            is applied.

            :param bootstrap_with: input set of sets to hit
            :param weights: weights of the objects in case the problem is weighted
            :type bootstrap_with: iterable(iterable(obj))
            :type weights: dict(obj)
        """

        # formula encoding the sets to hit
        formula = WCNF()

        # hard clauses
        for to_hit in bootstrap_with:
            to_hit = list(map(lambda obj: self.idpool.id(obj), to_hit))

            formula.append(to_hit)

        # soft clauses
        for obj_id in six.iterkeys(self.idpool.id2obj):
            formula.append(
                [-obj_id],
                weight=1 if not weights else weights[self.idpool.obj(obj_id)])

        if self.htype == 'rc2':
            if not weights or min(weights.values()) == max(weights.values()):
                self.oracle = RC2(formula,
                                  solver=self.solver,
                                  adapt=self.adapt,
                                  exhaust=self.exhaust,
                                  minz=self.minz,
                                  trim=self.trim)
            else:
                self.oracle = RC2Stratified(formula,
                                            solver=self.solver,
                                            adapt=self.adapt,
                                            exhaust=self.exhaust,
                                            minz=self.minz,
                                            nohard=True,
                                            trim=self.trim)
        elif self.htype == 'lbx':
            self.oracle = LBX(formula,
                              solver_name=self.solver,
                              use_cld=self.usecld)
        else:
            self.oracle = MCSls(formula,
                                solver_name=self.solver,
                                use_cld=self.usecld)

    def delete(self):
        """
            Explicit destructor of the internal hitting set oracle.
        """

        if self.oracle:
            self.oracle.delete()
            self.oracle = None

    def get(self):
        """
            This method computes and returns a hitting set. The hitting set is
            obtained using the underlying oracle operating the MaxSAT problem
            formulation. The computed solution is mapped back to objects of the
            problem domain.

            :rtype: list(obj)
        """

        model = self.oracle.compute()

        if model is not None:
            if self.htype == 'rc2':
                # extracting a hitting set
                self.hset = filter(lambda v: v > 0, model)
            else:
                self.hset = model

            return list(map(lambda vid: self.idpool.id2obj[vid], self.hset))

    def hit(self, to_hit, weights=None):
        """
            This method adds a new set to hit to the hitting set solver. This
            is done by translating the input iterable of objects into a list of
            Boolean variables in the MaxSAT problem formulation.

            Note that an optional parameter that can be passed to this method
            is ``weights``, which contains a mapping the objects under
            question into weights. Also note that the weight of an object must
            not change from one call of :meth:`hit` to another.

            :param to_hit: a new set to hit
            :param weights: a mapping from objects to weights

            :type to_hit: iterable(obj)
            :type weights: dict(obj)
        """

        # translating objects to variables
        to_hit = list(map(lambda obj: self.idpool.id(obj), to_hit))

        # a soft clause should be added for each new object
        new_obj = list(
            filter(lambda vid: vid not in self.oracle.vmap.e2i, to_hit))

        # new hard clause
        self.oracle.add_clause(to_hit)

        # new soft clauses
        for vid in new_obj:
            self.oracle.add_clause(
                [-vid], 1 if not weights else weights[self.idpool.obj(vid)])

    def block(self, to_block, weights=None):
        """
            The method serves for imposing a constraint forbidding the hitting
            set solver to compute a given hitting set. Each set to block is
            encoded as a hard clause in the MaxSAT problem formulation, which
            is then added to the underlying oracle.

            Note that an optional parameter that can be passed to this method
            is ``weights``, which contains a mapping the objects under
            question into weights. Also note that the weight of an object must
            not change from one call of :meth:`hit` to another.

            :param to_block: a set to block
            :param weights: a mapping from objects to weights

            :type to_block: iterable(obj)
            :type weights: dict(obj)
        """

        # translating objects to variables
        to_block = list(map(lambda obj: self.idpool.id(obj), to_block))

        # a soft clause should be added for each new object
        new_obj = list(
            filter(lambda vid: vid not in self.oracle.vmap.e2i, to_block))

        # new hard clause
        self.oracle.add_clause([-vid for vid in to_block])

        # new soft clauses
        for vid in new_obj:
            self.oracle.add_clause(
                [-vid], 1 if not weights else weights[self.idpool.obj(vid)])

    def enumerate(self):
        """
            The method can be used as a simple iterator computing and blocking
            the hitting sets on the fly. It essentially calls :func:`get`
            followed by :func:`block`. Each hitting set is reported as a list
            of objects in the original problem domain, i.e. it is mapped back
            from the solutions over Boolean variables computed by the
            underlying oracle.

            :rtype: list(obj)
        """

        done = False
        while not done:
            hset = self.get()

            if hset != None:
                self.block(hset)
                yield hset
            else:
                done = True

    def oracle_time(self):
        """
            Report the total SAT solving time.
        """

        return self.oracle.oracle_time()
Example #29
0
class SMTExplainer(object):
    """
        An SMT-inspired minimal explanation extractor for XGBoost models.
    """
    def __init__(self, formula, intvs, imaps, ivars, feats, nof_classes,
                 options, xgb):
        """
            Constructor.
        """

        self.feats = feats
        self.intvs = intvs
        self.imaps = imaps
        self.ivars = ivars
        self.nofcl = nof_classes
        self.optns = options
        self.idmgr = IDPool()

        # saving XGBooster
        self.xgb = xgb

        self.verbose = self.optns.verb
        self.oracle = Solver(name=options.solver)

        self.inps = []  # input (feature value) variables
        for f in self.xgb.extended_feature_names_as_array_strings:
            if '_' not in f:
                self.inps.append(Symbol(f, typename=REAL))
            else:
                self.inps.append(Symbol(f, typename=BOOL))

        self.outs = []  # output (class  score) variables
        for c in range(self.nofcl):
            self.outs.append(Symbol('class{0}_score'.format(c), typename=REAL))

        # theory
        self.oracle.add_assertion(formula)

        # current selector
        self.selv = None

        # save and use dual explanations whenever needed
        self.dualx = []

        # number of oracle calls involved
        self.calls = 0

    def prepare(self, sample):
        """
            Prepare the oracle for computing an explanation.
        """

        if self.selv:
            # disable the previous assumption if any
            self.oracle.add_assertion(Not(self.selv))

        # creating a fresh selector for a new sample
        sname = ','.join([str(v).strip() for v in sample])

        # the samples should not repeat; otherwise, they will be
        # inconsistent with the previously introduced selectors
        assert sname not in self.idmgr.obj2id, 'this sample has been considered before (sample {0})'.format(
            self.idmgr.id(sname))
        self.selv = Symbol('sample{0}_selv'.format(self.idmgr.id(sname)),
                           typename=BOOL)

        self.rhypos = []  # relaxed hypotheses

        # transformed sample
        self.sample = list(self.xgb.transform(sample)[0])

        self.sel2fid = {}  # selectors to original feature ids
        self.sel2vid = {}  # selectors to categorical feature ids

        # preparing the selectors
        for i, (inp, val) in enumerate(zip(self.inps, self.sample), 1):
            feat = inp.symbol_name().split('_')[0]
            selv = Symbol('selv_{0}'.format(feat))
            val = float(val)

            self.rhypos.append(selv)
            if selv not in self.sel2fid:
                self.sel2fid[selv] = int(feat[1:])
                self.sel2vid[selv] = [i - 1]
            else:
                self.sel2vid[selv].append(i - 1)

        # adding relaxed hypotheses to the oracle
        if not self.intvs:
            for inp, val, sel in zip(self.inps, self.sample, self.rhypos):
                if '_' not in inp.symbol_name():
                    hypo = Implies(self.selv,
                                   Implies(sel, Equals(inp, Real(float(val)))))
                else:
                    hypo = Implies(self.selv,
                                   Implies(sel, inp if val else Not(inp)))

                self.oracle.add_assertion(hypo)
        else:
            for inp, val, sel in zip(self.inps, self.sample, self.rhypos):
                inp = inp.symbol_name()
                # determining the right interval and the corresponding variable
                for ub, fvar in zip(self.intvs[inp], self.ivars[inp]):
                    if ub == '+' or val < ub:
                        hypo = Implies(self.selv, Implies(sel, fvar))
                        break

                self.oracle.add_assertion(hypo)

        # in case of categorical data, there are selector duplicates
        # and we need to remove them
        self.rhypos = sorted(set(self.rhypos),
                             key=lambda x: int(x.symbol_name()[6:]))

        # propagating the true observation
        if self.oracle.solve([self.selv] + self.rhypos):
            model = self.oracle.get_model()
        else:
            assert 0, 'Formula is unsatisfiable under given assumptions'

        # choosing the maximum
        outvals = [float(model.get_py_value(o)) for o in self.outs]
        maxoval = max(zip(outvals, range(len(outvals))))

        # correct class id (corresponds to the maximum computed)
        self.out_id = maxoval[1]
        self.output = self.xgb.target_name[self.out_id]

        # forcing a misclassification, i.e. a wrong observation
        disj = []
        for i in range(len(self.outs)):
            if i != self.out_id:
                disj.append(GT(self.outs[i], self.outs[self.out_id]))
        self.oracle.add_assertion(Implies(self.selv, Or(disj)))

        if self.verbose:
            inpvals = self.xgb.readable_sample(sample)

            self.preamble = []
            for f, v in zip(self.xgb.feature_names, inpvals):
                if f not in v:
                    self.preamble.append('{0} = {1}'.format(f, v))
                else:
                    self.preamble.append(v)

            print('  explaining:  "IF {0} THEN {1}"'.format(
                ' AND '.join(self.preamble), self.output))

    def explain(self, sample, smallest):
        """
            Hypotheses minimization.
        """

        # reinitializing the number of used oracle calls
        # 1 because of the initial call checking the entailment
        self.calls = 1

        # adapt the solver to deal with the current sample
        self.prepare(sample)

        # saving external explanation to be minimized further
        self.to_consider = [True for h in self.rhypos]

        self.time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \
                resource.getrusage(resource.RUSAGE_SELF).ru_utime

        # if satisfiable, then the observation is not implied by the hypotheses
        if self.oracle.solve(
            [self.selv] +
            [h for h, c in zip(self.rhypos, self.to_consider) if c]):
            print('  no implication!')
            print(self.oracle.get_model())
            sys.exit(1)

        if self.optns.xtype == 'abductive':
            # abductive explanations => MUS computation and enumeration
            if not smallest and self.optns.xnum == 1:
                expls = [self.compute_minimal_abductive()]
            else:
                expls = self.enumerate_abductive(smallest=smallest)
        else:  # contrastive explanations => MCS enumeration
            if self.optns.usemhs:
                expls = self.enumerate_contrastive()
            else:
                if not smallest:
                    expls = self.enumerate_minimal_contrastive()
                else:
                    # expls = self.enumerate_smallest_contrastive()
                    expls = self.enumerate_contrastive()

        self.time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \
                resource.getrusage(resource.RUSAGE_SELF).ru_utime - self.time

        expls = list(
            map(lambda expl: sorted([self.sel2fid[h] for h in expl]), expls))

        if self.dualx:
            self.dualx = list(
                map(lambda expl: sorted([self.sel2fid[h] for h in expl]),
                    self.dualx))

        if self.verbose:
            if expls[0] != None:
                for expl in expls:
                    preamble = [self.preamble[i] for i in expl]
                    if self.optns.xtype == 'abductive':
                        print('  explanation: "IF {0} THEN {1}"'.format(
                            ' AND '.join(preamble),
                            self.xgb.target_name[self.out_id]))
                    else:
                        print(
                            '  explanation: "IF NOT {0} THEN NOT {1}"'.format(
                                ' AND NOT '.join(preamble),
                                self.xgb.target_name[self.out_id]))
                    print('  # hypos left:', len(expl))

            print('  time: {0:.2f}'.format(self.time))

        # here we return the last computed explanation
        return expls

    def compute_minimal_abductive(self):
        """
            Compute any subset-minimal explanation.
        """

        i = 0

        # filtering out unnecessary features if external explanation is given
        rhypos = [h for h, c in zip(self.rhypos, self.to_consider) if c]

        # simple deletion-based linear search
        while i < len(rhypos):
            to_test = rhypos[:i] + rhypos[(i + 1):]

            self.calls += 1
            if self.oracle.solve([self.selv] + to_test):
                i += 1
            else:
                rhypos = to_test

        return rhypos

    def enumerate_minimal_contrastive(self):
        """
            Compute a subset-minimal contrastive explanation.
        """
        def _overapprox():
            model = self.oracle.get_model()

            for sel in self.rhypos:
                if int(model.get_py_value(sel)) > 0:
                    # soft clauses contain positive literals
                    # so if var is true then the clause is satisfied
                    self.ss_assumps.append(sel)
                else:
                    self.setd.append(sel)

        def _compute():
            i = 0
            while i < len(self.setd):
                if self.optns.usecld:
                    _do_cld_check(self.setd[i:])
                    i = 0

                if self.setd:
                    # it may be empty after the clause D check

                    self.calls += 1
                    self.ss_assumps.append(self.setd[i])
                    if not self.oracle.solve([self.selv] + self.ss_assumps +
                                             self.bb_assumps):
                        self.ss_assumps.pop()
                        self.bb_assumps.append(Not(self.setd[i]))

                i += 1

        def _do_cld_check(cld):
            self.cldid += 1
            sel = Symbol('{0}_{1}'.format(self.selv.symbol_name(), self.cldid))
            cld.append(Not(sel))

            # adding clause D
            self.oracle.add_assertion(Or(cld))
            self.ss_assumps.append(sel)

            self.setd = []
            st = self.oracle.solve([self.selv] + self.ss_assumps +
                                   self.bb_assumps)

            self.ss_assumps.pop()  # removing clause D assumption
            if st == True:
                model = self.oracle.get_model()

                for l in cld[:-1]:
                    # filtering all satisfied literals
                    if int(model.get_py_value(l)) > 0:
                        self.ss_assumps.append(l)
                    else:
                        self.setd.append(l)
            else:
                # clause D is unsatisfiable => all literals are backbones
                self.bb_assumps.extend([Not(l) for l in cld[:-1]])

            # deactivating clause D
            self.oracle.add_assertion(Not(sel))

        # sets of selectors to work with
        self.cldid = 0
        expls = []

        # detect and block unit-size MCSes immediately
        if self.optns.unitmcs:
            for i, hypo in enumerate(self.rhypos):
                self.calls += 1
                if self.oracle.solve([self.selv] + self.rhypos[:i] +
                                     self.rhypos[(i + 1):]):
                    expls.append([hypo])

                    if len(expls) != self.optns.xnum:
                        self.oracle.add_assertion(Or([Not(self.selv), hypo]))
                    else:
                        break

        self.calls += 1
        while self.oracle.solve([self.selv]):
            self.ss_assumps, self.bb_assumps, self.setd = [], [], []
            _overapprox()
            _compute()

            expl = [list(f.get_free_variables())[0] for f in self.bb_assumps]
            expls.append(expl)

            if len(expls) == self.optns.xnum:
                break

            self.oracle.add_assertion(Or([Not(self.selv)] + expl))
            self.calls += 1

        self.calls += self.cldid
        return expls if expls else [None]

    def enumerate_abductive(self, smallest=True):
        """
            Compute a cardinality-minimal explanation.
        """

        # result
        expls = []

        # just in case, let's save dual (contrastive) explanations
        self.dualx = []

        with Hitman(bootstrap_with=[[
                i for i in range(len(self.rhypos)) if self.to_consider[i]
        ]],
                    htype='sorted' if smallest else 'lbx') as hitman:
            # computing unit-size MCSes
            for i, hypo in enumerate(self.rhypos):
                if self.to_consider[i] == False:
                    continue

                self.calls += 1
                if self.oracle.solve([self.selv] + self.rhypos[:i] +
                                     self.rhypos[(i + 1):]):
                    hitman.hit([i])
                    self.dualx.append([self.rhypos[i]])

            # main loop
            iters = 0
            while True:
                hset = hitman.get()
                iters += 1

                if self.verbose > 1:
                    print('iter:', iters)
                    print('cand:', hset)

                if hset == None:
                    break

                self.calls += 1
                if self.oracle.solve([self.selv] +
                                     [self.rhypos[i] for i in hset]):
                    to_hit = []
                    satisfied, unsatisfied = [], []

                    removed = list(
                        set(range(len(self.rhypos))).difference(set(hset)))

                    model = self.oracle.get_model()
                    for h in removed:
                        i = self.sel2fid[self.rhypos[h]]
                        if '_' not in self.inps[i].symbol_name():
                            # feature variable and its expected value
                            var, exp = self.inps[i], self.sample[i]

                            # true value
                            true_val = float(model.get_py_value(var))

                            if not exp - 0.001 <= true_val <= exp + 0.001:
                                unsatisfied.append(h)
                            else:
                                hset.append(h)
                        else:
                            for vid in self.sel2vid[self.rhypos[h]]:
                                var, exp = self.inps[vid], int(
                                    self.sample[vid])

                                # true value
                                true_val = int(model.get_py_value(var))

                                if exp != true_val:
                                    unsatisfied.append(h)
                                    break
                            else:
                                hset.append(h)

                    # computing an MCS (expensive)
                    for h in unsatisfied:
                        self.calls += 1
                        if self.oracle.solve([self.selv] +
                                             [self.rhypos[i] for i in hset] +
                                             [self.rhypos[h]]):
                            hset.append(h)
                        else:
                            to_hit.append(h)

                    if self.verbose > 1:
                        print('coex:', to_hit)

                    hitman.hit(to_hit)

                    self.dualx.append([self.rhypos[i] for i in to_hit])
                else:
                    if self.verbose > 1:
                        print('expl:', hset)

                    expl = [self.rhypos[i] for i in hset]
                    expls.append(expl)

                    if len(expls) != self.optns.xnum:
                        hitman.block(hset)
                    else:
                        break

        return expls

    def enumerate_smallest_contrastive(self):
        """
            Compute a cardinality-minimal contrastive explanation.
        """

        # result
        expls = []

        # computing unit-size MUSes
        muses = set([])
        for hypo in self.rhypos:
            self.calls += 1
            if not self.oracle.solve([self.selv, hypo]):
                muses.add(hypo)

        # we are going to discard unit-size MUSes from consideration
        rhypos = set(self.rhypos).difference(muses)

        # introducing interer cost literals for rhypos
        costlits = []
        for i, hypo in enumerate(rhypos):
            costlit = Symbol(name='costlit_{0}_{1}'.format(
                self.selv.symbol_name(), i),
                             typename=INT)
            costlits.append(costlit)

            self.oracle.add_assertion(
                Ite(hypo, Equals(costlit, Int(0)), Equals(costlit, Int(1))))

        # main loop (linear search unsat-sat)
        i = 0
        while i < len(rhypos) and len(expls) != self.optns.xnum:
            # fresh selector for the current iteration
            sit = Symbol('iter_{0}_{1}'.format(self.selv.symbol_name(), i))

            # adding cardinality constraint
            self.oracle.add_assertion(Implies(sit, LE(Plus(costlits), Int(i))))

            # extracting explanations from MaxSAT models
            while self.oracle.solve([self.selv, sit]):
                self.calls += 1
                model = self.oracle.get_model()

                expl = []
                for hypo in rhypos:
                    if int(model.get_py_value(hypo)) == 0:
                        expl.append(hypo)

                # each MCS contains all unit-size MUSes
                expls.append(list(muses) + expl)

                # either stop or add a blocking clause
                if len(expls) != self.optns.xnum:
                    self.oracle.add_assertion(Implies(self.selv, Or(expl)))
                else:
                    break

            i += 1
            self.calls += 1

        return expls

    def enumerate_contrastive(self, smallest=True):
        """
            Compute a cardinality-minimal contrastive explanation.
        """

        # core extraction is done via calling Z3's internal API
        assert self.optns.solver == 'z3', 'This procedure requires Z3'

        # result
        expls = []

        # just in case, let's save dual (abductive) explanations
        self.dualx = []

        # mapping from hypothesis variables to their indices
        hmap = {h: i for i, h in enumerate(self.rhypos)}

        # mapping from internal Z3 variable into variables of PySMT
        vmap = {self.oracle.converter.convert(v): v for v in self.rhypos}
        vmap[self.oracle.converter.convert(self.selv)] = None

        def _get_core():
            core = self.oracle.z3.unsat_core()
            return sorted(filter(lambda x: x != None,
                                 map(lambda x: vmap[x], core)),
                          key=lambda x: int(x.symbol_name()[6:]))

        def _do_trimming(core):
            for i in range(self.optns.trim):
                self.calls += 1
                self.oracle.solve([self.selv] + core)
                new_core = _get_core()
                if len(core) == len(new_core):
                    break
            return new_core

        def _reduce_lin(core):
            def _assump_needed(a):
                if len(to_test) > 1:
                    to_test.remove(a)
                    self.calls += 1
                    if not self.oracle.solve([self.selv] + list(to_test)):
                        return False
                    to_test.add(a)
                    return True
                else:
                    return True

            to_test = set(core)
            return list(filter(lambda a: _assump_needed(a), core))

        def _reduce_qxp(core):
            coex = core[:]
            filt_sz = len(coex) / 2.0
            while filt_sz >= 1:
                i = 0
                while i < len(coex):
                    to_test = coex[:i] + coex[(i + int(filt_sz)):]
                    self.calls += 1
                    if to_test and not self.oracle.solve([self.selv] +
                                                         to_test):
                        # assumps are not needed
                        coex = to_test
                    else:
                        # assumps are needed => check the next chunk
                        i += int(filt_sz)
                # decreasing size of the set to filter
                filt_sz /= 2.0
                if filt_sz > len(coex) / 2.0:
                    # next size is too large => make it smaller
                    filt_sz = len(coex) / 2.0
            return coex

        def _reduce_coex(core):
            if self.optns.reduce == 'lin':
                return _reduce_lin(core)
            else:  # qxp
                return _reduce_qxp(core)

        with Hitman(bootstrap_with=[[
                i for i in range(len(self.rhypos)) if self.to_consider[i]
        ]],
                    htype='sorted' if smallest else 'lbx') as hitman:
            # computing unit-size MUSes
            for i, hypo in enumerate(self.rhypos):
                if self.to_consider[i] == False:
                    continue

                self.calls += 1
                if not self.oracle.solve([self.selv, self.rhypos[i]]):
                    hitman.hit([i])
                    self.dualx.append([self.rhypos[i]])
                elif self.optns.unitmcs:
                    self.calls += 1
                    if self.oracle.solve([self.selv] + self.rhypos[:i] +
                                         self.rhypos[(i + 1):]):
                        # this is a unit-size MCS => block immediately
                        hitman.block([i])
                        expls.append([self.rhypos[i]])

            # main loop
            iters = 0
            while True:
                hset = hitman.get()
                iters += 1

                if self.verbose > 1:
                    print('iter:', iters)
                    print('cand:', hset)

                if hset == None:
                    break

                self.calls += 1
                if not self.oracle.solve([self.selv] + [
                        self.rhypos[h] for h in list(
                            set(range(len(self.rhypos))).difference(set(hset)))
                ]):
                    to_hit = _get_core()

                    if len(to_hit) > 1 and self.optns.trim:
                        to_hit = _do_trimming(to_hit)

                    if len(to_hit) > 1 and self.optns.reduce != 'none':
                        to_hit = _reduce_coex(to_hit)

                    self.dualx.append(to_hit)
                    to_hit = [hmap[h] for h in to_hit]

                    if self.verbose > 1:
                        print('coex:', to_hit)

                    hitman.hit(to_hit)
                else:
                    if self.verbose > 1:
                        print('expl:', hset)

                    expl = [self.rhypos[i] for i in hset]
                    expls.append(expl)

                    if len(expls) != self.optns.xnum:
                        hitman.block(hset)
                    else:
                        break

        return expls
Example #30
0
class Problem:
    def __init__(self, inputs):
        # unpack inputs
        self.police = inputs["police"]
        self.medics = inputs["medics"]
        self.observations = inputs["observations"]
        self.queries = inputs["queries"]

        # auxiliary variables
        self.t_max = len(self.observations) - 1
        self.num_observations = len(self.observations)
        self.rows = len(self.observations[0])
        self.cols = len(self.observations[0][0])
        self.num_tiles = self.rows * self.cols
        self.tiles = {(i, j)
                      for j in range(self.cols) for i in range(self.rows)}

        # create predicates
        self.pool = IDPool()
        self.fill_predicates()
        self.obj2id = self.pool.obj2id

    def fill_predicates(self):
        for t in range(self.t_max + 1):
            for i in range(self.rows):
                for j in range(self.cols):
                    self.pool.id(f"U_{i}_{j}^{t}")
                    self.pool.id(f"I0_{i}_{j}^{t}")  # vaccinated now
                    self.pool.id(f"I_{i}_{j}^{t}")
                    self.pool.id(f"S0_{i}_{j}^{t}")  # current S
                    self.pool.id(f"S1_{i}_{j}^{t}")  # S minus 1 (prev)
                    self.pool.id(f"S2_{i}_{j}^{t}")  # S minus 2 (prev-prev)
                    self.pool.id(f"Q0_{i}_{j}^{t}")  # current Q
                    self.pool.id(f"Q1_{i}_{j}^{t}")  # Q minus 1 (prev)
                    self.pool.id(f"H_{i}_{j}^{t}")

    def U_tile_dynamics(self):
        clauses = []
        for i in range(self.rows):
            for j in range(self.cols):
                for t in range(self.t_max + 1):
                    # first
                    if t == 0:
                        clauses.append([
                            -self.obj2id[f"U_{i}_{j}^{t}"],
                            self.obj2id[f"U_{i}_{j}^{t + 1}"]
                        ])

                    # middle
                    if t > 0 and t != self.t_max:
                        clauses.append([
                            -self.obj2id[f"U_{i}_{j}^{t}"],
                            self.obj2id[f"U_{i}_{j}^{t + 1}"]
                        ])
                        clauses.append([
                            -self.obj2id[f"U_{i}_{j}^{t}"],
                            self.obj2id[f"U_{i}_{j}^{t - 1}"]
                        ])

                    # last
                    if t == self.t_max:
                        clauses.append([
                            -self.obj2id[f"U_{i}_{j}^{t}"],
                            self.obj2id[f"U_{i}_{j}^{t - 1}"]
                        ])

        return CNF(from_clauses=clauses)

    def I_tile_dynamics(self):
        clauses = []
        for i in range(self.rows):
            for j in range(self.cols):
                for t in range(self.t_max + 1):
                    # first
                    if t == 0:
                        continue

                    # middle
                    if t > 0 and t != self.t_max:
                        clauses.append([
                            -self.obj2id[f"I_{i}_{j}^{t}"],
                            self.obj2id[f"I_{i}_{j}^{t + 1}"]
                        ])
                        clauses.append([
                            -self.obj2id[f"I0_{i}_{j}^{t}"],
                            self.obj2id[f"I_{i}_{j}^{t + 1}"]
                        ])
                        clauses.append([
                            -self.obj2id[f"I0_{i}_{j}^{t}"],
                            self.obj2id[f"H_{i}_{j}^{t - 1}"]
                        ])
                        clauses.append([
                            -self.obj2id[f"I_{i}_{j}^{t}"],
                            self.obj2id[f"I_{i}_{j}^{t - 1}"],
                            self.obj2id[f"I0_{i}_{j}^{t - 1}"]
                        ])

                    # last
                    if t == self.t_max:
                        clauses.append([
                            -self.obj2id[f"I0_{i}_{j}^{t}"],
                            self.obj2id[f"H_{i}_{j}^{t - 1}"]
                        ])
                        clauses.append([
                            -self.obj2id[f"I_{i}_{j}^{t}"],
                            self.obj2id[f"I_{i}_{j}^{t - 1}"],
                            self.obj2id[f"I0_{i}_{j}^{t - 1}"]
                        ])

        return CNF(from_clauses=clauses)

    def S_tile_dynamics(self):
        clauses = []
        for i in range(self.rows):
            for j in range(self.cols):
                for t in range(self.t_max + 1):
                    neighbors = self.__get_neighbours_indices(i, j)

                    # Is sick
                    # Previous t
                    if 0 < t:
                        # S2_t => H_t-1
                        clauses.append([
                            -self.obj2id[f"S0_{i}_{j}^{t}"],
                            self.obj2id[f"H_{i}_{j}^{t - 1}"]
                        ])
                        # S1_t => S2_t-1
                        clauses.append([
                            -self.obj2id[f"S1_{i}_{j}^{t}"],
                            self.obj2id[f"S0_{i}_{j}^{t - 1}"]
                        ])
                        # S0_t => S1_t-1
                        clauses.append([
                            -self.obj2id[f"S2_{i}_{j}^{t}"],
                            self.obj2id[f"S1_{i}_{j}^{t - 1}"]
                        ])

                    # Next t
                    if t < self.num_observations - 1:
                        # S2_t => S1_t+1 v Q1_t+1
                        clauses.append([
                            -self.obj2id[f"S0_{i}_{j}^{t}"],
                            self.obj2id[f"S1_{i}_{j}^{t + 1}"],
                            self.obj2id[f"Q0_{i}_{j}^{t + 1}"]
                        ])
                        # S1_t => S0_t+1 v Q1_t+1
                        clauses.append([
                            -self.obj2id[f"S1_{i}_{j}^{t}"],
                            self.obj2id[f"S2_{i}_{j}^{t + 1}"],
                            self.obj2id[f"Q0_{i}_{j}^{t + 1}"]
                        ])
                        # S0_t => H_t+1 v Q1_t+1
                        clauses.append([
                            -self.obj2id[f"S2_{i}_{j}^{t}"],
                            self.obj2id[f"H_{i}_{j}^{t + 1}"],
                            self.obj2id[f"Q0_{i}_{j}^{t + 1}"]
                        ])

                    # Infected By Someone
                    if 0 < t:
                        # S2_t => V (S2n_t-1 v S1n_t-1 v S0n_t-1) for n in neighbors
                        clause = [-self.obj2id[f"S0_{i}_{j}^{t}"]]
                        for (n_row, n_col) in neighbors:
                            clause.extend([
                                self.obj2id[f"S0_{n_row}_{n_col}^{t - 1}"],
                                self.obj2id[f"S1_{n_row}_{n_col}^{t - 1}"],
                                self.obj2id[f"S2_{n_row}_{n_col}^{t - 1}"]
                            ])
                        clauses.append(clause)

                    # Infecting Others
                    if t < self.num_observations - 1:
                        for (n_row, n_col) in neighbors:
                            for sick_i in ["S0", "S1", "S2"]:
                                # Si_t /\ Hn_t /\ -Q1_t+1 /\ -In_recent_t+1 => S2n_t+1 (Sn, Hn stand for neighbor)
                                clauses.append([
                                    -self.obj2id[f"{sick_i}_{i}_{j}^{t}"],
                                    -self.obj2id[f"H_{n_row}_{n_col}^{t}"],
                                    self.obj2id[f"Q0_{i}_{j}^{t + 1}"],
                                    self.obj2id[f"I0_{n_row}_{n_col}^{t + 1}"],
                                    self.obj2id[f"S0_{n_row}_{n_col}^{t + 1}"]
                                ])

        return CNF(from_clauses=clauses)

    def Q_tile_dynamics(self):
        clauses = []
        for i in range(self.rows):
            for j in range(self.cols):
                for t in range(self.t_max + 1):
                    # first
                    if t == 0:
                        continue

                    # middle
                    if t > 0 and t != self.t_max:
                        clauses.append([
                            -self.obj2id[f"Q0_{i}_{j}^{t}"],
                            self.obj2id[f"Q1_{i}_{j}^{t + 1}"]
                        ])
                        clauses.append([
                            -self.obj2id[f"Q1_{i}_{j}^{t}"],
                            self.obj2id[f"H_{i}_{j}^{t + 1}"]
                        ])
                        clauses.append([
                            -self.obj2id[f"Q1_{i}_{j}^{t}"],
                            self.obj2id[f"Q0_{i}_{j}^{t - 1}"]
                        ])
                        clauses.append([
                            -self.obj2id[f"Q0_{i}_{j}^{t}"],
                            self.obj2id[f"S0_{i}_{j}^{t - 1}"],
                            self.obj2id[f"S1_{i}_{j}^{t - 1}"],
                            self.obj2id[f"S2_{i}_{j}^{t - 1}"]
                        ])

                    # last
                    if t == self.t_max:
                        clauses.append([
                            -self.obj2id[f"Q1_{i}_{j}^{t}"],
                            self.obj2id[f"Q0_{i}_{j}^{t - 1}"]
                        ])
                        clauses.append([
                            -self.obj2id[f"Q0_{i}_{j}^{t}"],
                            self.obj2id[f"S0_{i}_{j}^{t - 1}"],
                            self.obj2id[f"S1_{i}_{j}^{t - 1}"],
                            self.obj2id[f"S2_{i}_{j}^{t - 1}"]
                        ])

        return CNF(from_clauses=clauses)

    def H_tile_dynamics(self):
        clauses = []
        for i in range(self.rows):
            for j in range(self.cols):
                for t in range(self.t_max + 1):
                    if 0 < t:
                        # H_t => H_t-1 v Q0_t-1 v S0_t-1
                        clauses.append([
                            -self.obj2id[f"H_{i}_{j}^{t}"],
                            self.obj2id[f"H_{i}_{j}^{t - 1}"],
                            self.obj2id[f"Q1_{i}_{j}^{t - 1}"],
                            self.obj2id[f"S2_{i}_{j}^{t - 1}"],
                        ])

                    if t < self.num_observations - 1:
                        # H_t => H_t+1 \/ S2_t+1 \/ I_recent_t+1
                        clauses.append([
                            -self.obj2id[f"H_{i}_{j}^{t}"],
                            self.obj2id[f"S0_{i}_{j}^{t + 1}"],
                            self.obj2id[f"I0_{i}_{j}^{t + 1}"],
                            self.obj2id[f"H_{i}_{j}^{t + 1}"],
                        ])

        return CNF(from_clauses=clauses)

    def unique_tile_dynamics(self):
        clauses = []
        for i in range(self.rows):
            for j in range(self.cols):
                for t in range(self.t_max + 1):  # including all
                    legal_states = [
                        self.pool.obj2id[f"U_{i}_{j}^{t}"],
                        self.pool.obj2id[f"I0_{i}_{j}^{t}"],
                        self.pool.obj2id[f"I_{i}_{j}^{t}"],
                        self.pool.obj2id[f"S0_{i}_{j}^{t}"],
                        self.pool.obj2id[f"S1_{i}_{j}^{t}"],
                        self.pool.obj2id[f"S2_{i}_{j}^{t}"],
                        self.pool.obj2id[f"Q0_{i}_{j}^{t}"],
                        self.pool.obj2id[f"Q1_{i}_{j}^{t}"],
                        self.pool.obj2id[f"H_{i}_{j}^{t}"],
                    ]
                    clauses.extend(
                        CardEnc.equals(legal_states, 1, vpool=self.pool))
        return CNF(from_clauses=clauses)

    def first_turn_rules(self):
        clauses = []
        for i in range(self.rows):
            for j in range(self.cols):
                for t in range(min(2, self.num_observations)):
                    if t == 0:
                        # can't be Q0, Q1, S1, S2, I, I0 in the first turn
                        clauses.append([-self.obj2id[f"Q0_{i}_{j}^{t}"]])
                        clauses.append([-self.obj2id[f"Q1_{i}_{j}^{t}"]])
                        clauses.append([-self.obj2id[f"S1_{i}_{j}^{t}"]])
                        clauses.append([-self.obj2id[f"S2_{i}_{j}^{t}"]])
                        clauses.append([-self.obj2id[f"I_{i}_{j}^{t}"]])
                        clauses.append([-self.obj2id[f"I0_{i}_{j}^{t}"]])
                    if t == 1:
                        # can't be Q1, S2, I in the second turn
                        clauses.append([-self.obj2id[f"Q1_{i}_{j}^{t}"]])
                        clauses.append([-self.obj2id[f"S2_{i}_{j}^{t}"]])
                        clauses.append([-self.obj2id[f"I_{i}_{j}^{t}"]])
        return CNF(from_clauses=clauses)

    def hadar_dynamics(self):
        clauses = []

        for t in range(self.num_observations):
            clauses.extend(
                CardEnc.atmost(self.__get_I0_predicates(t),
                               bound=self.medics,
                               vpool=self.pool).clauses)

        if self.medics == 0:
            return clauses

        for t in range(self.num_observations - 1):
            for num_healthy in range(self.cols * self.rows):
                for healthy_tiles in itertools.combinations(
                        self.tiles, num_healthy):
                    sick_tiles = [
                        tile for tile in self.tiles
                        if tile not in healthy_tiles
                    ]
                    clause = []

                    for i, j in healthy_tiles:
                        clause.append(-self.obj2id[f"H_{i}_{j}^{t}"])

                    for i, j in sick_tiles:
                        clause.append(self.obj2id[f"H_{i}_{j}^{t}"])

                    lits = [
                        self.obj2id[f"I0_{i}_{j}^{t + 1}"]
                        for i, j in healthy_tiles
                    ]
                    equals_clauses = CardEnc.equals(lits,
                                                    bound=min(
                                                        self.medics,
                                                        num_healthy),
                                                    vpool=self.pool).clauses
                    for sub_clause in equals_clauses:
                        temp_clause = copy.deepcopy(clause)
                        temp_clause += sub_clause
                        clauses.append(temp_clause)

        return CNF(from_clauses=clauses)

    def naveh_dynamics(self):
        clauses = []

        for t in range(1, self.num_observations):
            clauses.extend(
                CardEnc.atmost(self.__get_Q0_predicates(t),
                               bound=self.police,
                               vpool=self.pool).clauses)

        if self.police == 0:
            return clauses

        for t in range(self.num_observations - 1):
            for num_sick in range(self.cols * self.rows):
                for sick_tiles in itertools.combinations(self.tiles, num_sick):
                    healthy_tiles = [
                        tile for tile in self.tiles if tile not in sick_tiles
                    ]
                    for sick_state_perm in itertools.combinations_with_replacement(
                            self.possible_sick_states(t), num_sick):
                        clause = []

                        for (i, j), state in zip(sick_tiles, sick_state_perm):
                            clause.append(-self.obj2id[f"{state}_{i}_{j}^{t}"])
                        for i, j in healthy_tiles:
                            for state in self.possible_sick_states(t):
                                clause.append(
                                    self.obj2id[f"{state}_{i}_{j}^{t}"])

                        equals_clauses = CardEnc.equals(
                            lits=self.__get_Q0_predicates(t + 1),
                            bound=min(self.police, num_sick),
                            vpool=self.pool).clauses
                        for sub_clause in equals_clauses:
                            temp_clause = copy.deepcopy(clause)
                            temp_clause += sub_clause
                            clauses.append(temp_clause)

        return CNF(from_clauses=clauses)

    def world_dynamics(self):
        # single tile dynamics
        dynamics = CNF()
        dynamics.extend(self.U_tile_dynamics())
        dynamics.extend(self.I_tile_dynamics())
        dynamics.extend(self.S_tile_dynamics())
        dynamics.extend(self.Q_tile_dynamics())
        dynamics.extend(self.H_tile_dynamics())

        # exactly one state for each tile
        dynamics.extend(self.first_turn_rules())
        dynamics.extend(self.unique_tile_dynamics())

        # use all teams
        # dynamics.extend(self.use_all_medics_dynamics())
        dynamics.extend(self.hadar_dynamics())
        dynamics.extend(self.naveh_dynamics())

        return dynamics

    def observations_to_assumptions(self) -> list:
        obs = self.observations
        assumptions = []
        for t in range(self.num_observations):
            for i in range(self.rows):
                for j in range(self.cols):
                    if obs[t][i][j] == "H":
                        assumptions.append(self.obj2id[f"H_{i}_{j}^{t}"])
                    if obs[t][i][j] == "U":
                        assumptions.append(self.obj2id[f"U_{i}_{j}^{t}"])
                    if t == 0:
                        # assuming no Q and I in first turn
                        if obs[t][i][j] == "S":
                            assumptions.append(self.obj2id[f"S0_{i}_{j}^{t}"])
                    if t > 0:
                        # observed Q Q
                        if obs[t][i][j] == "Q" and obs[t - 1][i][j] == "Q":
                            assumptions.append(self.obj2id[f"Q1_{i}_{j}^{t}"])
                        # observed X Q
                        if obs[t][i][j] == "Q" and obs[
                                t - 1][i][j] != "Q" and obs[t -
                                                            1][i][j] != "?":
                            assumptions.append(self.obj2id[f"Q0_{i}_{j}^{t}"])
                        # observed I I
                        if obs[t][i][j] == "I" and obs[t - 1][i][j] == "I":
                            assumptions.append(self.obj2id[f"I_{i}_{j}^{t}"])
                        # observed X I
                        if obs[t][i][j] == "I" and obs[
                                t - 1][i][j] != "I" and obs[t -
                                                            1][i][j] != "?":
                            assumptions.append(self.obj2id[f"I0_{i}_{j}^{t}"])
                    if t == 1:
                        # second observation
                        # observed S S
                        if obs[t][i][j] == "S" and obs[t - 1][i][j] == "S":
                            assumptions.append(self.obj2id[f"S1_{i}_{j}^{t}"])
                        # observed X S
                        if obs[t][i][j] == "S" and obs[
                                t - 1][i][j] != "S" and obs[t -
                                                            1][i][j] != "?":
                            assumptions.append(self.obj2id[f"S0_{i}_{j}^{t}"])

                    if t > 1:
                        # third observation and on
                        # observed S S S
                        if obs[t][i][j] == "S" and obs[
                                t - 1][i][j] == "S" and obs[t -
                                                            2][i][j] == "S":
                            assumptions.append(self.obj2id[f"S2_{i}_{j}^{t}"])
                        # observed X S S
                        if obs[t][i][j] == "S" and obs[t - 1][i][
                                j] == "S" and obs[t - 2][i][j] != "S" and obs[
                                    t - 2][i][j] != "?":
                            assumptions.append(self.obj2id[f"S1_{i}_{j}^{t}"])
                        # observed S X S
                        # observed X X S
                        # observed ? X S
                        if obs[t][i][j] == "S" and obs[
                                t - 1][i][j] != "S" and obs[t -
                                                            1][i][j] != "?":
                            assumptions.append(self.obj2id[f"S0_{i}_{j}^{t}"])
        # assumptions = [[a] for a in assumptions]
        return assumptions

    def read_observations(self):
        clauses = []
        for t in range(self.num_observations):
            for i in range(self.rows):
                for j in range(self.cols):
                    if self.observations[t][i][j] == "S":
                        clauses.append([
                            self.obj2id[f"S0_{i}_{j}^{t}"],
                            self.obj2id[f"S1_{i}_{j}^{t}"],
                            self.obj2id[f"S2_{i}_{j}^{t}"]
                        ])
                        continue
                    if self.observations[t][i][j] == "Q":
                        clauses.append([
                            self.obj2id[f"Q0_{i}_{j}^{t}"],
                            self.obj2id[f"Q1_{i}_{j}^{t}"]
                        ])
                        continue
                    if self.observations[t][i][j] == "U":
                        clauses.append([self.obj2id[f"U_{i}_{j}^{t}"]])
                        continue
                    if self.observations[t][i][j] == "H":
                        clauses.append([self.obj2id[f"H_{i}_{j}^{t}"]])
                        continue
                    if self.observations[t][i][j] == "I":
                        clauses.append([
                            self.obj2id[f"I0_{i}_{j}^{t}"],
                            self.obj2id[f"I_{i}_{j}^{t}"]
                        ])
                        continue

        return CNF(from_clauses=clauses)

    def translate_query(self, query, state: bool):
        (i, j), t, s = query
        clauses = []
        if s == "U":
            clauses = [[self.pool.obj2id[f"U_{i}_{j}^{t}"]]
                       if state else [-self.pool.obj2id[f"U_{i}_{j}^{t}"]]]
        if s == "H":
            clauses = [[self.pool.obj2id[f"H_{i}_{j}^{t}"]]
                       if state else [-self.pool.obj2id[f"H_{i}_{j}^{t}"]]]
        if s == "I":
            if state:
                clauses = [[
                    self.pool.obj2id[f"I_{i}_{j}^{t}"],
                    self.pool.obj2id[f"I0_{i}_{j}^{t}"]
                ]]
            else:
                clauses = [[-self.pool.obj2id[f"I_{i}_{j}^{t}"]],
                           [-self.pool.obj2id[f"I0_{i}_{j}^{t}"]]]
        if s == "Q":
            if state:
                clauses = [[
                    self.pool.obj2id[f"Q0_{i}_{j}^{t}"],
                    self.pool.obj2id[f"Q1_{i}_{j}^{t}"]
                ]]
            else:
                clauses = [[-self.pool.obj2id[f"Q0_{i}_{j}^{t}"]],
                           [-self.pool.obj2id[f"Q1_{i}_{j}^{t}"]]]
        if s == "S":
            if state:
                clauses = [[
                    self.pool.obj2id[f"S0_{i}_{j}^{t}"],
                    self.pool.obj2id[f"S1_{i}_{j}^{t}"],
                    self.pool.obj2id[f"S2_{i}_{j}^{t}"]
                ]]
            else:
                clauses = [[-self.pool.obj2id[f"S0_{i}_{j}^{t}"]],
                           [-self.pool.obj2id[f"S1_{i}_{j}^{t}"]],
                           [-self.pool.obj2id[f"S2_{i}_{j}^{t}"]]]

        return CNF(from_clauses=clauses)

    def solve(self, solver_name="m22"):
        answers_dict = {}
        world_dynamics = self.world_dynamics()
        for q in self.queries:
            (i, j), t, s = q
            # create new solver and append world dynamics and query to it
            solver = Solver(name=solver_name)
            solver.append_formula(world_dynamics)
            solver.append_formula(self.read_observations())
            solver.append_formula(self.translate_query(q, state=True))

            assumptions = self.observations_to_assumptions()

            solution = solver.solve(assumptions=assumptions)

            if not solution:  # solution was false
                answers_dict[q] = 'F'
            else:  # solution was true
                other_states = ["Q", "U", "I", "H", "S"]
                other_states.remove(s)
                skip = False
                for other_state in other_states:
                    solver = Solver(name=solver_name)
                    solver.append_formula(world_dynamics)
                    solver.append_formula(self.read_observations())
                    q_new = (i, j), t, other_state
                    solver.append_formula(
                        self.translate_query(q_new, state=True))
                    assumptions = self.observations_to_assumptions()

                    solution = solver.solve(assumptions=assumptions)
                    if solution:  # check for ambiguity
                        answers_dict[q] = '?'
                        skip = True
                        break

                if not skip:
                    answers_dict[q] = 'T'

        return answers_dict

    @staticmethod
    def possible_sick_states(t):
        if t == 0:
            return ["S0"]
        if t == 1:
            return ["S0", "S1"]
        return ["S0", "S1", "S2"]

    def __get_neighbours_indices(self, i, j):
        neighbours_indices = [(i - 1, j), (i + 1, j), (i, j - 1), (i, j + 1)]
        if i == 0:
            neighbours_indices.remove((i - 1, j))
        if i == self.rows - 1:
            neighbours_indices.remove((i + 1, j))
        if j == 0:
            neighbours_indices.remove((i, j - 1))
        if j == self.cols - 1:
            neighbours_indices.remove((i, j + 1))
        return neighbours_indices

    def __get_I0_predicates(self, t):
        I0_predicates = []
        for i in range(self.rows):
            for j in range(self.cols):
                I0_predicates.append(self.obj2id[f"I0_{i}_{j}^{t}"])
        return I0_predicates

    def __get_Q0_predicates(self, t):
        Q0_predicates = []
        for i in range(self.rows):
            for j in range(self.cols):
                Q0_predicates.append(self.obj2id[f"Q0_{i}_{j}^{t}"])
        return Q0_predicates