示例#1
0
    def assert_frame_axioms(self):
        ml = self.metalang
        tvar = _get_timestep_var(ml)

        # First deal with predicates;
        for p in get_symbols(self.lang, type_="all", include_builtin=False):
            if not self.symbol_is_fluent(p):
                continue

            self.comments[len(self.theory)] = f";; Frame axiom for symbol {p}:"
            lvars = generate_symbol_arguments(self.lang, p)
            atom = p(*lvars)
            fquant = generate_symbol_arguments(ml, p) + [tvar]

            if isinstance(p, Predicate):
                # pos: not p(x, t) and p(x, t+1)  => \gamma_p^+(x, t)
                # neg: p(x, t) and not p(x, t+1)  => \gamma_p^-(x, t)
                at_t = self.to_metalang(atom, tvar)
                at_t1 = self.to_metalang(atom, tvar + 1)

                pos = forall(*fquant,
                             implies(~at_t & at_t1, self.gamma_pos[p.name]))
                neg = forall(*fquant,
                             implies(at_t & ~at_t1, self.gamma_neg[p.name]))
                self.theory += [pos, neg]
            else:
                # fun: f(x, t) != f(x, t+1) => \gamma_f[y/f(x, t+1)]
                yvar = ml.variable("y", ml.get_sort(p.codomain.name))
                at_t = self.to_metalang(atom, tvar)
                at_t1 = self.to_metalang(atom, tvar + 1)
                gamma_replaced = term_substitution(self.gamma_fun[p.name],
                                                   {symref(yvar): at_t1})
                fun = forall(*fquant, implies(at_t != at_t1, gamma_replaced))
                self.theory += [fun]
示例#2
0
    def extract_parallel_plan(self, model, horizon, print_full_model):
        plan = defaultdict(list)

        for aname, (pred, smt_node) in self.actionvars.items():
            for binding in compute_signature_bindings(self.smtlang,
                                                      pred.domain, horizon):
                term = self.rewrite(pred(*binding), {}, horizon)
                if model[term].constant_value():
                    timestep = int(binding[-1].name)
                    args = " ".join(str(elem.name) for elem in binding[:-1])
                    plan[timestep] += [f"({aname} {args})"]

        # Useful for debugging
        if print_full_model:
            # print("Model:", model)
            print("A list of all atoms: ")
            for pred in get_symbols(self.smtlang,
                                    type_="all",
                                    include_builtin=False):
                print(pred)
                for binding in compute_signature_bindings(
                        self.smtlang, pred.domain, horizon + 1):
                    l0_term = pred(*binding)
                    term = self.rewrite(l0_term, {}, horizon)
                    print(f"{l0_term}: {model[term]}")
                    # if model[term].constant_value():
                    #     print(term)

        return plan
示例#3
0
    def setup_metalang(self, problem):
        """ Set up the Tarski metalanguage where we will build the SMT compilation. """
        lang = problem.language
        theories = lang.theories | {Theory.EQUALITY, Theory.ARITHMETIC}
        ml = tarski.fstrips.language(f"{lang.name}-smt", theories=theories)

        # Declare all sorts
        for s in lang.sorts:
            if not s.builtin and s.name != "object":
                if isinstance(s, Interval):
                    self.sort_map[s] = ml.interval(s.name,
                                                   parent(s).name,
                                                   s.lower_bound,
                                                   s.upper_bound)
                else:
                    self.sort_map[s] = ml.sort(s.name, parent(s).name)

        # Map remaining sorts
        self.sort_map[lang.Object] = ml.Object

        if Theory.ARITHMETIC in lang.theories:
            self.sort_map[lang.Integer] = ml.Integer
            self.sort_map[lang.Natural] = ml.Natural
            self.sort_map[lang.Real] = ml.Real

        if Theory.SETS in lang.theories:
            self.sort_map[sorts.Set(lang,
                                    lang.Object)] = sorts.Set(ml, ml.Object)
            self.sort_map[sorts.Set(lang,
                                    lang.Integer)] = sorts.Set(ml, ml.Integer)

        # Declare an extra "timestep" sort with a large range, which we'll adjust once we know the horizon
        ml.Timestep = ml.interval("timestep", ml.Natural, 0, 99999)

        # Declare all objects in the metalanguage
        for o in lang.constants():
            ml.constant(o.symbol, o.sort.name)

        # Declare all symbols
        for s in get_symbols(lang, type_="all", include_builtin=False):
            timestep_argument = [_get_timestep_sort(ml)
                                 ] if self.symbol_is_fluent(s) else []
            if isinstance(s, Predicate):
                sort = [t.name for t in s.sort] + timestep_argument
                ml.predicate(s.name, *sort)
            else:
                sort = [t.name for t in s.domain
                        ] + timestep_argument + [s.codomain.name]
                ml.function(s.name, *sort)

        # Declare extra function symbols for the actions
        for a in problem.actions.values():
            sort = [x.sort.name
                    for x in a.parameters] + [_get_timestep_sort(ml)]
            ml.predicate(a.name, *sort)

        return ml
