Exemple #1
0
def k_sequence_WH(m, K, K_seq_len=100, count=100):
    k_seq = [Symbol('x_%i' % i, INT) for i in range(K_seq_len)]
    domain = And([Or(Equals(x, Int(0)), Equals(x, Int(1))) for x in k_seq])
    K_window = And([
        LE(Plus(k_seq[t:min(K_seq_len, t + K)]), Int(m))
        for t in range(max(1, K_seq_len - K + 1))
    ])
    formula = And(domain, K_window)
    solver = Solver(name='yices',
                    incremental=True,
                    random_seed=randint(2 << 30))
    solver.add_assertion(formula)
    for _ in range(count):
        result = solver.solve()
        if not result:
            solver = Solver(name='z3',
                            incremental=True,
                            random_seed=randint(2 << 30))
            solver.add_assertion(formula)
            solver.solve()
        model = solver.get_model()
        model = array(list(map(lambda x: model.get_py_value(x), k_seq)),
                      dtype=bool)
        yield model
        solver.add_assertion(
            Or([NotEquals(k_seq[i], Int(model[i])) for i in range(K_seq_len)]))
Exemple #2
0
    def test_msat_partial_model(self):
        msat = Solver(name="msat")
        x, y = Symbol("x"), Symbol("y")
        msat.add_assertion(x)
        c = msat.solve()
        self.assertTrue(c)

        model = msat.get_model()
        self.assertNotIn(y, model)
        self.assertIn(x, model)
        msat.exit()
Exemple #3
0
    def test_msat_partial_model(self):
        msat = Solver(name="msat")
        x, y = Symbol("x"), Symbol("y")
        msat.add_assertion(x)
        c = msat.solve()
        self.assertTrue(c)

        model = msat.get_model()
        self.assertNotIn(y, model)
        self.assertIn(x, model)
        msat.exit()
Exemple #4
0
def solve(formula, n, max_models=None, solver="msat"):
    s = Solver(name=solver)
    st = s.is_sat(formula)
    if st:
        vs = [x for xs in variables(n) for x in xs]
        k = 0
        s.add_assertion(formula)
        while s.solve() and ((not max_models) or k < max_models):
            k = k + 1
            model = s.get_model()
            s.add_assertion(Not(And([EqualsOrIff(v, model[v]) for v in vs])))
            yield to_bn(model, n)
Exemple #5
0
def k_sequence_WH_worst_case(m, K, K_seq_len=100, count=100):
    k_seq = [Symbol('x_%i' % i, INT) for i in range(K_seq_len)]
    domain = And([Or(Equals(x, Int(0)), Equals(x, Int(1))) for x in k_seq])
    K_window = And([
        LE(Plus(k_seq[t:min(K_seq_len, t + K)]), Int(m))
        for t in range(max(1, K_seq_len - K + 1))
    ])
    violate_up = And([
        GT(Plus(k_seq[t:min(K_seq_len, t + K)]), Int(m - 1))
        for t in range(max(1, K_seq_len - K + 1))
    ])

    def violate_right_generator(n):
        return And([
            GT(Plus(k_seq[t:min(K_seq_len, t + K + n)]), Int(m))
            for t in range(max(1, K_seq_len - (K + n) + 1))
        ])

    right_shift = 1
    formula = And(domain, K_window, violate_up,
                  violate_right_generator(right_shift))
    solver = Solver(name='z3', incremental=True, random_seed=randint(2 << 30))
    solver.add_assertion(formula)
    solver.z3.set('timeout', 5 * 60 * 1000)
    solutions = And()
    for _ in range(count):
        while right_shift + K < K_seq_len:
            try:
                result = solver.solve()
            except BaseException:
                result = None
            if not result:
                solver = Solver(name='z3',
                                incremental=True,
                                random_seed=randint(2 << 30))
                right_shift += 1
                solver.z3.set('timeout', 5 * 60 * 1000)
                solver.add_assertion(
                    And(solutions, domain, K_window, violate_up,
                        violate_right_generator(right_shift)))
            else:
                break
        try:
            model = solver.get_model()
        except BaseException:
            break
        model = array(list(map(lambda x: model.get_py_value(x), k_seq)),
                      dtype=bool)
        yield model
        solution = Or(
            [NotEquals(k_seq[i], Int(model[i])) for i in range(K_seq_len)])
        solutions = And(solutions, solution)
        solver.add_assertion(solution)
Exemple #6
0
class SATSolver:
    def __init__(self, solver_name):
        self.solver = Solver(solver_name, logic=QF_BOOL)

    def get_model(self):
        m = None
        res = self.solver.solve()
        if res:
            m = self.solver.get_model()
        return m

    def enum_model(self, blocking_cls):
        m = None
        self.solver.add_assertion(Not(And(blocking_cls)))
        res = self.solver.solve()
        if res:
            m = self.solver.get_model()
        return m

    def assert_cls(self, _cls):
        self.solver.add_assertion(_cls)

    def add_cls(self, _cls, level):
        self.solver.push(level)
        self.solver.add_assertion(_cls)

    def remove_cls(self, level):
        self.solver.pop(level)

    def unsat_core(self):

        res = self.solver.solve()

        if not res:
            print('Assertions:', self.fml)
            conj = conjunctive_partition(self.fml)
            ucore = get_unsat_core(conj)
            print("UNSAT-Core size '%d'" % len(ucore))
        for f in ucore:
            print(f.serialize())
Exemple #7
0
 def _model_iterator_base(formula):
     """Finds all the total truth assignments that satisfy the given formula.
     
     Args:
         formula (FNode): The pysmt formula to examine.
         
     Yields:
         model: The model representing the next total truth assignment that satisfies the formula.
     
     """
     solver = Solver(name="msat")
     solver.add_assertion(formula)
     while solver.solve():
         model = solver.get_model()
         yield model
         atom_assignments = {a : model.get_value(a)
                                for a in formula.get_atoms()}
                                
         # Constrain the solver to find a different assignment
         solver.add_assertion(
             Not(And([Iff(var,val)
                      for var,val in atom_assignments.items()])))
