예제 #1
0
파일: syntax.py 프로젝트: aman-goel/ic3po
 def build_prospectives(self):
     for nv, actions in self.nex2actions.items():
         if len(actions) == 1:
             action = next(iter(actions))
             if action not in self.action2prefix:
                 continue
             precond = And(self.action2prefix[action])
             pvars = precond.get_free_variables()
             pv = self.system._nex2pre[nv]
             if pv not in pvars:
                 postconds = self.action2suffix[action]
                 for cond in postconds:
                     cvars = cond.get_free_variables()
                     if nv in cvars:
                         assert (nv not in self.nexinfer)
                         self.nexinfer[nv] = (action, cond)
예제 #2
0
    def print_as_smtlib(self, smt_formulas, comments, cout):
        # script = smtlibscript_from_formula(And(smt_formulas), logic="QF_UFIDL")
        # script = SmtLibScript()
        # script.add(name=smtcmd.SET_LOGIC, args=["QF_UFIDL"])
        print(
            f';; File automatically generated on {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}',
            file=cout)

        print_script_command_line(cout,
                                  name=smtcmd.SET_LOGIC,
                                  args=["QF_UFIDL"])
        print("", file=cout)

        # The code below has largely been copied from smtlibscript_from_formula, with a few modifications
        # to work on a list of formulas
        # We simply create an And in order to be able to gather all types and free variables efficiently
        formula_and = And(smt_formulas)
        # Declare all types
        for type_ in get_env().typeso.get_types(formula_and, custom_only=True):
            # script.add(name=smtcmd.DECLARE_SORT, args=[type_.decl])
            print_script_command_line(cout,
                                      name=smtcmd.DECLARE_SORT,
                                      args=[type_.decl])
        print("", file=cout)

        # Declare all variables
        for symbol in formula_and.get_free_variables():
            assert symbol.is_symbol()
            # script.add(name=smtcmd.DECLARE_FUN, args=[symbol])
            print_script_command_line(cout,
                                      name=smtcmd.DECLARE_FUN,
                                      args=[symbol])

        print("", file=cout)

        # Assert formulas
        for i, formula in enumerate(smt_formulas, start=0):
            if i in comments:
                print(f"\n{comments[i]}", file=cout)
            # script.add_command(SmtLibCommand(name=smtcmd.ASSERT, args=[formula]))
            print_script_command_line(cout, name=smtcmd.ASSERT, args=[formula])

        print("\n", file=cout)

        # check-sat
        # script.add_command(SmtLibCommand(name=smtcmd.CHECK_SAT, args=[]))
        print_script_command_line(cout, name=smtcmd.CHECK_SAT, args=[])
