Ejemplo n.º 1
0
def test_term_hash_raises_exception():
    # from tarski.fstrips import language
    # from tarski.syntax import symref
    lang = fs.language("test")
    counter = defaultdict(int)

    c = lang.constant('c', 'object')
    f = lang.function('f', 'object', 'object')

    # Trying to use associative containers on terms raises a standard TypeError
    with pytest.raises(TypeError):
        counter[c] += 2

    with pytest.raises(TypeError):
        counter[f(c)] += 2

    # Using symrefs instead works correctly:
    counter[symref(c)] += 2
    assert counter[symref(c)] == 2

    counter[symref(f(c))] += 2
    assert counter[symref(f(c))] == 2

    # Atoms and in general formulas can be used without problem
    atom = f(c) == c
    counter[atom] += 2
    assert counter[atom] == 2
Ejemplo n.º 2
0
def test_sort_id_assignment():
    lang = tsk.fstrips.language(theories=[])

    # Create some sort hierarchy
    lang.sort("s1", "object")
    lang.sort("s2", "object")
    lang.sort("t1", "s2")
    lang.sort("t2", "s2")

    # Create some objects in a sequence that alternates types
    lang.constant("a1", "s1")
    lang.constant("o1", "object")
    lang.constant("a2", "s1")
    lang.constant("b3", "t1")
    lang.constant("a3", "s1")
    lang.constant("o2", "object")
    lang.constant("b1", "s2")
    lang.constant("b2", "t1")
    lang.constant("b4", "s2")

    sortmap = compute_direct_sort_map(lang)
    cards = {s.name: len(objs) for s, objs in sortmap.items()}
    assert cards == {'object': 2, 's1': 3, 's2': 2, 't1': 2, 't2': 0}
    bounds, ids = compute_sort_id_assignment(lang)

    assert bounds[lang.Object] == (0, 9)
    assert {ids[symref(lang.get('o1'))], ids[symref(lang.get('o2'))]} == {7, 8}
    # Note that the following relies on dict (the dict of sort parents) iterating over its elements in insertion sort.
    # see https://stackoverflow.com/a/39980744
    assert {
        ids[symref(lang.get('a1'))], ids[symref(lang.get('a2'))],
        ids[symref(lang.get('a3'))]
    } == {0, 1, 2}
Ejemplo n.º 3
0
def test_simple_expression_substitutions():
    lang = tarski.benchmarks.blocksworld.generate_strips_bw_language(nblocks=2)
    clear, b1, b2 = [lang.get(name) for name in ('clear', 'b1', 'b2')]
    x, y = lang.variable('x', 'object'), lang.variable('y', 'object')

    formula = clear(x)
    replaced = substitute_expression(formula,
                                     substitution={symref(x): b1},
                                     inplace=False)
    replaced2 = substitute_expression(formula,
                                      substitution={symref(x): b2},
                                      inplace=False)

    assert not formula.is_syntactically_equal(replaced)
    assert str(formula) == "clear(x)" and str(replaced) == "clear(b1)" and str(
        replaced2) == "clear(b2)"

    # Now let's do the same but inplace
    replaced = substitute_expression(formula,
                                     substitution={symref(x): b1},
                                     inplace=True)

    assert formula.is_syntactically_equal(replaced)
    assert str(formula) == str(replaced) == "clear(b1)"

    formula = forall(x, clear(x) & clear(y))
    replaced = substitute_expression(formula,
                                     substitution={
                                         symref(x): b1,
                                         symref(y): b2
                                     },
                                     inplace=False)
    assert str(formula) == "forall x : ((clear(x) and clear(y)))" and \
           str(replaced) == "forall b1 : ((clear(b1) and clear(b2)))"
Ejemplo n.º 4
0
def test_variables_classification():
    tw = tarskiworld.create_small_world()
    x = tw.variable('x', tw.Object)
    y = tw.variable('y', tw.Object)
    s = neg(land(tw.Cube(x), exists(y, land(tw.Tet(x), tw.LeftOf(x, y)))))
    free = free_variables(s)
    assert len(free) == 1 and symref(free[0]) == symref(x)
    assert len(all_variables(s)) == 2
Ejemplo n.º 5
0
def test_symbol_ref_in_sets_equality_is_exact_syntactic_match():
    tw = tarskiworld.create_small_world()
    x = tw.variable('x', tw.Object)
    y = tw.variable('y', tw.Object)
    x_ref = symref(x)
    y_ref = symref(y)
    S = set()
    S.add(x_ref)
    assert y_ref not in S
    assert x_ref in S
    S.remove(x_ref)
    assert len(S) == 0