Exemple #8
0
class SMTValidator(object):
    """
        Validating Anchor's explanations using SMT solving.
    """
    def __init__(self, formula, feats, nof_classes, xgb):
        """
            Constructor.
        """

        self.ftids = {f: i for i, f in enumerate(feats)}
        self.nofcl = nof_classes
        self.idmgr = IDPool()
        self.optns = xgb.options

        # xgbooster will also be needed
        self.xgb = xgb

        self.verbose = self.optns.verb
        self.oracle = Solver(name=self.xgb.options.solver)

        self.inps = []  # input (feature value) variables
        for f in self.xgb.extended_feature_names_as_array_strings:
            if '_' not in f:
                self.inps.append(Symbol(f, typename=REAL))
            else:
                self.inps.append(Symbol(f, typename=BOOL))

        self.outs = []  # output (class  score) variables
        for c in range(self.nofcl):
            self.outs.append(Symbol('class{0}_score'.format(c), typename=REAL))

        # theory
        self.oracle.add_assertion(formula)

        # current selector
        self.selv = None

    def prepare(self, sample, expl):
        """
            Prepare the oracle for validating an explanation given a sample.
        """

        if self.selv:
            # disable the previous assumption if any
            self.oracle.add_assertion(Not(self.selv))

        # creating a fresh selector for a new sample
        sname = ','.join([str(v).strip() for v in sample])

        # the samples should not repeat; otherwise, they will be
        # inconsistent with the previously introduced selectors
        assert sname not in self.idmgr.obj2id, 'this sample has been considered before (sample {0})'.format(
            self.idmgr.id(sname))
        self.selv = Symbol('sample{0}_selv'.format(self.idmgr.id(sname)),
                           typename=BOOL)

        self.rhypos = []  # relaxed hypotheses

        # transformed sample
        self.sample = list(self.xgb.transform(sample)[0])

        # preparing the selectors
        for i, (inp, val) in enumerate(zip(self.inps, self.sample), 1):
            feat = inp.symbol_name().split('_')[0]
            selv = Symbol('selv_{0}'.format(feat))
            val = float(val)

            self.rhypos.append(selv)

        # adding relaxed hypotheses to the oracle
        for inp, val, sel in zip(self.inps, self.sample, self.rhypos):
            if '_' not in inp.symbol_name():
                hypo = Implies(self.selv,
                               Implies(sel, Equals(inp, Real(float(val)))))
            else:
                hypo = Implies(self.selv, Implies(sel,
                                                  inp if val else Not(inp)))

            self.oracle.add_assertion(hypo)

        # propagating the true observation
        if self.oracle.solve([self.selv] + self.rhypos):
            model = self.oracle.get_model()
        else:
            assert 0, 'Formula is unsatisfiable under given assumptions'

        # choosing the maximum
        outvals = [float(model.get_py_value(o)) for o in self.outs]
        maxoval = max(zip(outvals, range(len(outvals))))

        # correct class id (corresponds to the maximum computed)
        true_output = maxoval[1]

        # forcing a misclassification, i.e. a wrong observation
        disj = []
        for i in range(len(self.outs)):
            if i != true_output:
                disj.append(GT(self.outs[i], self.outs[true_output]))
        self.oracle.add_assertion(Implies(self.selv, Or(disj)))

        # removing all hypotheses except for those in the explanation
        hypos = []
        for i, hypo in enumerate(self.rhypos):
            j = self.ftids[self.xgb.transform_inverse_by_index(i)[0]]
            if j in expl:
                hypos.append(hypo)
        self.rhypos = hypos

        if self.verbose:
            inpvals = self.xgb.readable_sample(sample)

            preamble = []
            for f, v in zip(self.xgb.feature_names, inpvals):
                if f not in v:
                    preamble.append('{0} = {1}'.format(f, v))
                else:
                    preamble.append(v)

            print('  explanation for:  "IF {0} THEN {1}"'.format(
                ' AND '.join(preamble), self.xgb.target_name[true_output]))

    def validate(self, sample, expl):
        """
            Make an effort to show that the explanation is too optimistic.
        """

        self.time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \
                resource.getrusage(resource.RUSAGE_SELF).ru_utime

        # adapt the solver to deal with the current sample
        self.prepare(sample, expl)

        # if satisfiable, then there is a counterexample
        if self.oracle.solve([self.selv] + self.rhypos):
            model = self.oracle.get_model()
            inpvals = [float(model.get_py_value(i)) for i in self.inps]
            outvals = [float(model.get_py_value(o)) for o in self.outs]
            maxoval = max(zip(outvals, range(len(outvals))))

            inpvals = self.xgb.transform_inverse(np.array(inpvals))[0]
            self.coex = tuple([inpvals, maxoval[1]])
            inpvals = self.xgb.readable_sample(inpvals)

            if self.verbose:
                preamble = []
                for f, v in zip(self.xgb.feature_names, inpvals):
                    if f not in v:
                        preamble.append('{0} = {1}'.format(f, v))
                    else:
                        preamble.append(v)

                print('  explanation is incorrect')
                print('  counterexample: "IF {0} THEN {1}"'.format(
                    ' AND '.join(preamble), self.xgb.target_name[maxoval[1]]))
        else:
            self.coex = None

            if self.verbose:
                print('  explanation is correct')

        self.time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \
                resource.getrusage(resource.RUSAGE_SELF).ru_utime - self.time

        if self.verbose:
            print('  time: {0:.2f}'.format(self.time))

        return self.coex