예제 #3
0
class SMTEncoder(object):
    """
        Encoder of XGBoost tree ensembles into SMT.
    """
    def __init__(self, model, feats, nof_classes, xgb, from_file=None):
        """
            Constructor.
        """

        self.model = model
        self.feats = {f: i for i, f in enumerate(feats)}
        self.nofcl = nof_classes
        self.idmgr = IDPool()
        self.optns = xgb.options

        # xgbooster will also be needed
        self.xgb = xgb

        # for interval-based encoding
        self.intvs, self.imaps, self.ivars = None, None, None

        if from_file:
            self.load_from(from_file)

    def traverse(self, tree, tvar, prefix=[]):
        """
            Traverse a tree and encode each node.
        """

        if tree.children:
            pos, neg = self.encode_node(tree)

            self.traverse(tree.children[0], tvar, prefix + [pos])
            self.traverse(tree.children[1], tvar, prefix + [neg])
        else:  # leaf node
            if prefix:
                self.enc.append(
                    Implies(And(prefix), Equals(tvar, Real(tree.values))))
            else:
                self.enc.append(Equals(tvar, Real(tree.values)))

    def encode_node(self, node):
        """
            Encode a node of a tree.
        """

        if '_' not in node.name:
            # continuous features => expecting an upper bound
            # feature and its upper bound (value)
            f, v = node.name, node.threshold

            existing = True if tuple([f, v]) in self.idmgr.obj2id else False
            vid = self.idmgr.id(tuple([f, v]))
            bv = Symbol('bvar{0}'.format(vid), typename=BOOL)

            if not existing:
                if self.intvs:
                    d = self.imaps[f][v] + 1
                    pos, neg = self.ivars[f][:d], self.ivars[f][d:]
                    self.enc.append(Iff(bv, Or(pos)))
                    self.enc.append(Iff(Not(bv), Or(neg)))
                else:
                    fvar, fval = Symbol(f, typename=REAL), Real(v)
                    self.enc.append(Iff(bv, LT(fvar, fval)))

            return bv, Not(bv)
        else:
            # all features are expected to be categorical and
            # encoded with one-hot encoding into Booleans
            # each node is expected to be of the form: f_i < 0.5
            bv = Symbol(node.name, typename=BOOL)

            # left branch is positive,  i.e. bv is true
            # right branch is negative, i.e. bv is false
            return Not(bv), bv

    def compute_intervals(self):
        """
            Traverse all trees in the ensemble and extract intervals for each
            feature.

            At this point, the method only works for numerical datasets!
        """
        def traverse_intervals(tree):
            """
                Auxiliary function. Recursive tree traversal.
            """

            if tree.children:
                f = tree.name
                v = tree.threshold
                self.intvs[f].add(v)

                traverse_intervals(tree.children[0])
                traverse_intervals(tree.children[1])

        # initializing the intervals
        self.intvs = {
            'f{0}'.format(i): set([])
            for i in range(len(self.feats))
        }

        for tree in self.ensemble.trees:
            traverse_intervals(tree)

        # OK, we got all intervals; let's sort the values
        self.intvs = {
            f: sorted(self.intvs[f]) + ['+']
            for f in six.iterkeys(self.intvs)
        }

        self.imaps, self.ivars = {}, {}
        for feat, intvs in six.iteritems(self.intvs):
            self.imaps[feat] = {}
            self.ivars[feat] = []
            for i, ub in enumerate(intvs):
                self.imaps[feat][ub] = i

                ivar = Symbol(name='{0}_intv{1}'.format(feat, i),
                              typename=BOOL)
                self.ivars[feat].append(ivar)

    def encode(self):
        """
            Do the job.
        """

        self.enc = []

        # getting a tree ensemble
        self.ensemble = TreeEnsemble(
            self.model,
            self.xgb.extended_feature_names_as_array_strings,
            nb_classes=self.nofcl)

        # introducing class score variables
        csum = []
        for j in range(self.nofcl):
            cvar = Symbol('class{0}_score'.format(j), typename=REAL)
            csum.append(tuple([cvar, []]))

        # if targeting interval-based encoding,
        # traverse all trees and extract all possible intervals
        # for each feature
        if self.optns.encode == 'smtbool':
            self.compute_intervals()

        # traversing and encoding each tree
        for i, tree in enumerate(self.ensemble.trees):
            # getting class id
            clid = i % self.nofcl

            # encoding the tree
            tvar = Symbol('tr{0}_score'.format(i + 1), typename=REAL)
            self.traverse(tree, tvar, prefix=[])

            # this tree contributes to class with clid
            csum[clid][1].append(tvar)

        # encoding the sums
        for pair in csum:
            cvar, tvars = pair
            self.enc.append(Equals(cvar, Plus(tvars)))

        # enforce exactly one of the feature values to be chosen
        # (for categorical features)
        categories = collections.defaultdict(lambda: [])
        for f in self.xgb.extended_feature_names_as_array_strings:
            if '_' in f:
                categories[f.split('_')[0]].append(
                    Symbol(name=f, typename=BOOL))
        for c, feats in six.iteritems(categories):
            self.enc.append(ExactlyOne(feats))

        # number of assertions
        nof_asserts = len(self.enc)

        # making conjunction
        self.enc = And(self.enc)

        # number of variables
        nof_vars = len(self.enc.get_free_variables())

        if self.optns.verb:
            print('encoding vars:', nof_vars)
            print('encoding asserts:', nof_asserts)

        return self.enc, self.intvs, self.imaps, self.ivars

    def test_sample(self, sample):
        """
            Check whether or not the encoding "predicts" the same class
            as the classifier given an input sample.
        """

        # first, compute the scores for all classes as would be
        # predicted by the classifier

        # score arrays computed for each class
        csum = [[] for c in range(self.nofcl)]

        if self.optns.verb:
            print('testing sample:', list(sample))

        sample_internal = list(self.xgb.transform(sample)[0])

        # traversing all trees
        for i, tree in enumerate(self.ensemble.trees):
            # getting class id
            clid = i % self.nofcl

            # a score computed by the current tree
            score = scores_tree(tree, sample_internal)

            # this tree contributes to class with clid
            csum[clid].append(score)

        # final scores for each class
        cscores = [sum(scores) for scores in csum]

        # second, get the scores computed with the use of the encoding

        # asserting the sample
        hypos = []

        if not self.intvs:
            for i, fval in enumerate(sample_internal):
                feat, vid = self.xgb.transform_inverse_by_index(i)
                fid = self.feats[feat]

                if vid == None:
                    fvar = Symbol('f{0}'.format(fid), typename=REAL)
                    hypos.append(Equals(fvar, Real(float(fval))))
                else:
                    fvar = Symbol('f{0}_{1}'.format(fid, vid), typename=BOOL)
                    if int(fval) == 1:
                        hypos.append(fvar)
                    else:
                        hypos.append(Not(fvar))
        else:
            for i, fval in enumerate(sample_internal):
                feat, _ = self.xgb.transform_inverse_by_index(i)
                feat = 'f{0}'.format(self.feats[feat])

                # determining the right interval and the corresponding variable
                for ub, fvar in zip(self.intvs[feat], self.ivars[feat]):
                    if ub == '+' or fval < ub:
                        hypos.append(fvar)
                        break
                else:
                    assert 0, 'No proper interval found for {0}'.format(feat)

        # now, getting the model
        escores = []
        model = get_model(And(self.enc, *hypos), solver_name=self.optns.solver)
        for c in range(self.nofcl):
            v = Symbol('class{0}_score'.format(c), typename=REAL)
            escores.append(float(model.get_py_value(v)))

        assert all(map(lambda c, e: abs(c - e) <= 0.001, cscores, escores)), \
                'wrong prediction: {0} vs {1}'.format(cscores, escores)

        if self.optns.verb:
            print('xgb scores:', cscores)
            print('enc scores:', escores)

    def save_to(self, outfile):
        """
            Save the encoding into a file with a given name.
        """

        if outfile.endswith('.txt'):
            outfile = outfile[:-3] + 'smt2'

        write_smtlib(self.enc, outfile)

        # appending additional information
        with open(outfile, 'r') as fp:
            contents = fp.readlines()

        # comments
        comments = [
            '; features: {0}\n'.format(', '.join(self.feats)),
            '; classes: {0}\n'.format(self.nofcl)
        ]

        if self.intvs:
            for f in self.xgb.extended_feature_names_as_array_strings:
                c = '; i {0}: '.format(f)
                c += ', '.join([
                    '{0}<->{1}'.format(u, v)
                    for u, v in zip(self.intvs[f], self.ivars[f])
                ])
                comments.append(c + '\n')

        contents = comments + contents
        with open(outfile, 'w') as fp:
            fp.writelines(contents)

    def load_from(self, infile):
        """
            Loads the encoding from an input file.
        """

        with open(infile, 'r') as fp:
            file_content = fp.readlines()

        # empty intervals for the standard encoding
        self.intvs, self.imaps, self.ivars = {}, {}, {}

        for line in file_content:
            if line[0] != ';':
                break
            elif line.startswith('; i '):
                f, arr = line[4:].strip().split(': ', 1)
                f = f.replace('-', '_')
                self.intvs[f], self.imaps[f], self.ivars[f] = [], {}, []

                for i, pair in enumerate(arr.split(', ')):
                    ub, symb = pair.split('<->')

                    if ub[0] != '+':
                        ub = float(ub)
                    symb = Symbol(symb, typename=BOOL)

                    self.intvs[f].append(ub)
                    self.ivars[f].append(symb)
                    self.imaps[f][ub] = i

            elif line.startswith('; features:'):
                self.feats = line[11:].strip().split(', ')
            elif line.startswith('; classes:'):
                self.nofcl = int(line[10:].strip())

        parser = SmtLibParser()
        script = parser.get_script(StringIO(''.join(file_content)))

        self.enc = script.get_last_formula()

    def access(self):
        """
            Get access to the encoding, features names, and the number of
            classes.
        """

        return self.enc, self.intvs, self.imaps, self.ivars, self.feats, self.nofcl