Ejemplo n.º 6
0
def test_simplifier():
    problem = generate_fstrips_counters_problem(ncounters=3)
    lang = problem.language
    value, max_int, counter, val_t, c1 = lang.get('value', 'max_int',
                                                  'counter', 'val', 'c1')
    x = lang.variable('x', counter)
    two, three, six = [lang.constant(c, val_t) for c in (2, 3, 6)]

    s = Simplify(problem, problem.init)
    assert symref(s.simplify_expression(x)) == symref(x)
    assert symref(s.simplify_expression(value(c1) < max_int())) == symref(
        value(c1) < six)  # max_int evaluates to 6
    assert s.simplify_expression(two < max_int()) is True
    assert s.simplify_expression(two > three) is False

    # conjunction evaluates to false because of first conjunct:
    falseconj = land(two > three, value(c1) < max_int())
    assert s.simplify_expression(falseconj) is False
    assert s.simplify_expression(neg(falseconj)) is True

    # first conjunct gets removed:
    assert str(s.simplify_expression(land(
        two < three,
        value(c1) < max_int()))) == '<(value(c1),6)'

    # first disjunct gets removed because it is false
    assert str(s.simplify_expression(lor(
        two > three,
        value(c1) < max_int()))) == '<(value(c1),6)'

    assert str(s.simplify_expression(forall(
        x,
        value(x) < max_int()))) == 'forall x : (<(value(x),6))'
    assert s.simplify_expression(forall(x, two + three <= 6)) is True

    inc = problem.get_action('increment')
    simp = s.simplify_action(inc)
    assert str(simp.precondition) == '<(value(c),6)'
    assert str(simp.effects) == str(inc.effects)

    eff = UniversalEffect(x, [value(x) << three])
    assert str(
        s.simplify_effect(eff)) == '(T -> forall (x) : ((T -> value(x) := 3)))'

    simp = s.simplify()
    assert str(simp.get_action('increment').precondition) == '<(value(c),6)'

    # Make sure there is no mention to the compiled away "max_int" symbol in the language
    assert not simp.language.has_function("max_int")

    # Make sure there is no mention to the compiled away "max_int" symbol in the initial state
    exts = list(simp.init.list_all_extensions().keys())
    assert ('max_int', 'val') not in exts
Ejemplo n.º 7
0
def test_term_refs():
    lang = fstrips.language(theories=[Theory.ARITHMETIC])
    _ = lang.function('f', lang.Object, lang.Integer)
    o1 = lang.constant("o1", lang.Object)
    o2 = lang.constant("o2", lang.Object)

    tr1 = symref(o1)
    tr2 = symref(o1)
    tr3 = symref(o2)

    assert tr1 == tr2
    assert tr1 != tr3
def test_model_list_extensions():
    lang = tarski.language(theories=[])
    p = lang.predicate('p', lang.Object, lang.Object)
    f = lang.function('f', lang.Object, lang.Object)
    o1 = lang.constant("o1", lang.Object)
    o2 = lang.constant("o2", lang.Object)

    model = Model(lang)
    model.evaluator = evaluate

    model.set(f, o1, o2)
    model.add(p, o1, o2)

    extensions = model.list_all_extensions()
    ext_f = extensions[f.signature]
    ext_p = extensions[p.signature]

    # We want to test that the `list_all_extensions` method correctly unwraps all TermReferences in the internal
    # representation of the model and returns _only_ actual Tarski terms. Testing this is a bit involved, as of
    # course we cannot just check for (o1, o2) in ext_f, because that will trigger the wrong __eq__ and __hash__
    # methods - in other words, to test this we precisely need to wrap things back into TermReferences, so that we can
    # compare them properly

    assert len(ext_f) == 1 and len(ext_p) == 1
    p, v = ext_f[0]
    assert symref(p) == symref(o1) and symref(v) == symref(o2)
    v1, v2 = ext_p[0]
    assert symref(v1) == symref(o1) and symref(v2) == symref(o2)