Exemple #9
0
class SMTExplainer(object):
    """
        An SMT-inspired minimal explanation extractor for XGBoost models.
    """
    def __init__(self, formula, intvs, imaps, ivars, feats, nof_classes,
                 options, xgb):
        """
            Constructor.
        """

        self.feats = feats
        self.intvs = intvs
        self.imaps = imaps
        self.ivars = ivars
        self.nofcl = nof_classes
        self.optns = options
        self.idmgr = IDPool()

        # saving XGBooster
        self.xgb = xgb

        self.verbose = self.optns.verb
        self.oracle = Solver(name=options.solver)

        self.inps = []  # input (feature value) variables
        for f in self.xgb.extended_feature_names_as_array_strings:
            if '_' not in f:
                self.inps.append(Symbol(f, typename=REAL))
            else:
                self.inps.append(Symbol(f, typename=BOOL))

        self.outs = []  # output (class  score) variables
        for c in range(self.nofcl):
            self.outs.append(Symbol('class{0}_score'.format(c), typename=REAL))

        # theory
        self.oracle.add_assertion(formula)

        # current selector
        self.selv = None

        # save and use dual explanations whenever needed
        self.dualx = []

        # number of oracle calls involved
        self.calls = 0

    def prepare(self, sample):
        """
            Prepare the oracle for computing an explanation.
        """

        if self.selv:
            # disable the previous assumption if any
            self.oracle.add_assertion(Not(self.selv))

        # creating a fresh selector for a new sample
        sname = ','.join([str(v).strip() for v in sample])

        # the samples should not repeat; otherwise, they will be
        # inconsistent with the previously introduced selectors
        assert sname not in self.idmgr.obj2id, 'this sample has been considered before (sample {0})'.format(
            self.idmgr.id(sname))
        self.selv = Symbol('sample{0}_selv'.format(self.idmgr.id(sname)),
                           typename=BOOL)

        self.rhypos = []  # relaxed hypotheses

        # transformed sample
        self.sample = list(self.xgb.transform(sample)[0])

        self.sel2fid = {}  # selectors to original feature ids
        self.sel2vid = {}  # selectors to categorical feature ids

        # preparing the selectors
        for i, (inp, val) in enumerate(zip(self.inps, self.sample), 1):
            feat = inp.symbol_name().split('_')[0]
            selv = Symbol('selv_{0}'.format(feat))
            val = float(val)

            self.rhypos.append(selv)
            if selv not in self.sel2fid:
                self.sel2fid[selv] = int(feat[1:])
                self.sel2vid[selv] = [i - 1]
            else:
                self.sel2vid[selv].append(i - 1)

        # adding relaxed hypotheses to the oracle
        if not self.intvs:
            for inp, val, sel in zip(self.inps, self.sample, self.rhypos):
                if '_' not in inp.symbol_name():
                    hypo = Implies(self.selv,
                                   Implies(sel, Equals(inp, Real(float(val)))))
                else:
                    hypo = Implies(self.selv,
                                   Implies(sel, inp if val else Not(inp)))

                self.oracle.add_assertion(hypo)
        else:
            for inp, val, sel in zip(self.inps, self.sample, self.rhypos):
                inp = inp.symbol_name()
                # determining the right interval and the corresponding variable
                for ub, fvar in zip(self.intvs[inp], self.ivars[inp]):
                    if ub == '+' or val < ub:
                        hypo = Implies(self.selv, Implies(sel, fvar))
                        break

                self.oracle.add_assertion(hypo)

        # in case of categorical data, there are selector duplicates
        # and we need to remove them
        self.rhypos = sorted(set(self.rhypos),
                             key=lambda x: int(x.symbol_name()[6:]))

        # propagating the true observation
        if self.oracle.solve([self.selv] + self.rhypos):
            model = self.oracle.get_model()
        else:
            assert 0, 'Formula is unsatisfiable under given assumptions'

        # choosing the maximum
        outvals = [float(model.get_py_value(o)) for o in self.outs]
        maxoval = max(zip(outvals, range(len(outvals))))

        # correct class id (corresponds to the maximum computed)
        self.out_id = maxoval[1]
        self.output = self.xgb.target_name[self.out_id]

        # forcing a misclassification, i.e. a wrong observation
        disj = []
        for i in range(len(self.outs)):
            if i != self.out_id:
                disj.append(GT(self.outs[i], self.outs[self.out_id]))
        self.oracle.add_assertion(Implies(self.selv, Or(disj)))

        if self.verbose:
            inpvals = self.xgb.readable_sample(sample)

            self.preamble = []
            for f, v in zip(self.xgb.feature_names, inpvals):
                if f not in v:
                    self.preamble.append('{0} = {1}'.format(f, v))
                else:
                    self.preamble.append(v)

            print('  explaining:  "IF {0} THEN {1}"'.format(
                ' AND '.join(self.preamble), self.output))

    def explain(self, sample, smallest):
        """
            Hypotheses minimization.
        """

        # reinitializing the number of used oracle calls
        # 1 because of the initial call checking the entailment
        self.calls = 1

        # adapt the solver to deal with the current sample
        self.prepare(sample)

        # saving external explanation to be minimized further
        self.to_consider = [True for h in self.rhypos]

        self.time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \
                resource.getrusage(resource.RUSAGE_SELF).ru_utime

        # if satisfiable, then the observation is not implied by the hypotheses
        if self.oracle.solve(
            [self.selv] +
            [h for h, c in zip(self.rhypos, self.to_consider) if c]):
            print('  no implication!')
            print(self.oracle.get_model())
            sys.exit(1)

        if self.optns.xtype == 'abductive':
            # abductive explanations => MUS computation and enumeration
            if not smallest and self.optns.xnum == 1:
                expls = [self.compute_minimal_abductive()]
            else:
                expls = self.enumerate_abductive(smallest=smallest)
        else:  # contrastive explanations => MCS enumeration
            if self.optns.usemhs:
                expls = self.enumerate_contrastive()
            else:
                if not smallest:
                    expls = self.enumerate_minimal_contrastive()
                else:
                    # expls = self.enumerate_smallest_contrastive()
                    expls = self.enumerate_contrastive()

        self.time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \
                resource.getrusage(resource.RUSAGE_SELF).ru_utime - self.time

        expls = list(
            map(lambda expl: sorted([self.sel2fid[h] for h in expl]), expls))

        if self.dualx:
            self.dualx = list(
                map(lambda expl: sorted([self.sel2fid[h] for h in expl]),
                    self.dualx))

        if self.verbose:
            if expls[0] != None:
                for expl in expls:
                    preamble = [self.preamble[i] for i in expl]
                    if self.optns.xtype == 'abductive':
                        print('  explanation: "IF {0} THEN {1}"'.format(
                            ' AND '.join(preamble),
                            self.xgb.target_name[self.out_id]))
                    else:
                        print(
                            '  explanation: "IF NOT {0} THEN NOT {1}"'.format(
                                ' AND NOT '.join(preamble),
                                self.xgb.target_name[self.out_id]))
                    print('  # hypos left:', len(expl))

            print('  time: {0:.2f}'.format(self.time))

        # here we return the last computed explanation
        return expls

    def compute_minimal_abductive(self):
        """
            Compute any subset-minimal explanation.
        """

        i = 0

        # filtering out unnecessary features if external explanation is given
        rhypos = [h for h, c in zip(self.rhypos, self.to_consider) if c]

        # simple deletion-based linear search
        while i < len(rhypos):
            to_test = rhypos[:i] + rhypos[(i + 1):]

            self.calls += 1
            if self.oracle.solve([self.selv] + to_test):
                i += 1
            else:
                rhypos = to_test

        return rhypos

    def enumerate_minimal_contrastive(self):
        """
            Compute a subset-minimal contrastive explanation.
        """
        def _overapprox():
            model = self.oracle.get_model()

            for sel in self.rhypos:
                if int(model.get_py_value(sel)) > 0:
                    # soft clauses contain positive literals
                    # so if var is true then the clause is satisfied
                    self.ss_assumps.append(sel)
                else:
                    self.setd.append(sel)

        def _compute():
            i = 0
            while i < len(self.setd):
                if self.optns.usecld:
                    _do_cld_check(self.setd[i:])
                    i = 0

                if self.setd:
                    # it may be empty after the clause D check

                    self.calls += 1
                    self.ss_assumps.append(self.setd[i])
                    if not self.oracle.solve([self.selv] + self.ss_assumps +
                                             self.bb_assumps):
                        self.ss_assumps.pop()
                        self.bb_assumps.append(Not(self.setd[i]))

                i += 1

        def _do_cld_check(cld):
            self.cldid += 1
            sel = Symbol('{0}_{1}'.format(self.selv.symbol_name(), self.cldid))
            cld.append(Not(sel))

            # adding clause D
            self.oracle.add_assertion(Or(cld))
            self.ss_assumps.append(sel)

            self.setd = []
            st = self.oracle.solve([self.selv] + self.ss_assumps +
                                   self.bb_assumps)

            self.ss_assumps.pop()  # removing clause D assumption
            if st == True:
                model = self.oracle.get_model()

                for l in cld[:-1]:
                    # filtering all satisfied literals
                    if int(model.get_py_value(l)) > 0:
                        self.ss_assumps.append(l)
                    else:
                        self.setd.append(l)
            else:
                # clause D is unsatisfiable => all literals are backbones
                self.bb_assumps.extend([Not(l) for l in cld[:-1]])

            # deactivating clause D
            self.oracle.add_assertion(Not(sel))

        # sets of selectors to work with
        self.cldid = 0
        expls = []

        # detect and block unit-size MCSes immediately
        if self.optns.unitmcs:
            for i, hypo in enumerate(self.rhypos):
                self.calls += 1
                if self.oracle.solve([self.selv] + self.rhypos[:i] +
                                     self.rhypos[(i + 1):]):
                    expls.append([hypo])

                    if len(expls) != self.optns.xnum:
                        self.oracle.add_assertion(Or([Not(self.selv), hypo]))
                    else:
                        break

        self.calls += 1
        while self.oracle.solve([self.selv]):
            self.ss_assumps, self.bb_assumps, self.setd = [], [], []
            _overapprox()
            _compute()

            expl = [list(f.get_free_variables())[0] for f in self.bb_assumps]
            expls.append(expl)

            if len(expls) == self.optns.xnum:
                break

            self.oracle.add_assertion(Or([Not(self.selv)] + expl))
            self.calls += 1

        self.calls += self.cldid
        return expls if expls else [None]

    def enumerate_abductive(self, smallest=True):
        """
            Compute a cardinality-minimal explanation.
        """

        # result
        expls = []

        # just in case, let's save dual (contrastive) explanations
        self.dualx = []

        with Hitman(bootstrap_with=[[
                i for i in range(len(self.rhypos)) if self.to_consider[i]
        ]],
                    htype='sorted' if smallest else 'lbx') as hitman:
            # computing unit-size MCSes
            for i, hypo in enumerate(self.rhypos):
                if self.to_consider[i] == False:
                    continue

                self.calls += 1
                if self.oracle.solve([self.selv] + self.rhypos[:i] +
                                     self.rhypos[(i + 1):]):
                    hitman.hit([i])
                    self.dualx.append([self.rhypos[i]])

            # main loop
            iters = 0
            while True:
                hset = hitman.get()
                iters += 1

                if self.verbose > 1:
                    print('iter:', iters)
                    print('cand:', hset)

                if hset == None:
                    break

                self.calls += 1
                if self.oracle.solve([self.selv] +
                                     [self.rhypos[i] for i in hset]):
                    to_hit = []
                    satisfied, unsatisfied = [], []

                    removed = list(
                        set(range(len(self.rhypos))).difference(set(hset)))

                    model = self.oracle.get_model()
                    for h in removed:
                        i = self.sel2fid[self.rhypos[h]]
                        if '_' not in self.inps[i].symbol_name():
                            # feature variable and its expected value
                            var, exp = self.inps[i], self.sample[i]

                            # true value
                            true_val = float(model.get_py_value(var))

                            if not exp - 0.001 <= true_val <= exp + 0.001:
                                unsatisfied.append(h)
                            else:
                                hset.append(h)
                        else:
                            for vid in self.sel2vid[self.rhypos[h]]:
                                var, exp = self.inps[vid], int(
                                    self.sample[vid])

                                # true value
                                true_val = int(model.get_py_value(var))

                                if exp != true_val:
                                    unsatisfied.append(h)
                                    break
                            else:
                                hset.append(h)

                    # computing an MCS (expensive)
                    for h in unsatisfied:
                        self.calls += 1
                        if self.oracle.solve([self.selv] +
                                             [self.rhypos[i] for i in hset] +
                                             [self.rhypos[h]]):
                            hset.append(h)
                        else:
                            to_hit.append(h)

                    if self.verbose > 1:
                        print('coex:', to_hit)

                    hitman.hit(to_hit)

                    self.dualx.append([self.rhypos[i] for i in to_hit])
                else:
                    if self.verbose > 1:
                        print('expl:', hset)

                    expl = [self.rhypos[i] for i in hset]
                    expls.append(expl)

                    if len(expls) != self.optns.xnum:
                        hitman.block(hset)
                    else:
                        break

        return expls

    def enumerate_smallest_contrastive(self):
        """
            Compute a cardinality-minimal contrastive explanation.
        """

        # result
        expls = []

        # computing unit-size MUSes
        muses = set([])
        for hypo in self.rhypos:
            self.calls += 1
            if not self.oracle.solve([self.selv, hypo]):
                muses.add(hypo)

        # we are going to discard unit-size MUSes from consideration
        rhypos = set(self.rhypos).difference(muses)

        # introducing interer cost literals for rhypos
        costlits = []
        for i, hypo in enumerate(rhypos):
            costlit = Symbol(name='costlit_{0}_{1}'.format(
                self.selv.symbol_name(), i),
                             typename=INT)
            costlits.append(costlit)

            self.oracle.add_assertion(
                Ite(hypo, Equals(costlit, Int(0)), Equals(costlit, Int(1))))

        # main loop (linear search unsat-sat)
        i = 0
        while i < len(rhypos) and len(expls) != self.optns.xnum:
            # fresh selector for the current iteration
            sit = Symbol('iter_{0}_{1}'.format(self.selv.symbol_name(), i))

            # adding cardinality constraint
            self.oracle.add_assertion(Implies(sit, LE(Plus(costlits), Int(i))))

            # extracting explanations from MaxSAT models
            while self.oracle.solve([self.selv, sit]):
                self.calls += 1
                model = self.oracle.get_model()

                expl = []
                for hypo in rhypos:
                    if int(model.get_py_value(hypo)) == 0:
                        expl.append(hypo)

                # each MCS contains all unit-size MUSes
                expls.append(list(muses) + expl)

                # either stop or add a blocking clause
                if len(expls) != self.optns.xnum:
                    self.oracle.add_assertion(Implies(self.selv, Or(expl)))
                else:
                    break

            i += 1
            self.calls += 1

        return expls

    def enumerate_contrastive(self, smallest=True):
        """
            Compute a cardinality-minimal contrastive explanation.
        """

        # core extraction is done via calling Z3's internal API
        assert self.optns.solver == 'z3', 'This procedure requires Z3'

        # result
        expls = []

        # just in case, let's save dual (abductive) explanations
        self.dualx = []

        # mapping from hypothesis variables to their indices
        hmap = {h: i for i, h in enumerate(self.rhypos)}

        # mapping from internal Z3 variable into variables of PySMT
        vmap = {self.oracle.converter.convert(v): v for v in self.rhypos}
        vmap[self.oracle.converter.convert(self.selv)] = None

        def _get_core():
            core = self.oracle.z3.unsat_core()
            return sorted(filter(lambda x: x != None,
                                 map(lambda x: vmap[x], core)),
                          key=lambda x: int(x.symbol_name()[6:]))

        def _do_trimming(core):
            for i in range(self.optns.trim):
                self.calls += 1
                self.oracle.solve([self.selv] + core)
                new_core = _get_core()
                if len(core) == len(new_core):
                    break
            return new_core

        def _reduce_lin(core):
            def _assump_needed(a):
                if len(to_test) > 1:
                    to_test.remove(a)
                    self.calls += 1
                    if not self.oracle.solve([self.selv] + list(to_test)):
                        return False
                    to_test.add(a)
                    return True
                else:
                    return True

            to_test = set(core)
            return list(filter(lambda a: _assump_needed(a), core))

        def _reduce_qxp(core):
            coex = core[:]
            filt_sz = len(coex) / 2.0
            while filt_sz >= 1:
                i = 0
                while i < len(coex):
                    to_test = coex[:i] + coex[(i + int(filt_sz)):]
                    self.calls += 1
                    if to_test and not self.oracle.solve([self.selv] +
                                                         to_test):
                        # assumps are not needed
                        coex = to_test
                    else:
                        # assumps are needed => check the next chunk
                        i += int(filt_sz)
                # decreasing size of the set to filter
                filt_sz /= 2.0
                if filt_sz > len(coex) / 2.0:
                    # next size is too large => make it smaller
                    filt_sz = len(coex) / 2.0
            return coex

        def _reduce_coex(core):
            if self.optns.reduce == 'lin':
                return _reduce_lin(core)
            else:  # qxp
                return _reduce_qxp(core)

        with Hitman(bootstrap_with=[[
                i for i in range(len(self.rhypos)) if self.to_consider[i]
        ]],
                    htype='sorted' if smallest else 'lbx') as hitman:
            # computing unit-size MUSes
            for i, hypo in enumerate(self.rhypos):
                if self.to_consider[i] == False:
                    continue

                self.calls += 1
                if not self.oracle.solve([self.selv, self.rhypos[i]]):
                    hitman.hit([i])
                    self.dualx.append([self.rhypos[i]])
                elif self.optns.unitmcs:
                    self.calls += 1
                    if self.oracle.solve([self.selv] + self.rhypos[:i] +
                                         self.rhypos[(i + 1):]):
                        # this is a unit-size MCS => block immediately
                        hitman.block([i])
                        expls.append([self.rhypos[i]])

            # main loop
            iters = 0
            while True:
                hset = hitman.get()
                iters += 1

                if self.verbose > 1:
                    print('iter:', iters)
                    print('cand:', hset)

                if hset == None:
                    break

                self.calls += 1
                if not self.oracle.solve([self.selv] + [
                        self.rhypos[h] for h in list(
                            set(range(len(self.rhypos))).difference(set(hset)))
                ]):
                    to_hit = _get_core()

                    if len(to_hit) > 1 and self.optns.trim:
                        to_hit = _do_trimming(to_hit)

                    if len(to_hit) > 1 and self.optns.reduce != 'none':
                        to_hit = _reduce_coex(to_hit)

                    self.dualx.append(to_hit)
                    to_hit = [hmap[h] for h in to_hit]

                    if self.verbose > 1:
                        print('coex:', to_hit)

                    hitman.hit(to_hit)
                else:
                    if self.verbose > 1:
                        print('expl:', hset)

                    expl = [self.rhypos[i] for i in hset]
                    expls.append(expl)

                    if len(expls) != self.optns.xnum:
                        hitman.block(hset)
                    else:
                        break

        return expls