예제 #4
0
파일: syntax.py 프로젝트: aman-goel/ic3po
    def process_prospectives(self):
        print("\nProcessing prospectives")
        for nv, (action, cond) in self.nexinfer.items():
            print("\tprocessing %s in %s" % (nv, action))
            action_suffix = self.action2suffix[action]
            action_nvars = And(action_suffix).get_free_variables()
            action_nvars = action_nvars.intersection(
                self.system._nex2pre.keys())
            action_pvars = set()
            for n in action_nvars:
                action_pvars.add(self.system._nex2pre[n])
            action_prefix = self.action2prefix[action]
            concs_global = []
            concs = []
            for c in action_prefix:
                cvars = c.get_free_variables()
                common = cvars.intersection(action_pvars)
                if len(common) == 0:
                    statevars = cvars.intersection(self.system._states)
                    statevars = statevars.difference(self.system._globals)
                    if len(statevars) == 0:
                        concs_global.append(c)
                    else:
                        concs.append(c)
            if (len(concs) + len(concs_global)) == 0:
                print("\t\tskipping %s since no static precondition found" %
                      nv)
                continue
            qvars = set()
            if cond.is_forall():
                cqvars = cond.quantifier_vars()
                for v in cqvars:
                    qvars.add(v)
                cond = cond.arg(0)
            if not (cond.is_iff() or cond.is_equals()):
                continue
            lhs = cond.arg(0)
            rhs = cond.arg(1)
            lvars = lhs.get_free_variables()
            rvars = rhs.get_free_variables()
            if nv in rvars:
                lhs, rhs = rhs, lhs
                lvars, rvars = rvars, lvars
            if nv in rvars:
                continue
            ldvars = lvars.difference(qvars)
            if len(ldvars) != 1:
                continue
            ldvar = next(iter(ldvars))
            if nv != ldvar:
                continue
            if len(qvars) != 0:
                if not lhs.is_function_application():
                    continue
                nsym = lhs.function_name()
                if nv != nsym:
                    continue
            if len(self.action2def) != 0:
                rhs = rhs.substitute(self.action2def[action])
            rconds = []
            rval = self.process_assign(rhs, rconds)
            premise = []
            qsubs = {}
            for c in rconds:
                if c.is_iff() or c.is_equals():
                    lc = c.arg(0)
                    rc = c.arg(1)
                    if rc.is_symbol():
                        if rc in qvars:
                            lc, rc = rc, lc
                    if lc.is_symbol():
                        if lc in qvars:
                            if rc in self.system._others:
                                qsubs[rc] = lc
                                continue
                premise.append(c)
            eq = EqualsOrIff(lhs, rval)
            premise.insert(0, eq)
            prem = And(premise)
            qsubs[nv] = self.system._nex2pre[nv]

            prem = prem.substitute(qsubs)
            ivars = prem.get_free_variables()
            ivars = ivars.intersection(self.system._others)
            if len(ivars) != 0:
                for v in ivars:
                    vname = "Q" + str(v)
                    vname = vname.replace(":", "")
                    vnew = Symbol(vname, v.symbol_type())
                    qsubs[v] = vnew
                    qvars.add(vnew)

            prem = prem.substitute(qsubs)
            if len(concs_global) != 0:
                concs.append(None)
            for conc in concs:
                if conc == None:
                    conc = And(concs_global)
                else:
                    if len(concs_global) != 0:
                        conc = And(conc, And(concs_global))
                conc = conc.substitute(qsubs)

                ivars = conc.get_free_variables()
                ivars = ivars.intersection(self.system._others)
                evars = []
                if len(ivars) != 0:
                    esubs = {}
                    for v in ivars:
                        vname = "Q" + str(v)
                        vname = vname.replace(":", "")
                        vnew = Symbol(vname, v.symbol_type())
                        esubs[v] = vnew
                        evars.append(vnew)
                    conc = conc.substitute(esubs)

    #                 conc = Exists(evars, conc)
    #                 evars = []