Ejemplo n.º 9
0
def analyze_action_effects(lang, schemas):
    """ Compile an index of action effects according to the type of effect (add/del/functional) and the
    symbol they affect. """
    index = {
        "add": defaultdict(list),
        "del": defaultdict(list),
        "fun": defaultdict(list)
    }

    for a in schemas:
        substitution = {
            symref(param): arg
            for param, arg in zip(a.parameters,
                                  generate_action_arguments(lang, a))
        }

        for eff in a.effects:
            if not isinstance(
                    eff, (fs.AddEffect, fs.DelEffect, fs.FunctionalEffect)):
                raise TransformationError(f'Cannot handle effect "{eff}"')

            # Let's substitute the action parameters for some standard variable names such as z1, z2, ... so that
            # later on in the compilation we can use them off the self.
            eff = term_substitution(eff, substitution)
            atom = eff.atom if isinstance(eff, (fs.AddEffect,
                                                fs.DelEffect)) else eff.lhs

            if isinstance(eff, fs.AddEffect):
                index["add"][atom.symbol.name].append((a, eff))
            elif isinstance(eff, fs.DelEffect):
                index["del"][atom.symbol.name].append((a, eff))
            else:
                index["fun"][atom.symbol.name].append((a, eff))

    return index
Ejemplo n.º 10
0
    def assert_frame_axioms(self):
        ml = self.metalang
        tvar = _get_timestep_var(ml)

        # First deal with predicates;
        for p in get_symbols(self.lang, type_="all", include_builtin=False):
            if not self.symbol_is_fluent(p):
                continue

            self.comments[len(self.theory)] = f";; Frame axiom for symbol {p}:"
            lvars = generate_symbol_arguments(self.lang, p)
            atom = p(*lvars)
            fquant = generate_symbol_arguments(ml, p) + [tvar]

            if isinstance(p, Predicate):
                # pos: not p(x, t) and p(x, t+1)  => \gamma_p^+(x, t)
                # neg: p(x, t) and not p(x, t+1)  => \gamma_p^-(x, t)
                at_t = self.to_metalang(atom, tvar)
                at_t1 = self.to_metalang(atom, tvar + 1)

                pos = forall(*fquant,
                             implies(~at_t & at_t1, self.gamma_pos[p.name]))
                neg = forall(*fquant,
                             implies(at_t & ~at_t1, self.gamma_neg[p.name]))
                self.theory += [pos, neg]
            else:
                # fun: f(x, t) != f(x, t+1) => \gamma_f[y/f(x, t+1)]
                yvar = ml.variable("y", ml.get_sort(p.codomain.name))
                at_t = self.to_metalang(atom, tvar)
                at_t1 = self.to_metalang(atom, tvar + 1)
                gamma_replaced = term_substitution(self.gamma_fun[p.name],
                                                   {symref(yvar): at_t1})
                fun = forall(*fquant, implies(at_t != at_t1, gamma_replaced))
                self.theory += [fun]
Ejemplo n.º 11
0
    def resolve_constant(self, c: Constant, sort: Sort = None):
        if sort is None:
            sort = c.sort

        if sort in (self.smtlang.Integer, self.smtlang.Real):
            return str(sort.literal(c))

        if isinstance(sort, Interval):
            return self.resolve_constant(c, parent(sort))

        if isinstance(sort, Set):
            # This is slightly tricky, since set denotations are encoded with strings, not with Constant objects
            assert isinstance(c.symbol, set)
            elems = [self.resolve_constant(self.smtlang.get(x)) if isinstance(x, str) else str(x) for x in c.symbol]

            if len(c.symbol) == 0:
                return f"(as emptyset {resolve_type_for_sort(self.smtlang, c.sort)})"
            elif len(c.symbol) == 1:
                return f"(singleton {' '.join(elems)})"
            else:
                # e.g. if the set is {1, 2, 3, 4}, we want to output: (insert 1 2 3 (singleton 4))
                return f'(insert {" ".join(elems[:-1])} (singleton {elems[-1]}))'

        # Otherwise we must have an enumerated type and simply return the object ID
        return str(self.object_ids[symref(c)])
Ejemplo n.º 12
0
def test_term_refs_compound():
    lang = fstrips.language(theories=[Theory.ARITHMETIC])
    f = lang.function('f', lang.Object, lang.Integer)
    o1 = lang.constant("o1", lang.Object)
    o2 = lang.constant("o2", lang.Object)
    _ = lang.get('f')

    t1 = f(o1)
    t2 = f(o1)
    t3 = f(o2)
    assert t1.symbol == t2.symbol
    tr1 = symref(t1)
    tr2 = symref(t2)
    tr3 = symref(t3)

    assert tr1 == tr2
    assert tr1 != tr3