Exemple #10
0
def get_makespan_optimal_weakly_hard_schedule(g,
                                              network,
                                              feasibility_timeout=None,
                                              optimization_timeout=10 * 60 *
                                              1000):
    vprint('*computing optimal weakly-hard real-time schedule via SMT*')
    # SMT formulation
    tc = transitive_closure(g)
    logical_edges = get_logical_edges(g)
    JUMPTABLE_MAX = 6
    K_MAX = 5001  # LAMBDA(i)[1] < K_MAX for all i < JUMPTABLE_MAX

    A, B, C, D, GAMMA, LAMBDA = (network[key] for key in ('A', 'B', 'C', 'D',
                                                          'GAMMA', 'LAMBDA'))

    assert (all(map(lambda x: LAMBDA(x)[1] < K_MAX, range(JUMPTABLE_MAX))))

    vprint('\tinstantiating symvars...')
    label = [Symbol('label_%i' % i, INT) for i in range(len(logical_edges))]
    # first half for slot, second half for beacons
    chi = [Symbol('chi_%i' % i, INT) for i in range(2 * len(logical_edges))]
    duration = [
        Symbol('duration_%i' % i, INT) for i in range(len(logical_edges))
    ]
    zeta = [
        Symbol('zeta_%i' % i, INT)
        for i in range(g.num_vertices() + len(logical_edges))
    ]
    delta_e_in_r = [[
        Symbol('delta_e_in_r-%i_%i' % (i, j), INT)
        for j in range(len(logical_edges))
    ] for i in range(len(logical_edges))]
    delta_chi_eq_i = [[
        Symbol('delta_chi_eq_i-%i_%i' % (i, j), INT)
        for j in range(JUMPTABLE_MAX)
    ] for i in range(2 * len(logical_edges))]
    delta_tau_before_r = [[
        Symbol('delta_tau_before_r-%i_%i' % (i, j), INT)
        for j in range(len(logical_edges))
    ] for i in range(g.num_vertices() + len(logical_edges))]

    vprint('\tgenerating constraint clauses...')
    domain = And([
        And([
            And(LE(Int(1), sym), LE(sym, Int(len(logical_edges))))
            for sym in label
        ]),
        And([And(LE(Int(1), sym), LT(sym, Int(JUMPTABLE_MAX)))
             for sym in chi]),
        And([LE(Int(0), sym) for sym in zeta]),
        And([
            And(LE(Int(0), sym), LE(sym, Int(1)))
            for sym in chain.from_iterable(delta_e_in_r + delta_chi_eq_i +
                                           delta_tau_before_r)
        ])
    ])
    one_hot = And([
        And([
            Equals(
                Plus([delta_e_in_r[e][r] for r in range(len(logical_edges))]),
                Int(1)) for e in range(len(logical_edges))
        ]),
        And([
            Equals(
                Plus([delta_chi_eq_i[chir][i] for i in range(JUMPTABLE_MAX)]),
                Int(1)) for chir in range(2 * len(logical_edges))
        ])
    ])
    CFOP = And([
        LT(label[logical_edges.index(r)], label[logical_edges.index(s)])
        for r, s in product(logical_edges, repeat=2)
        if r.source() in tc.get_in_neighbors(s.source())
    ])
    task_partitioning_by_round = And(
        And([
            LE(delta_tau_before_r[int(tau)][r], delta_tau_before_r[int(mu)][r])
            for tau, mu, r in product(tc.vertices(), tc.vertices(),
                                      range(len(logical_edges)))
            if tau in tc.get_in_neighbors(mu)
        ]),
        And([
            Equals(delta_tau_before_r[r + g.num_vertices()][s], Int(0)) if
            r < s else Equals(delta_tau_before_r[r +
                                                 g.num_vertices()][s], Int(1))
            for r, s in product(range(len(logical_edges)), repeat=2)
        ]))
    round_empty = And([
        Implies(
            Equals(
                Plus([delta_e_in_r[e][r] for e in range(len(logical_edges))]),
                Int(0)), Equals(chi[len(logical_edges) + r], Int(1)))
        for r in range(len(logical_edges))
    ])
    durations = And([
        Equals(
            duration[r],
            Plus(
                Int(A),
                Times(Plus(Times(Int(2), chi[r + len(logical_edges)]), Int(B)),
                      Int(C + D * GAMMA)),
                Times(
                    Ite(
                        GE(
                            Plus([
                                delta_e_in_r[e][r]
                                for e in range(len(logical_edges))
                            ]), Int(1)), Int(0), Int(-1)),
                    Int(A + (2 + B) * (C + D * GAMMA))),
                Plus([
                    Ite(
                        Equals(delta_e_in_r[e][r], Int(1)),
                        Plus(
                            Int(A),
                            Times(
                                Plus(Times(Int(2), chi[e]), Int(B)),
                                Int(C + D * g.edge_properties['widths'][
                                    logical_edges[e]]))), Int(0))
                    for e in range(len(logical_edges))
                ]))) for r in range(len(logical_edges))
    ])
    label_to_delta = And([
        Equals(
            label[e],
            Plus([
                Times(delta_e_in_r[e][i - 1], Int(i))
                for i in range(1, 1 + len(logical_edges))
            ])) for e in range(len(logical_edges))
    ])
    chi_to_delta = And([
        Equals(
            chi[chir],
            Plus([
                Times(delta_chi_eq_i[chir][i], Int(i))
                for i in range(JUMPTABLE_MAX)
            ])) for chir in range(2 * len(logical_edges))
    ])
    order = And(
        And([
            LT(zeta[int(tau)],
               Minus(zeta[int(mu)], Int(g.vertex_properties['durations'][mu])))
            for tau, mu in product(g.vertices(), repeat=2)
            if tau in tc.get_in_neighbors(mu)
        ]),
        And([
            LT(zeta[r + g.num_vertices()],
               Minus(zeta[r + 1 + g.num_vertices()], duration[r + 1]))
            for r in range(len(logical_edges) - 1)
        ]),
        And([
            Implies(
                Equals(delta_e_in_r[e][r], Int(1)),
                GT(
                    Minus(zeta[int(tau)],
                          Int(g.vertex_properties['durations'][tau])),
                    zeta[r + g.num_vertices()])) for tau in g.vertices()
            for r in range(len(logical_edges))
            for e in range(len(logical_edges))
            if tau in tc.get_out_neighbors(logical_edges[e].source())
        ]),
        And([
            Implies(
                Equals(delta_e_in_r[e][r], Int(1)),
                GT(Minus(zeta[r + g.num_vertices()], duration[r]),
                   zeta[int(tau)])) for tau in g.vertices()
            for r in range(len(logical_edges))
            for e in range(len(logical_edges))
            if tau in tc.get_in_neighbors(logical_edges[e].source())
            or tau == logical_edges[e].source()
        ]))
    exclusion = And([
        And(
            Implies(
                Equals(delta_tau_before_r[int(tau)][r], Int(0)),
                GT(Minus(zeta[r + g.num_vertices()], duration[r]),
                   zeta[int(tau)])),
            Implies(
                Equals(delta_tau_before_r[int(tau)][r], Int(1)),
                GT(
                    Minus(zeta[int(tau)],
                          Int(g.vertex_properties['durations'][tau])),
                    zeta[g.num_vertices() + r]))) for tau in g.vertices()
        for r in range(len(logical_edges))
    ])
    deadline = And([
        LE(zeta[int(tau)], Int(g.vertex_properties['deadlines'][tau]))
        for tau in g.vertices() if g.vertex_properties['deadlines'][tau] >= 0
    ])

    def sum_m(tau):
        return Plus([Int(0)] + [
            Plus(
                Ite(Equals(delta_chi_eq_i[e][i], Int(1)), Int(LAMBDA(i)[0]),
                    Int(0)),
                Plus([
                    Ite(
                        Equals(delta_chi_eq_i[len(logical_edges) +
                                              r][i], Int(1)),
                        Ite(Equals(delta_e_in_r[e][r], Int(1)),
                            Int(LAMBDA(i)[0]), Int(0)), Int(0))
                    for r in range(len(logical_edges))
                ])) for i in range(JUMPTABLE_MAX)
            for e in range(len(logical_edges))
            if logical_edges[e].source() in tc.get_in_neighbors(tau)
        ])

    def min_K(tau):
        return Min([Int(K_MAX)] + [
            Min(
                Ite(Equals(delta_chi_eq_i[e][i], Int(1)), Int(LAMBDA(i)[1]),
                    Int(K_MAX)),
                Min([
                    Ite(
                        Equals(delta_chi_eq_i[len(logical_edges) +
                                              r][i], Int(1)),
                        Ite(Equals(delta_e_in_r[e][r], Int(1)),
                            Int(LAMBDA(i)[1]), Int(K_MAX)), Int(K_MAX))
                    for r in range(len(logical_edges))
                ])) for i in range(JUMPTABLE_MAX)
            for r in range(len(logical_edges))
            for e in range(len(logical_edges))
            if logical_edges[e].source() in tc.get_in_neighbors(tau)
        ])

    WH = And([
        And(
            GE(Int(g.vertex_properties['weakly-hard'][tau][0]),
               Min(sum_m(tau), min_K(tau))),
            LE(Int(g.vertex_properties['weakly-hard'][tau][1]), min_K(tau)))
        for tau in g.vertices()
        if g.vertex_properties['weakly-hard'][tau][0] >= 0
    ])

    formula = And([
        domain, one_hot, CFOP, task_partitioning_by_round, round_empty,
        durations, label_to_delta, chi_to_delta, order, exclusion, deadline, WH
    ])
    vprint('\tchecking feasibility...')
    solver = Solver(name='z3', incremental=True, logic='LIA')
    if feasibility_timeout:
        solver.z3.set('timeout', feasibility_timeout)
    solver.add_assertion(formula)
    try:
        result = solver.solve()
    except SolverReturnedUnknownResultError:
        result = None
    if not result:
        vprint('\tsolver returned infeasible!')
        return [None] * 4
    else:
        models = [solver.get_model()]
        vprint('\tsolver found a feasible solution, optimizing...')

    solver.z3.set('timeout', optimization_timeout)
    LB = 0
    UB = max(map(lambda x: models[-1].get_py_value(x), zeta))
    curr_B = UB // 2
    while range(LB + 1, UB):
        try:
            result = solver.solve(
                [And([LT(zeta_tau, Int(curr_B)) for zeta_tau in zeta])])
        except SolverReturnedUnknownResultError:
            vprint('\t(timeout, not necessarily unsat)')
            result = None
        if result:
            vprint('\tfound feasible solution of length %i, optimizing...' %
                   curr_B)
            models.append(solver.get_model())
            UB = curr_B
        else:
            vprint('\tnew lower bound %i, optimizing...' % curr_B)
            LB = curr_B
        curr_B = LB + int(ceil((UB - LB) / 2))
    vprint('\tsolver returned optimal (under composition+P.O. abstractions)!')
    best_model = models[-1]
    zeta = list(map(lambda x: best_model.get_py_value(x), zeta))
    chi = list(map(lambda x: best_model.get_py_value(x), chi))
    duration = list(map(lambda x: best_model.get_py_value(x), duration))
    label = list(map(lambda x: best_model.get_py_value(x), label))
    return zeta, chi, duration, label
