def test_term_hash_raises_exception(): # from tarski.fstrips import language # from tarski.syntax import symref lang = fs.language("test") counter = defaultdict(int) c = lang.constant('c', 'object') f = lang.function('f', 'object', 'object') # Trying to use associative containers on terms raises a standard TypeError with pytest.raises(TypeError): counter[c] += 2 with pytest.raises(TypeError): counter[f(c)] += 2 # Using symrefs instead works correctly: counter[symref(c)] += 2 assert counter[symref(c)] == 2 counter[symref(f(c))] += 2 assert counter[symref(f(c))] == 2 # Atoms and in general formulas can be used without problem atom = f(c) == c counter[atom] += 2 assert counter[atom] == 2
def test_sort_id_assignment(): lang = tsk.fstrips.language(theories=[]) # Create some sort hierarchy lang.sort("s1", "object") lang.sort("s2", "object") lang.sort("t1", "s2") lang.sort("t2", "s2") # Create some objects in a sequence that alternates types lang.constant("a1", "s1") lang.constant("o1", "object") lang.constant("a2", "s1") lang.constant("b3", "t1") lang.constant("a3", "s1") lang.constant("o2", "object") lang.constant("b1", "s2") lang.constant("b2", "t1") lang.constant("b4", "s2") sortmap = compute_direct_sort_map(lang) cards = {s.name: len(objs) for s, objs in sortmap.items()} assert cards == {'object': 2, 's1': 3, 's2': 2, 't1': 2, 't2': 0} bounds, ids = compute_sort_id_assignment(lang) assert bounds[lang.Object] == (0, 9) assert {ids[symref(lang.get('o1'))], ids[symref(lang.get('o2'))]} == {7, 8} # Note that the following relies on dict (the dict of sort parents) iterating over its elements in insertion sort. # see https://stackoverflow.com/a/39980744 assert { ids[symref(lang.get('a1'))], ids[symref(lang.get('a2'))], ids[symref(lang.get('a3'))] } == {0, 1, 2}
def test_simple_expression_substitutions(): lang = tarski.benchmarks.blocksworld.generate_strips_bw_language(nblocks=2) clear, b1, b2 = [lang.get(name) for name in ('clear', 'b1', 'b2')] x, y = lang.variable('x', 'object'), lang.variable('y', 'object') formula = clear(x) replaced = substitute_expression(formula, substitution={symref(x): b1}, inplace=False) replaced2 = substitute_expression(formula, substitution={symref(x): b2}, inplace=False) assert not formula.is_syntactically_equal(replaced) assert str(formula) == "clear(x)" and str(replaced) == "clear(b1)" and str( replaced2) == "clear(b2)" # Now let's do the same but inplace replaced = substitute_expression(formula, substitution={symref(x): b1}, inplace=True) assert formula.is_syntactically_equal(replaced) assert str(formula) == str(replaced) == "clear(b1)" formula = forall(x, clear(x) & clear(y)) replaced = substitute_expression(formula, substitution={ symref(x): b1, symref(y): b2 }, inplace=False) assert str(formula) == "forall x : ((clear(x) and clear(y)))" and \ str(replaced) == "forall b1 : ((clear(b1) and clear(b2)))"
def test_variables_classification(): tw = tarskiworld.create_small_world() x = tw.variable('x', tw.Object) y = tw.variable('y', tw.Object) s = neg(land(tw.Cube(x), exists(y, land(tw.Tet(x), tw.LeftOf(x, y))))) free = free_variables(s) assert len(free) == 1 and symref(free[0]) == symref(x) assert len(all_variables(s)) == 2
def test_symbol_ref_in_sets_equality_is_exact_syntactic_match(): tw = tarskiworld.create_small_world() x = tw.variable('x', tw.Object) y = tw.variable('y', tw.Object) x_ref = symref(x) y_ref = symref(y) S = set() S.add(x_ref) assert y_ref not in S assert x_ref in S S.remove(x_ref) assert len(S) == 0
def test_simplifier(): problem = generate_fstrips_counters_problem(ncounters=3) lang = problem.language value, max_int, counter, val_t, c1 = lang.get('value', 'max_int', 'counter', 'val', 'c1') x = lang.variable('x', counter) two, three, six = [lang.constant(c, val_t) for c in (2, 3, 6)] s = Simplify(problem, problem.init) assert symref(s.simplify_expression(x)) == symref(x) assert symref(s.simplify_expression(value(c1) < max_int())) == symref( value(c1) < six) # max_int evaluates to 6 assert s.simplify_expression(two < max_int()) is True assert s.simplify_expression(two > three) is False # conjunction evaluates to false because of first conjunct: falseconj = land(two > three, value(c1) < max_int()) assert s.simplify_expression(falseconj) is False assert s.simplify_expression(neg(falseconj)) is True # first conjunct gets removed: assert str(s.simplify_expression(land( two < three, value(c1) < max_int()))) == '<(value(c1),6)' # first disjunct gets removed because it is false assert str(s.simplify_expression(lor( two > three, value(c1) < max_int()))) == '<(value(c1),6)' assert str(s.simplify_expression(forall( x, value(x) < max_int()))) == 'forall x : (<(value(x),6))' assert s.simplify_expression(forall(x, two + three <= 6)) is True inc = problem.get_action('increment') simp = s.simplify_action(inc) assert str(simp.precondition) == '<(value(c),6)' assert str(simp.effects) == str(inc.effects) eff = UniversalEffect(x, [value(x) << three]) assert str( s.simplify_effect(eff)) == '(T -> forall (x) : ((T -> value(x) := 3)))' simp = s.simplify() assert str(simp.get_action('increment').precondition) == '<(value(c),6)' # Make sure there is no mention to the compiled away "max_int" symbol in the language assert not simp.language.has_function("max_int") # Make sure there is no mention to the compiled away "max_int" symbol in the initial state exts = list(simp.init.list_all_extensions().keys()) assert ('max_int', 'val') not in exts
def test_term_refs(): lang = fstrips.language(theories=[Theory.ARITHMETIC]) _ = lang.function('f', lang.Object, lang.Integer) o1 = lang.constant("o1", lang.Object) o2 = lang.constant("o2", lang.Object) tr1 = symref(o1) tr2 = symref(o1) tr3 = symref(o2) assert tr1 == tr2 assert tr1 != tr3
def test_model_list_extensions(): lang = tarski.language(theories=[]) p = lang.predicate('p', lang.Object, lang.Object) f = lang.function('f', lang.Object, lang.Object) o1 = lang.constant("o1", lang.Object) o2 = lang.constant("o2", lang.Object) model = Model(lang) model.evaluator = evaluate model.set(f, o1, o2) model.add(p, o1, o2) extensions = model.list_all_extensions() ext_f = extensions[f.signature] ext_p = extensions[p.signature] # We want to test that the `list_all_extensions` method correctly unwraps all TermReferences in the internal # representation of the model and returns _only_ actual Tarski terms. Testing this is a bit involved, as of # course we cannot just check for (o1, o2) in ext_f, because that will trigger the wrong __eq__ and __hash__ # methods - in other words, to test this we precisely need to wrap things back into TermReferences, so that we can # compare them properly assert len(ext_f) == 1 and len(ext_p) == 1 p, v = ext_f[0] assert symref(p) == symref(o1) and symref(v) == symref(o2) v1, v2 = ext_p[0] assert symref(v1) == symref(o1) and symref(v2) == symref(o2)
def analyze_action_effects(lang, schemas): """ Compile an index of action effects according to the type of effect (add/del/functional) and the symbol they affect. """ index = { "add": defaultdict(list), "del": defaultdict(list), "fun": defaultdict(list) } for a in schemas: substitution = { symref(param): arg for param, arg in zip(a.parameters, generate_action_arguments(lang, a)) } for eff in a.effects: if not isinstance( eff, (fs.AddEffect, fs.DelEffect, fs.FunctionalEffect)): raise TransformationError(f'Cannot handle effect "{eff}"') # Let's substitute the action parameters for some standard variable names such as z1, z2, ... so that # later on in the compilation we can use them off the self. eff = term_substitution(eff, substitution) atom = eff.atom if isinstance(eff, (fs.AddEffect, fs.DelEffect)) else eff.lhs if isinstance(eff, fs.AddEffect): index["add"][atom.symbol.name].append((a, eff)) elif isinstance(eff, fs.DelEffect): index["del"][atom.symbol.name].append((a, eff)) else: index["fun"][atom.symbol.name].append((a, eff)) return index
def assert_frame_axioms(self): ml = self.metalang tvar = _get_timestep_var(ml) # First deal with predicates; for p in get_symbols(self.lang, type_="all", include_builtin=False): if not self.symbol_is_fluent(p): continue self.comments[len(self.theory)] = f";; Frame axiom for symbol {p}:" lvars = generate_symbol_arguments(self.lang, p) atom = p(*lvars) fquant = generate_symbol_arguments(ml, p) + [tvar] if isinstance(p, Predicate): # pos: not p(x, t) and p(x, t+1) => \gamma_p^+(x, t) # neg: p(x, t) and not p(x, t+1) => \gamma_p^-(x, t) at_t = self.to_metalang(atom, tvar) at_t1 = self.to_metalang(atom, tvar + 1) pos = forall(*fquant, implies(~at_t & at_t1, self.gamma_pos[p.name])) neg = forall(*fquant, implies(at_t & ~at_t1, self.gamma_neg[p.name])) self.theory += [pos, neg] else: # fun: f(x, t) != f(x, t+1) => \gamma_f[y/f(x, t+1)] yvar = ml.variable("y", ml.get_sort(p.codomain.name)) at_t = self.to_metalang(atom, tvar) at_t1 = self.to_metalang(atom, tvar + 1) gamma_replaced = term_substitution(self.gamma_fun[p.name], {symref(yvar): at_t1}) fun = forall(*fquant, implies(at_t != at_t1, gamma_replaced)) self.theory += [fun]
def resolve_constant(self, c: Constant, sort: Sort = None): if sort is None: sort = c.sort if sort in (self.smtlang.Integer, self.smtlang.Real): return str(sort.literal(c)) if isinstance(sort, Interval): return self.resolve_constant(c, parent(sort)) if isinstance(sort, Set): # This is slightly tricky, since set denotations are encoded with strings, not with Constant objects assert isinstance(c.symbol, set) elems = [self.resolve_constant(self.smtlang.get(x)) if isinstance(x, str) else str(x) for x in c.symbol] if len(c.symbol) == 0: return f"(as emptyset {resolve_type_for_sort(self.smtlang, c.sort)})" elif len(c.symbol) == 1: return f"(singleton {' '.join(elems)})" else: # e.g. if the set is {1, 2, 3, 4}, we want to output: (insert 1 2 3 (singleton 4)) return f'(insert {" ".join(elems[:-1])} (singleton {elems[-1]}))' # Otherwise we must have an enumerated type and simply return the object ID return str(self.object_ids[symref(c)])
def test_term_refs_compound(): lang = fstrips.language(theories=[Theory.ARITHMETIC]) f = lang.function('f', lang.Object, lang.Integer) o1 = lang.constant("o1", lang.Object) o2 = lang.constant("o2", lang.Object) _ = lang.get('f') t1 = f(o1) t2 = f(o1) t3 = f(o2) assert t1.symbol == t2.symbol tr1 = symref(t1) tr2 = symref(t2) tr3 = symref(t3) assert tr1 == tr2 assert tr1 != tr3
def test_formula_refs(): lang = fstrips.language('arith', [Theory.EQUALITY, Theory.ARITHMETIC]) _ = lang.constant(1, lang.Integer) x = lang.function('x', lang.Integer) y = lang.function('y', lang.Integer) phi = (x() <= y()) & (y() <= x()) psi = (x() >= y()) & (y() <= x()) gamma = (x() <= y()) & (y() <= x()) fr1 = symref(phi) fr2 = symref(psi) fr3 = symref(gamma) assert fr1 == fr3 assert fr1 != fr2
def smt_fun_application(self, phi, varmap): key = symref(phi) try: return self.vars[key] except KeyError: params = [self.rewrite(st, varmap) for st in phi.subterms] fun, ftype = self.smt_functions[phi.symbol.name] self.vars[key] = res = pysmt.shortcuts.Function(fun, params) return res
def test_action_grounding_bw(): problem = generate_strips_blocksworld_problem() b1, b2, b3, clear, on, ontable, handempty, holding = \ problem.language.get('b1', 'b2', 'b3', 'clear', 'on', 'ontable', 'handempty', 'holding') unstack = problem.get_action("unstack") x1, x2 = [symref(x) for x in unstack.parameters] # Unstack has two parameters ground = ground_schema_into_plain_operator(unstack, { x1: b1, x2: b2 }) # i.e. the operator unstack(b1, b2) assert isinstance(ground, PlainOperator) and \ str(ground.precondition) == '(on(b1,b2) and clear(b1) and handempty())'
def smt_variable(self, expr): # TODO This code is currently unused and needs to be revised / removed """ Return the (possibly cached) SMT theory variable that corresponds to the given Tarski logical expression, which can be either an atom (e.g. clear(b1)) or a compound term representing a state variable (e.g. value(c1)). """ assert isinstance(expr, (Atom, CompoundTerm, Variable)) key = symref(expr) try: return self.vars[key] except KeyError: creator = self.create_bool_term if isinstance( expr, Atom) else self.create_variable self.vars[key] = res = creator(expr, name=str(expr)) return res
def visit(phi, subst): if isinstance(phi, CompoundFormula): subformulas = [visit(f, subst) for f in phi.subformulas] return CompoundFormula(phi.connective, subformulas) elif isinstance(phi, QuantifiedFormula): if any(symref(x) in subst for x in phi.variables): raise SubstitutionError( phi, subst, 'Attempted to substitute variable bound by quantifier') formula = visit(phi.formula, subst) return QuantifiedFormula(phi.quantifier, phi.variables, formula) elif isinstance(phi, Atom): new_subterms = list(phi.subterms) phi = Atom(phi.symbol, phi.subterms) for k, t in enumerate(new_subterms): rep = subst.get(symref(t), None) if rep is None: new_subterms[k] = visit(t, subst) else: new_subterms[k] = rep phi.subterms = tuple(new_subterms) return phi elif isinstance(phi, CompoundTerm): new_subterms = list(phi.subterms) phi = CompoundTerm(phi.symbol, phi.subterms) for k, t in enumerate(new_subterms): rep = subst.get(symref(t), None) if rep is None: new_subterms[k] = visit(t, subst) else: new_subterms[k] = rep phi.subterms = tuple(new_subterms) return phi return phi
def assert_action(self, op): """ For given operator op and timestep t, assert the SMT expression: op@t --> op.precondition@t op@t --> op.effects@(t+1) """ ml = self.metalang vart = _get_timestep_var(ml) apred = ml.get_predicate(op.name) vars_ = generate_action_arguments(ml, op) # Don't use the timestep arg substitution = { symref(param): arg for param, arg in zip(op.parameters, vars_) } args = vars_ + [vart] happens = apred(*args) prec = term_substitution(flatten(op.precondition), substitution) a_implies_prec = forall(*args, implies(happens, self.to_metalang(prec, vart))) self.theory.append(a_implies_prec) for eff in op.effects: eff = term_substitution(eff, substitution) antec = happens # Prepend the effect condition, if necessary: if not isinstance(eff.condition, Tautology): antec = land(antec, self.to_metalang(eff.condition, vart)) if isinstance(eff, fs.AddEffect): a_implies_eff = implies( antec, self.to_metalang(eff.atom, vart + 1, subt=vart)) elif isinstance(eff, fs.DelEffect): a_implies_eff = implies( antec, self.to_metalang(~eff.atom, vart + 1, subt=vart)) elif isinstance(eff, fs.FunctionalEffect): lhs = self.to_metalang(eff.lhs, vart + 1, subt=vart) rhs = self.to_metalang(eff.rhs, vart, subt=vart) a_implies_eff = implies(antec, lhs == rhs) else: raise TransformationError(f"Can't compile effect {eff}") self.theory.append(forall(*args, a_implies_eff))
def resolve_constant(self, c: Constant, sort: Sort = None): if sort is None: sort = c.sort if sort == self.smtlang.Integer: return Int(c.symbol) if sort == self.smtlang.Real: return Real(c.symbol) if isinstance(sort, Interval): return self.resolve_constant(c, parent(sort)) if isinstance(sort, Set): return self.resolve_constant(c, parent(sort)) # Otherwise we must have an enumerated type and simply return the object ID return Int(self.object_ids[symref(c)])
def _index_state_variables(statevars): indexed = dict() for v in statevars: indexed[symref(v.to_atom())] = v return indexed
def compute_interferences(self, operators): # TODO Deprecated - to be removed posprec = defaultdict(list) negprec = defaultdict(list) funprec = defaultdict(list) addeff = defaultdict(list) deleff = defaultdict(list) funeff = defaultdict(list) addalleff = defaultdict(list) delalleff = defaultdict(list) funalleff = defaultdict(list) mutexes = set() interferences = defaultdict(list) # Classify precondition atoms for op in operators: pos, neg, fun = classify_atom_occurrences_in_formula( op.precondition) _ = [posprec[a].append(str(op)) for a in pos] _ = [negprec[a].append(str(op)) for a in neg] _ = [funprec[a].append(str(op)) for a in fun] # Analyze effects for op in operators: for eff in op.effects: if not isinstance( eff, (fs.AddEffect, fs.DelEffect, fs.FunctionalEffect)): raise TransformationError(f'Cannot handle effect "{eff}"') atom = eff.atom if isinstance(eff, (fs.AddEffect, fs.DelEffect)) else eff.lhs if self.is_state_variable(atom): if isinstance(eff, fs.AddEffect): addeff[symref(atom)].append(str(op)) elif isinstance(eff, fs.DelEffect): deleff[symref(atom)].append(str(op)) else: funeff[symref(atom)].append(str(op)) else: if isinstance(eff, fs.AddEffect): addalleff[atom.predicate].append(str(op)) elif isinstance(eff, fs.DelEffect): delalleff[atom.predicate].append(str(op)) else: funalleff[atom.predicate].append(str(op)) def add_mutex(op1, op2): if str(op1) != str(op2): mutexes.add(frozenset({str(op1), str(op2)})) # Compute mutexes for op in operators: for eff in op.effects: atom = eff.atom if isinstance(eff, (fs.AddEffect, fs.DelEffect)) else eff.lhs if self.is_state_variable(atom): if isinstance(eff, fs.AddEffect): for conflict in itertools.chain( negprec[symref(atom)], deleff[symref(atom)], delalleff[atom.predicate]): add_mutex(op, conflict) elif isinstance(eff, fs.DelEffect): for conflict in itertools.chain( posprec[symref(atom)], addeff[symref(atom)], addalleff[atom.predicate]): add_mutex(op, conflict) else: for conflict in itertools.chain( funprec[symref(atom)], funeff[symref(atom)], funalleff): add_mutex(op, conflict) # TODO We need to take into account the RHS !! return interferences, mutexes
def is_state_variable(self, expression): return symref(expression) in self.statevaridx