Ejemplo n.º 13
0
def test_formula_refs():
    lang = fstrips.language('arith', [Theory.EQUALITY, Theory.ARITHMETIC])

    _ = lang.constant(1, lang.Integer)

    x = lang.function('x', lang.Integer)
    y = lang.function('y', lang.Integer)

    phi = (x() <= y()) & (y() <= x())
    psi = (x() >= y()) & (y() <= x())
    gamma = (x() <= y()) & (y() <= x())

    fr1 = symref(phi)
    fr2 = symref(psi)
    fr3 = symref(gamma)

    assert fr1 == fr3
    assert fr1 != fr2
Ejemplo n.º 14
0
 def smt_fun_application(self, phi, varmap):
     key = symref(phi)
     try:
         return self.vars[key]
     except KeyError:
         params = [self.rewrite(st, varmap) for st in phi.subterms]
         fun, ftype = self.smt_functions[phi.symbol.name]
         self.vars[key] = res = pysmt.shortcuts.Function(fun, params)
     return res
Ejemplo n.º 15
0
def test_action_grounding_bw():
    problem = generate_strips_blocksworld_problem()
    b1, b2, b3, clear, on, ontable, handempty, holding = \
        problem.language.get('b1', 'b2', 'b3', 'clear', 'on', 'ontable', 'handempty', 'holding')
    unstack = problem.get_action("unstack")
    x1, x2 = [symref(x)
              for x in unstack.parameters]  # Unstack has two parameters
    ground = ground_schema_into_plain_operator(unstack, {
        x1: b1,
        x2: b2
    })  # i.e. the operator unstack(b1, b2)
    assert isinstance(ground, PlainOperator) and \
           str(ground.precondition) == '(on(b1,b2) and clear(b1) and handempty())'
Ejemplo n.º 16
0
 def smt_variable(self, expr):
     # TODO This code is currently unused and needs to be revised / removed
     """ Return the (possibly cached) SMT theory variable that corresponds to the given Tarski
     logical expression, which can be either an atom (e.g. clear(b1)) or a compound term representing
     a state variable (e.g. value(c1)). """
     assert isinstance(expr, (Atom, CompoundTerm, Variable))
     key = symref(expr)
     try:
         return self.vars[key]
     except KeyError:
         creator = self.create_bool_term if isinstance(
             expr, Atom) else self.create_variable
         self.vars[key] = res = creator(expr, name=str(expr))
     return res
Ejemplo n.º 17
0
def visit(phi, subst):
    if isinstance(phi, CompoundFormula):
        subformulas = [visit(f, subst) for f in phi.subformulas]
        return CompoundFormula(phi.connective, subformulas)

    elif isinstance(phi, QuantifiedFormula):
        if any(symref(x) in subst for x in phi.variables):
            raise SubstitutionError(
                phi, subst,
                'Attempted to substitute variable bound by quantifier')
        formula = visit(phi.formula, subst)
        return QuantifiedFormula(phi.quantifier, phi.variables, formula)

    elif isinstance(phi, Atom):
        new_subterms = list(phi.subterms)
        phi = Atom(phi.symbol, phi.subterms)
        for k, t in enumerate(new_subterms):
            rep = subst.get(symref(t), None)
            if rep is None:
                new_subterms[k] = visit(t, subst)
            else:
                new_subterms[k] = rep
        phi.subterms = tuple(new_subterms)
        return phi
    elif isinstance(phi, CompoundTerm):
        new_subterms = list(phi.subterms)
        phi = CompoundTerm(phi.symbol, phi.subterms)
        for k, t in enumerate(new_subterms):
            rep = subst.get(symref(t), None)
            if rep is None:
                new_subterms[k] = visit(t, subst)
            else:
                new_subterms[k] = rep
        phi.subterms = tuple(new_subterms)
        return phi
    return phi
Ejemplo n.º 18
0
    def assert_action(self, op):
        """ For given operator op and timestep t, assert the SMT expression:
                op@t --> op.precondition@t
                op@t --> op.effects@(t+1)
        """
        ml = self.metalang
        vart = _get_timestep_var(ml)
        apred = ml.get_predicate(op.name)

        vars_ = generate_action_arguments(ml, op)  # Don't use the timestep arg
        substitution = {
            symref(param): arg
            for param, arg in zip(op.parameters, vars_)
        }
        args = vars_ + [vart]
        happens = apred(*args)

        prec = term_substitution(flatten(op.precondition), substitution)
        a_implies_prec = forall(*args,
                                implies(happens, self.to_metalang(prec, vart)))
        self.theory.append(a_implies_prec)

        for eff in op.effects:
            eff = term_substitution(eff, substitution)
            antec = happens

            # Prepend the effect condition, if necessary:
            if not isinstance(eff.condition, Tautology):
                antec = land(antec, self.to_metalang(eff.condition, vart))

            if isinstance(eff, fs.AddEffect):
                a_implies_eff = implies(
                    antec, self.to_metalang(eff.atom, vart + 1, subt=vart))

            elif isinstance(eff, fs.DelEffect):
                a_implies_eff = implies(
                    antec, self.to_metalang(~eff.atom, vart + 1, subt=vart))

            elif isinstance(eff, fs.FunctionalEffect):
                lhs = self.to_metalang(eff.lhs, vart + 1, subt=vart)
                rhs = self.to_metalang(eff.rhs, vart, subt=vart)
                a_implies_eff = implies(antec, lhs == rhs)
            else:
                raise TransformationError(f"Can't compile effect {eff}")
            self.theory.append(forall(*args, a_implies_eff))