Exemple #11
0
class SMTExplainer(object):
    """
        An SMT-inspired minimal explanation extractor for XGBoost models.
    """
    def __init__(self, formula, intvs, imaps, ivars, feats, nof_classes,
                 options, xgb):
        """
            Constructor.
        """

        self.feats = feats
        self.intvs = intvs
        self.imaps = imaps
        self.ivars = ivars
        self.nofcl = nof_classes
        self.optns = options
        self.idmgr = IDPool()

        # saving XGBooster
        self.xgb = xgb

        self.verbose = self.optns.verb
        self.oracle = Solver(name=options.solver)

        self.inps = []  # input (feature value) variables
        for f in self.xgb.extended_feature_names_as_array_strings:
            if '_' not in f:
                self.inps.append(Symbol(f, typename=REAL))
            else:
                self.inps.append(Symbol(f, typename=BOOL))

        self.outs = []  # output (class  score) variables
        for c in range(self.nofcl):
            self.outs.append(Symbol('class{0}_score'.format(c), typename=REAL))

        # theory
        self.oracle.add_assertion(formula)

        # current selector
        self.selv = None

    def prepare(self, sample):
        """
            Prepare the oracle for computing an explanation.
        """

        if self.selv:
            # disable the previous assumption if any
            self.oracle.add_assertion(Not(self.selv))

        # creating a fresh selector for a new sample
        sname = ','.join([str(v).strip() for v in sample])

        # the samples should not repeat; otherwise, they will be
        # inconsistent with the previously introduced selectors
        assert sname not in self.idmgr.obj2id, 'this sample has been considered before (sample {0})'.format(
            self.idmgr.id(sname))
        self.selv = Symbol('sample{0}_selv'.format(self.idmgr.id(sname)),
                           typename=BOOL)

        self.rhypos = []  # relaxed hypotheses

        # transformed sample
        self.sample = list(self.xgb.transform(sample)[0])

        self.sel2fid = {}  # selectors to original feature ids
        self.sel2vid = {}  # selectors to categorical feature ids

        # preparing the selectors
        for i, (inp, val) in enumerate(zip(self.inps, self.sample), 1):
            feat = inp.symbol_name().split('_')[0]
            selv = Symbol('selv_{0}'.format(feat))
            val = float(val)

            self.rhypos.append(selv)
            if selv not in self.sel2fid:
                self.sel2fid[selv] = int(feat[1:])
                self.sel2vid[selv] = [i - 1]
            else:
                self.sel2vid[selv].append(i - 1)

        # adding relaxed hypotheses to the oracle
        if not self.intvs:
            for inp, val, sel in zip(self.inps, self.sample, self.rhypos):
                if '_' not in inp.symbol_name():
                    hypo = Implies(self.selv,
                                   Implies(sel, Equals(inp, Real(float(val)))))
                else:
                    hypo = Implies(self.selv,
                                   Implies(sel, inp if val else Not(inp)))

                self.oracle.add_assertion(hypo)
        else:
            for inp, val, sel in zip(self.inps, self.sample, self.rhypos):
                inp = inp.symbol_name()
                # determining the right interval and the corresponding variable
                for ub, fvar in zip(self.intvs[inp], self.ivars[inp]):
                    if ub == '+' or val < ub:
                        hypo = Implies(self.selv, Implies(sel, fvar))
                        break

                self.oracle.add_assertion(hypo)

        # in case of categorical data, there are selector duplicates
        # and we need to remove them
        self.rhypos = sorted(set(self.rhypos),
                             key=lambda x: int(x.symbol_name()[6:]))

        # propagating the true observation
        if self.oracle.solve([self.selv] + self.rhypos):
            model = self.oracle.get_model()
        else:
            assert 0, 'Formula is unsatisfiable under given assumptions'

        # choosing the maximum
        outvals = [float(model.get_py_value(o)) for o in self.outs]
        maxoval = max(zip(outvals, range(len(outvals))))

        # correct class id (corresponds to the maximum computed)
        self.out_id = maxoval[1]
        self.output = self.xgb.target_name[self.out_id]

        # forcing a misclassification, i.e. a wrong observation
        disj = []
        for i in range(len(self.outs)):
            if i != self.out_id:
                disj.append(GT(self.outs[i], self.outs[self.out_id]))
        self.oracle.add_assertion(Implies(self.selv, Or(disj)))

        if self.verbose:
            inpvals = self.xgb.readable_sample(sample)

            self.preamble = []
            for f, v in zip(self.xgb.feature_names, inpvals):
                if f not in str(v):
                    self.preamble.append('{0} = {1}'.format(f, v))
                else:
                    self.preamble.append(str(v))

            print('  explaining:  "IF {0} THEN {1}"'.format(
                ' AND '.join(self.preamble), self.output))

    def explain(self, sample, smallest, expl_ext=None, prefer_ext=False):
        """
            Hypotheses minimization.
        """

        start_mem = resource.getrusage(resource.RUSAGE_CHILDREN).ru_maxrss + \
                    resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
        self.time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \
                resource.getrusage(resource.RUSAGE_SELF).ru_utime

        # adapt the solver to deal with the current sample
        self.prepare(sample)

        # saving external explanation to be minimized further
        if expl_ext == None or prefer_ext:
            self.to_consider = [True for h in self.rhypos]
        else:
            eexpl = set(expl_ext)
            self.to_consider = [
                True if i in eexpl else False
                for i, h in enumerate(self.rhypos)
            ]

        # if satisfiable, then the observation is not implied by the hypotheses
        if self.oracle.solve(
            [self.selv] +
            [h for h, c in zip(self.rhypos, self.to_consider) if c]):
            print('  no implication!')
            print(self.oracle.get_model())
            sys.exit(1)

        if not smallest:
            self.compute_minimal(prefer_ext=prefer_ext)
        else:
            self.compute_smallest()

        self.time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \
                resource.getrusage(resource.RUSAGE_SELF).ru_utime - self.time

        self.used_mem = resource.getrusage(resource.RUSAGE_CHILDREN).ru_maxrss + \
                    resource.getrusage(resource.RUSAGE_SELF).ru_maxrss - start_mem

        expl = sorted([self.sel2fid[h] for h in self.rhypos])

        if self.verbose:
            self.preamble = [self.preamble[i] for i in expl]
            print('  explanation: "IF {0} THEN {1}"'.format(
                ' AND '.join(self.preamble),
                self.xgb.target_name[self.out_id]))
            print('  # hypos left:', len(self.rhypos))
            print('  time: {0:.2f}'.format(self.time))

        return expl

    def compute_minimal(self, prefer_ext=False):
        """
            Compute any subset-minimal explanation.
        """

        i = 0

        if not prefer_ext:
            # here, we want to reduce external explanation

            # filtering out unnecessary features if external explanation is given
            self.rhypos = [
                h for h, c in zip(self.rhypos, self.to_consider) if c
            ]
        else:
            # here, we want to compute an explanation that is preferred
            # to be similar to the given external one
            # for that, we try to postpone removing features that are
            # in the external explanation provided

            rhypos = [
                h for h, c in zip(self.rhypos, self.to_consider) if not c
            ]
            rhypos += [h for h, c in zip(self.rhypos, self.to_consider) if c]
            self.rhypos = rhypos

        # simple deletion-based linear search
        while i < len(self.rhypos):
            to_test = self.rhypos[:i] + self.rhypos[(i + 1):]

            if self.oracle.solve([self.selv] + to_test):
                i += 1
            else:
                self.rhypos = to_test

    def compute_smallest(self):
        """
            Compute a cardinality-minimal explanation.
        """

        # result
        rhypos = []

        with Hitman(bootstrap_with=[[
                i for i in range(len(self.rhypos)) if self.to_consider[i]
        ]]) as hitman:
            # computing unit-size MCSes
            for i, hypo in enumerate(self.rhypos):
                if self.to_consider[i] == False:
                    continue

                if self.oracle.solve([self.selv] + self.rhypos[:i] +
                                     self.rhypos[(i + 1):]):
                    hitman.hit([i])

            # main loop
            iters = 0
            while True:
                hset = hitman.get()
                iters += 1

                if self.verbose > 1:
                    print('iter:', iters)
                    print('cand:', hset)

                if self.oracle.solve([self.selv] +
                                     [self.rhypos[i] for i in hset]):
                    to_hit = []
                    satisfied, unsatisfied = [], []

                    removed = list(
                        set(range(len(self.rhypos))).difference(set(hset)))

                    model = self.oracle.get_model()
                    for h in removed:
                        i = self.sel2fid[self.rhypos[h]]
                        if '_' not in self.inps[i].symbol_name():
                            # feature variable and its expected value
                            var, exp = self.inps[i], self.sample[i]

                            # true value
                            true_val = float(model.get_py_value(var))

                            if not exp - 0.001 <= true_val <= exp + 0.001:
                                unsatisfied.append(h)
                            else:
                                hset.append(h)
                        else:
                            for vid in self.sel2vid[self.rhypos[h]]:
                                var, exp = self.inps[vid], int(
                                    self.sample[vid])

                                # true value
                                true_val = int(model.get_py_value(var))

                                if exp != true_val:
                                    unsatisfied.append(h)
                                    break
                            else:
                                hset.append(h)

                    # computing an MCS (expensive)
                    for h in unsatisfied:
                        if self.oracle.solve([self.selv] +
                                             [self.rhypos[i] for i in hset] +
                                             [self.rhypos[h]]):
                            hset.append(h)
                        else:
                            to_hit.append(h)

                    if self.verbose > 1:
                        print('coex:', to_hit)

                    hitman.hit(to_hit)
                else:
                    self.rhypos = [self.rhypos[i] for i in hset]
                    break