#                 print("evars ", evars)
#                 print("conc ", conc)
                inference = Implies(prem, conc)

                qvars2 = []
                #                 if len(qvars) != 0:
                #                     for u in qvars:
                #                         ut = u.symbol_type()
                #                         for e in evars:
                #                             et = e.symbol_type()
                #                             if not self.strat.allowed_arc(ut, et):
                #                                 qvars2.append(u)
                uqvars = qvars.difference(qvars2)
                #                     for u in qvars2:
                #                         if u in qvars:
                #                             qvars.remove(u)

                #                 print("qvars2 ", qvars2)
                #                 print("uqvars ", uqvars)

                if len(qvars2) != 0:
                    inference = ForAll(qvars2, inference)

                if len(evars) != 0:
                    inference = Exists(evars, inference)

                if len(uqvars) != 0:
                    inference = ForAll(uqvars, inference)

                iname = "syntax" + str(len(self.system._infers) + 1)
                self.system._infers[inference] = iname
                print("\t\tinferred %s: %s" % (iname, inference))
예제 #5
0
    def add_trel_new(self):
        eprint("\t(found #%d actions)" % len(self._actions))
        print("\t(found #%d actions)" % len(self._actions))
        if len(self._actions) == 0:
            eprint("\t(error: no action found)")
            print("\t(error: no action found)")
            assert (0)
        if len(self._axiom) != 0:
            ax = And(self._axiom)
            axvar = ax.get_free_variables()
            for v in axvar:
                self._axiomvars.add(v)

        tcond = []
        enVar = []
        enOr = []
        noop = self.add_action_noop()

        for idx, f in enumerate(self._actions):
            action = f[0]
            action_name = f[1]
            action_vars = action.get_free_variables()

            en = Symbol("en_" + action_name, BOOL)
            self._action_en2idx[en] = idx
            enVar.append(en)

            qvars = None
            if action.is_exists():
                qvars = action.quantifier_vars()
                action = action.arg(0)

            action_all = [action]
            missing_nex = []
            for n in self._nex2pre.keys():
                if n not in action_vars:
                    if str(n) not in self._definitions:
                        if True or (str(n) != "choosable"):
                            action_all.append(noop[n])
                            missing_nex.append(n)
            if len(missing_nex) != 0:
                print("adding #%d noops to action %s" %
                      (len(missing_nex), f[1]))
                for n in missing_nex:
                    print("\tnoop(%s)" % n)
            action = And(action_all)
            if qvars != None:
                action = Exists(qvars, action)

            self._actions[idx][0] = action
            self._actions[idx][2] = en

            cond = Implies(en, action)
            tcond.append(cond)
            enOr.append(en)

        cond = Or(enOr)
        tcond.append(cond)
        for i in range(len(enVar) - 1):
            ei = enVar[i]
            for j in range(i + 1, len(enVar)):
                assert (i != j)
                ej = enVar[j]
                cond = Or(Not(ei), Not(ej))
                tcond.append(cond)
        self._trel = And(tcond)