Ejemplo n.º 19
0
    def resolve_constant(self, c: Constant, sort: Sort = None):
        if sort is None:
            sort = c.sort

        if sort == self.smtlang.Integer:
            return Int(c.symbol)

        if sort == self.smtlang.Real:
            return Real(c.symbol)

        if isinstance(sort, Interval):
            return self.resolve_constant(c, parent(sort))

        if isinstance(sort, Set):
            return self.resolve_constant(c, parent(sort))

        # Otherwise we must have an enumerated type and simply return the object ID
        return Int(self.object_ids[symref(c)])
Ejemplo n.º 20
0
def _index_state_variables(statevars):
    indexed = dict()
    for v in statevars:
        indexed[symref(v.to_atom())] = v
    return indexed
Ejemplo n.º 21
0
    def compute_interferences(self, operators):
        # TODO Deprecated - to be removed
        posprec = defaultdict(list)
        negprec = defaultdict(list)
        funprec = defaultdict(list)
        addeff = defaultdict(list)
        deleff = defaultdict(list)
        funeff = defaultdict(list)
        addalleff = defaultdict(list)
        delalleff = defaultdict(list)
        funalleff = defaultdict(list)

        mutexes = set()
        interferences = defaultdict(list)

        # Classify precondition atoms
        for op in operators:
            pos, neg, fun = classify_atom_occurrences_in_formula(
                op.precondition)
            _ = [posprec[a].append(str(op)) for a in pos]
            _ = [negprec[a].append(str(op)) for a in neg]
            _ = [funprec[a].append(str(op)) for a in fun]

        # Analyze effects
        for op in operators:
            for eff in op.effects:
                if not isinstance(
                        eff,
                    (fs.AddEffect, fs.DelEffect, fs.FunctionalEffect)):
                    raise TransformationError(f'Cannot handle effect "{eff}"')
                atom = eff.atom if isinstance(eff, (fs.AddEffect,
                                                    fs.DelEffect)) else eff.lhs

                if self.is_state_variable(atom):
                    if isinstance(eff, fs.AddEffect):
                        addeff[symref(atom)].append(str(op))
                    elif isinstance(eff, fs.DelEffect):
                        deleff[symref(atom)].append(str(op))
                    else:
                        funeff[symref(atom)].append(str(op))
                else:
                    if isinstance(eff, fs.AddEffect):
                        addalleff[atom.predicate].append(str(op))
                    elif isinstance(eff, fs.DelEffect):
                        delalleff[atom.predicate].append(str(op))
                    else:
                        funalleff[atom.predicate].append(str(op))

        def add_mutex(op1, op2):
            if str(op1) != str(op2):
                mutexes.add(frozenset({str(op1), str(op2)}))

        # Compute mutexes
        for op in operators:
            for eff in op.effects:
                atom = eff.atom if isinstance(eff, (fs.AddEffect,
                                                    fs.DelEffect)) else eff.lhs
                if self.is_state_variable(atom):

                    if isinstance(eff, fs.AddEffect):
                        for conflict in itertools.chain(
                                negprec[symref(atom)], deleff[symref(atom)],
                                delalleff[atom.predicate]):
                            add_mutex(op, conflict)

                    elif isinstance(eff, fs.DelEffect):
                        for conflict in itertools.chain(
                                posprec[symref(atom)], addeff[symref(atom)],
                                addalleff[atom.predicate]):
                            add_mutex(op, conflict)
                    else:
                        for conflict in itertools.chain(
                                funprec[symref(atom)], funeff[symref(atom)],
                                funalleff):
                            add_mutex(op, conflict)
                        # TODO We need to take into account the RHS !!

        return interferences, mutexes
Ejemplo n.º 22
0
 def is_state_variable(self, expression):
     return symref(expression) in self.statevaridx