Exemple #12
0
from pysmt.shortcuts import Symbol, And, Or, BOOL, Solver, BVType, EqualsOrIff, TRUE, simplify

a = Symbol("a", BOOL)
b = Symbol("b", BOOL)
c = Symbol("c", BOOL)

bv1 = Symbol("bv1", BVType(4))
bv2 = Symbol("bv2", BVType(4))

f = EqualsOrIff(Or(And(a, b), TRUE()), EqualsOrIff(bv1, bv2))

print("f=", f, "f'=", simplify(f))

solver = Solver(name="msat")
solver.add_assertion(f)
print(solver.solve(), solver.get_model())
print([(v, v.symbol_type()) for v in f.get_free_variables()])

print(f.args(), f.is_and())
class TestCounterEnc(unittest.TestCase):

    def setUp(self):
        self.enc = CounterEnc(get_env(), False)
        self.solver = Solver(logic=pysmt.logics.BOOL)

    def _is_eq(self, a, b):
        f = Iff(a, b)
        self.assertTrue(self.solver.is_valid(f))

    def test_0(self):
        var_name = "counter_0"
        self.enc.add_var(var_name, 0)

        b0 = self.enc._get_bitvar(var_name,0)

        e = self.enc.eq_val(var_name, 0)
        self._is_eq(e, Not(b0))

        with self.assertRaises(AssertionError):
            e = self.enc.eq_val(var_name, 1)

        mask = self.enc.get_mask(var_name)
        self._is_eq(mask, Not(b0))

    def test_1(self):
        var_name = "counter_1"
        self.enc.add_var(var_name, 1)

        b0 = self.enc._get_bitvar(var_name,0)

        e = self.enc.eq_val(var_name, 0)
        self._is_eq(e, Not(b0))
        e = self.enc.eq_val(var_name, 1)
        self._is_eq(e, b0)

        with self.assertRaises(AssertionError):
            e = self.enc.eq_val(var_name, 2)

        mask = self.enc.get_mask(var_name)
        self._is_eq(mask, TRUE())


    def test_2(self):
        # need 2 bits
        var_name = "counter_2"
        self.enc.add_var(var_name, 2)

        b0 = self.enc._get_bitvar(var_name,0)
        b1 = self.enc._get_bitvar(var_name,1)

        e = self.enc.eq_val(var_name, 0)
        self._is_eq(e, And(Not(b0), Not(b1)))
        e = self.enc.eq_val(var_name, 1)
        self._is_eq(e, And(b0, Not(b1)))
        e = self.enc.eq_val(var_name, 2)
        self._is_eq(e, And(Not(b0), b1))

        with self.assertRaises(AssertionError):
            # out of the counter bound
            e = self.enc.eq_val(var_name, 3)

        mask = self.enc.get_mask(var_name)
        self._is_eq(mask, Not(And(b0, b1)))

    def test_3(self):
        # need 2 bits
        var_name = "counter_3"
        self.enc.add_var(var_name, 3)

        b0 = self.enc._get_bitvar(var_name,0)
        b1 = self.enc._get_bitvar(var_name,1)

        e = self.enc.eq_val(var_name, 0)
        self._is_eq(e, And(Not(b0), Not(b1)))
        e = self.enc.eq_val(var_name, 1)
        self._is_eq(e, And(b0, Not(b1)))
        e = self.enc.eq_val(var_name, 2)
        self._is_eq(e, And(Not(b0), b1))
        e = self.enc.eq_val(var_name, 3)
        self._is_eq(e, And(b0, b1))

        mask = self.enc.get_mask(var_name)
        self._is_eq(mask, TRUE())


    def test_4(self):
        # need 3 bits
        var_name = "counter_4"
        self.enc.add_var(var_name, 4)

        b0 = self.enc._get_bitvar(var_name,0)
        b1 = self.enc._get_bitvar(var_name,1)
        b2 = self.enc._get_bitvar(var_name,2)

        e = self.enc.eq_val(var_name, 0)
        self._is_eq(e, And([Not(b0), Not(b1), Not(b2)]))
        e = self.enc.eq_val(var_name, 1)
        self._is_eq(e, And([b0, Not(b1), Not(b2)]))

        e = self.enc.eq_val(var_name, 4)
        self._is_eq(e, And([Not(b0), Not(b1), b2]))

        with self.assertRaises(AssertionError):
            e = self.enc.eq_val(var_name, 5)

        mask = self.enc.get_mask(var_name)
        models = Or([And([b0, Not(b1), b2]),
                     And([Not(b0), b1, b2]),
                     And([b0, b1, b2])])
        self._is_eq(mask, Not(models))


    def test_value(self):
        def eq_value(self, var_name, value):
            eq_val = self.enc.eq_val(var_name, value)
            self.solver.is_sat(eq_val)
            model = self.solver.get_model()
            res = self.enc.get_counter_value(var_name, model, False)
            self.assertTrue(res == value)

        var_name = "counter_4"

        self.enc.add_var(var_name, 4)
        eq_value(self, var_name, 0)
        eq_value(self, var_name, 1)
        eq_value(self, var_name, 2)
        eq_value(self, var_name, 3)
        eq_value(self, var_name, 4)