예제 #1
0
    def setup_metalang(self, problem):
        """ Set up the Tarski metalanguage where we will build the SMT compilation. """
        lang = problem.language
        theories = lang.theories | {Theory.EQUALITY, Theory.ARITHMETIC}
        ml = tarski.fstrips.language(f"{lang.name}-smt", theories=theories)

        # Declare all sorts
        for s in lang.sorts:
            if not s.builtin and s.name != "object":
                if isinstance(s, Interval):
                    self.sort_map[s] = ml.interval(s.name,
                                                   parent(s).name,
                                                   s.lower_bound,
                                                   s.upper_bound)
                else:
                    self.sort_map[s] = ml.sort(s.name, parent(s).name)

        # Map remaining sorts
        self.sort_map[lang.Object] = ml.Object

        if Theory.ARITHMETIC in lang.theories:
            self.sort_map[lang.Integer] = ml.Integer
            self.sort_map[lang.Natural] = ml.Natural
            self.sort_map[lang.Real] = ml.Real

        if Theory.SETS in lang.theories:
            self.sort_map[sorts.Set(lang,
                                    lang.Object)] = sorts.Set(ml, ml.Object)
            self.sort_map[sorts.Set(lang,
                                    lang.Integer)] = sorts.Set(ml, ml.Integer)

        # Declare an extra "timestep" sort with a large range, which we'll adjust once we know the horizon
        ml.Timestep = ml.interval("timestep", ml.Natural, 0, 99999)

        # Declare all objects in the metalanguage
        for o in lang.constants():
            ml.constant(o.symbol, o.sort.name)

        # Declare all symbols
        for s in get_symbols(lang, type_="all", include_builtin=False):
            timestep_argument = [_get_timestep_sort(ml)
                                 ] if self.symbol_is_fluent(s) else []
            if isinstance(s, Predicate):
                sort = [t.name for t in s.sort] + timestep_argument
                ml.predicate(s.name, *sort)
            else:
                sort = [t.name for t in s.domain
                        ] + timestep_argument + [s.codomain.name]
                ml.function(s.name, *sort)

        # Declare extra function symbols for the actions
        for a in problem.actions.values():
            sort = [x.sort.name
                    for x in a.parameters] + [_get_timestep_sort(ml)]
            ml.predicate(a.name, *sort)

        return ml
예제 #2
0
def test_parent_types():
    lang, human, animal, being = get_children_parent_types()

    assert parent(human) == animal
    assert parent(animal) == being
    assert parent(being) == lang.Object
    assert parent(lang.Object) is None

    assert len(ancestors(human)) == 3  # i.e. including the top "object" sort
    assert being in ancestors(human)
    assert lang.Object in ancestors(human)

    # Adding two different parents to the same sort raises an error
    with pytest.raises(err.LanguageError):
        lang.set_parent(being, lang.Object)
예제 #3
0
    def resolve_constant(self, c: Constant, sort: Sort = None):
        if sort is None:
            sort = c.sort

        if sort in (self.smtlang.Integer, self.smtlang.Real):
            return str(sort.literal(c))

        if isinstance(sort, Interval):
            return self.resolve_constant(c, parent(sort))

        if isinstance(sort, Set):
            # This is slightly tricky, since set denotations are encoded with strings, not with Constant objects
            assert isinstance(c.symbol, set)
            elems = [self.resolve_constant(self.smtlang.get(x)) if isinstance(x, str) else str(x) for x in c.symbol]

            if len(c.symbol) == 0:
                return f"(as emptyset {resolve_type_for_sort(self.smtlang, c.sort)})"
            elif len(c.symbol) == 1:
                return f"(singleton {' '.join(elems)})"
            else:
                # e.g. if the set is {1, 2, 3, 4}, we want to output: (insert 1 2 3 (singleton 4))
                return f'(insert {" ".join(elems[:-1])} (singleton {elems[-1]}))'

        # Otherwise we must have an enumerated type and simply return the object ID
        return str(self.object_ids[symref(c)])
예제 #4
0
    def resolve_constant(self, c: Constant, sort: Sort = None):
        if sort is None:
            sort = c.sort

        if sort == self.smtlang.Integer:
            return Int(c.symbol)

        if sort == self.smtlang.Real:
            return Real(c.symbol)

        if isinstance(sort, Interval):
            return self.resolve_constant(c, parent(sort))

        if isinstance(sort, Set):
            return self.resolve_constant(c, parent(sort))

        # Otherwise we must have an enumerated type and simply return the object ID
        return Int(self.object_ids[symref(c)])
예제 #5
0
 def get_types(self):
     from tarski.syntax.sorts import parent
     type_decl_list = []
     for S in self.task.L.sorts:
         if S.builtin or S.name == 'object':
             continue
         if isinstance(S, Interval):
             self.need_constraints[S.name] = S
             continue
         type_decl_list += ['{} : {};'.format(S.name, parent(S).name)]
         self.need_obj_decl += [S]
     return '\n'.join(type_decl_list)
예제 #6
0
 def get_types(self):
     res = []
     for t in self.lang.sorts:
         if t.builtin or t == self.lang.Object:
             continue  # Don't declare builtin elements
         tname = tarsky_to_pddl_type(t)
         p = parent(t)
         if p:
             res.append("{} - {}".format(tname, tarsky_to_pddl_type(p)))
         else:
             res.append(tname)
     return ("\n" + _TAB*2).join(res)
예제 #7
0
    def resolve_type_for_sort(self, s):
        if s == self.smtlang.Integer:
            return INT

        if s == self.smtlang.Real:
            return REAL

        if s is bool:
            return BOOL

        if isinstance(s, Interval):
            return self.resolve_type_for_sort(parent(s))

        # Otherwise we have an enumerated type, which we'll model as an integer
        return INT
예제 #8
0
def resolve_type_for_sort(lang, s):
    if s == lang.Integer:
        return "Int"

    if s == lang.Real:
        return "Real"

    if s is bool:
        return "Bool"

    if isinstance(s, Interval):
        return resolve_type_for_sort(lang, parent(s))

    if isinstance(s, Set):
        t = resolve_type_for_sort(lang, s.subtype)
        return f'(Set {t})'

    # Otherwise we have an enumerated type, which we'll model as an integer
    return "Int"
예제 #9
0
    def create_variable(self, elem, sort=None, name=None):
        # TODO This code is currently unused and needs to be revised / removed
        assert 0
        sort = elem.sort if sort is None else sort
        name = str(elem) if name is None else name

        if sort == self.smtlang.Integer:
            return Symbol(name, INT)

        if sort == self.smtlang.Real:
            return Symbol(name, REAL)

        if isinstance(sort, Interval):
            # Let's go seek the underlying type of the interval recursively!
            return self.create_variable(elem, parent(sort), name)

        # Otherwise assume we have a enumerated type and simply return the index of the object
        y_var = Symbol(name, INT)
        self.create_enum_type_domain_axioms(y_var, elem.sort)
        return y_var