示例#4
0
def compute_choice_symbols(lang, init):
    # Note that ATM we cannot consider that predicate symbols without initial denotation are choice
    # symbols, because of the closed world assumption (i.e. no denotation already means emptyset denotation).
    # Of course we can devise some other mechanism to explicitly represent choice symbols that will avoid this problem.
    choices = set()
    for s in get_symbols(lang, type_="function", include_builtin=False):
        if s.signature not in init.function_extensions:
            choices.add(s)
    return choices
示例#5
0
    def setup_language(self, smtlang):
        smt_funs = dict()
        smt_actions = dict()
        for s in get_symbols(smtlang, type_="all", include_builtin=False):
            # arity=0 implies the symbol is not fluent, but static symbols of arity 0 should have
            # already been compiled away
            assert s.arity > 0
            fun, ftype = self.create_function_type(s)
            smt_funs[s.name] = (fun, ftype)

            if s.name in self.action_names:
                smt_actions[s.name] = (s, fun)
        return smt_funs, smt_actions
示例#6
0
    def assert_initial_state(self):
        for p in get_symbols(self.lang, type_="all", include_builtin=False):
            if self.symbol_is_choice(p):
                continue  # The value of choice symbols is not determined a priori
            for binding in compute_signature_bindings(p.domain):
                expr = p(*binding)

                if isinstance(p, Predicate):
                    x = expr if self.problem.init[expr] else ~expr
                    self.theory.append(self.to_metalang(x, 0))
                else:
                    self.theory.append(
                        self.to_metalang(expr == self.problem.init[expr], 0))
示例#7
0
def test_symbol_casing():
    """ Test the special casing for PDDL parsing. See issue #67 """
    problem = parse_benchmark_instance("spider-sat18-strips:p01.pddl")

    # PDDL parsing represents all symbols in lowercase. The PDDL contains a predicate TO-DEAL, but will get lowercased
    _ = problem.language.get_predicate("to-deal")
    with pytest.raises(UndefinedPredicate):
        _ = problem.language.get_predicate("TO-DEAL")

    # PDDL predicate current-deal remains unaffected
    _ = problem.language.get_predicate("current-deal")

    assert "to-deal" in set(x.symbol for x in get_symbols(
        problem.language, type_="predicate", include_builtin=False))
示例#8
0
    def translate(self, theory):
        result = SMTLibTheory()

        # Functions and predicates
        for s in get_symbols(self.smtlang, type_="all", include_builtin=False):
            # arity=0 implies the symbol is not fluent, but static symbols of arity 0 should have
            # already been compiled away
            assert s.arity > 0
            dom = [resolve_type_for_sort(self.smtlang, sort) for sort in s.domain]
            codomain = resolve_type_for_sort(self.smtlang, s.codomain) if isinstance(s, Function) else "Bool"
            result.declarations.append(f'(declare-fun {s.name} ({" ".join(dom)}) {codomain})')

        # Theory translation
        ut = Untyper(self.smtlang, self.sort_bounds)
        for i, phi in enumerate(theory, start=1):
            phi_prime = ut.untype(phi)  # Compile away first possible typed quantifications
            rewritten = self.run(phi_prime, inplace=False)
            result.assertions.append(f"(assert {rewritten.smtlib})")

        return result
示例#9
0
    def compute_gammas(self, problem, ml):
        """ Compute the gamma sentences for all (fluent) symbols """
        lang = problem.language
        gamma_pos = dict()
        gamma_neg = dict()
        gamma_f = dict()

        for s in get_symbols(lang, type_="all", include_builtin=False):
            if not self.symbol_is_fluent(s):
                continue

            if isinstance(s, Predicate):
                gamma_pos[s.name] = self.compute_gamma(ml, s,
                                                       self.eff_index['add'])
                gamma_neg[s.name] = self.compute_gamma(ml, s,
                                                       self.eff_index['del'])
            else:
                gamma_f[s.name] = self.compute_gamma(ml, s,
                                                     self.eff_index['fun'])

        return gamma_pos, gamma_neg, gamma_f
示例#10
0
def test_symbol_classification(instance_file, domain_file):
    # Test the symbol classification procedure for a few standard benchmarks that we parse entirely
    problem = reader().read_problem(domain_file, instance_file)
    fluent, static = approximate_symbol_fluency(problem)

    expected = {  # A compilation of the expected values for each tested domain (including total-cost terms!)
        "grid-visit-all": (2, 1),
        "Trucks": (6, 4),
        "BLOCKS": (5, 0),
        "gripper-strips": (4, 3),
        "elevators-sequencedstrips": (4, 6),
        "sokoban-sequential": (3, 3),
        "parking": (5, 0),
        "transport": (3, 3),
        "spider": (15, 6),
        "counters-fn": (1, 1),
        "settlers": (26, 23),
        "nurikabe": (9, 3),
    }
    # First make sure that the amount of expected fluent + static add up to the total number of symbols
    assert len(set(get_symbols(problem.language, include_builtin=False))) == sum(expected[problem.domain_name])
    assert (len(fluent), len(static)) == expected[problem.domain_name]