예제 #6
0
    def add_trel(self):
        eprint("\t(found #%d actions)" % len(self._actions))
        print("\t(found #%d actions)" % len(self._actions))
        if len(self._actions) == 0:
            eprint("\t(error: no action found)")
            print("\t(error: no action found)")
            assert (0)
        if len(self._axiom) != 0:
            ax = And(self._axiom)
            axvar = ax.get_free_variables()
            for v in axvar:
                self._axiomvars.add(v)

        noop_name, noop = self.get_action_noop()
        noop_all = And([i for i in noop.values()])
        self._trel = noop_all
        self._input_action = Symbol(self.input_action_name(), INT)
        self.add_var(self._input_action)
        #         axiom_vars = self.get_axiom_vars()
        #         action_en = []
        #         self.add_action(noop_all, noop_name)
        for idx, f in enumerate(self._actions):
            action = f[0]
            action_name = f[1]
            #             action_vars = axiom_vars.copy()
            #             for v in action.get_free_variables():
            #                 action_vars.add(v)
            action_vars = action.get_free_variables()
            qvars = None
            if action.is_exists():
                qvars = action.quantifier_vars()
                action = action.arg(0)
            action_all = [action]
            missing_nex = []
            for n in self._nex2pre.keys():
                if n not in action_vars:
                    if str(n) not in self._definitions:
                        if True or (str(n) != "choosable"):
                            action_all.append(noop[n])
                            missing_nex.append(n)
            if len(missing_nex) != 0:
                print("adding #%d noops to action %s" %
                      (len(missing_nex), f[1]))
                for n in missing_nex:
                    print("\tnoop(%s)" % n)
            action = And(action_all)
            if qvars != None:
                action = Exists(qvars, action)
            self._actions[idx][0] = action

            #             self._trel = Or(action, self._trel)

            action_symbol = Int(idx)
            self._actions[idx][-1] = action_symbol
            cond = EqualsOrIff(self._input_action, action_symbol)
            self._trel = Ite(cond, action, self._trel)


#             action_symbol = Symbol("en_"+action_name)
#             action_en.append(action_symbol)
#             self._trel = Ite(action_symbol, action, self._trel)

#         action_cond = []
#         action_cond.append(self._trel)
#         for i in range(len(action_en)-1):
#             fi = Not(action_en[i])
#             for j in range(i+1, len(action_en)):
#                 fj = Not(action_en[j])
#                 cond = Or(fi, fj)
#                 action_cond.insert(0, cond)
#         self._trel = And(action_cond)

#         if len(self._axiom) != 0:
#             q = []
#             q.extend(self._axiom)
# #             self.add_init(And(q))
#             q.append(self._trel)
#             self._trel = And(q)

        self.add_action(noop_all, noop_name)