def expand(self, clause: typing.Union[Clause, Body, Procedure]) -> typing.Sequence[typing.Union[Body, Clause, Procedure]]:
        """
            Expands the clause/byd/procedure by adding literals from the bottom clause
            :param clause:
            :param variable_lit_dependency:
            :return:
        """
        if isinstance(clause, (Body, Clause)):
            return self._expand_clause(clause)
        else:
            clauses = clause.get_clauses()

            # extend each clause individually
            extensions = []
            for cl_ind in range(len(clauses)):
                clause_extensions = self._expand_clause(clauses[cl_ind])

                for ext_cl_ind in range(len(clause_extensions)):
                    cls = [
                        clauses[x] if x != cl_ind else clause_extensions[ext_cl_ind]
                        for x in range(len(clauses))
                    ]

                    if isinstance(clause, Disjunction):
                        extensions.append(Disjunction(cls))
                    else:
                        extensions.append(Recursion(cls))

            return extensions
def variable_instantiation(
    clause: typing.Union[Clause,Body,Procedure],
    constant: Constant) -> typing.Sequence[typing.Union[Clause,Body,Procedure]]:
    """
    Extends a clause by instantiation, replacing all occurrences
    of a variable with a constant
    """
    if isinstance(clause, (Clause, Body)):
        return _instantiate_var_clause(clause, constant)
    else:
        clauses = clause.get_clauses()

        # extend each clause individually
        extensions = []
        for cl_ind in range(len(clauses)):
            clause_extensions = (_instantiate_var_clause(clauses[cl_ind], constant))
            for ext_cl_ind in range(len(clause_extensions)):
                cls = [
                    clauses[x] if x != cl_ind else clause_extensions[ext_cl_ind]
                    for x in range(len(clauses))
                ]

                if isinstance(clause, Disjunction):
                    extensions.append(Disjunction(cls))
                else:
                    extensions.append(Recursion(cls))
        print("Extensions for {} are {}".format(clause,extensions))
        return extensions
def plain_extension(
    clause: typing.Union[Clause, Body, Procedure],
    predicate: Predicate,
    connected_clauses: bool = True,
    negated: bool = False,
) -> typing.Sequence[typing.Union[Clause, Body, Procedure]]:
    """
    Extends a clause or a procedure without any bias. Only checks for variable type match.
    Adds the predicate to the clause/procedure
    """
    if isinstance(clause, (Clause, Body)):
        if negated:
            return _plain_extend_negation_clause(clause, predicate)
        else:
            return _plain_extend_clause(
                clause, predicate, connected_clause=connected_clauses
            )
    else:
        clauses = clause.get_clauses()

        # extend each clause individually
        extensions = []
        for cl_ind in range(len(clauses)):
            clause_extensions = (
                _plain_extend_clause(
                    clauses[cl_ind], predicate, connected_clause=connected_clauses
                )
                if not negated
                else _plain_extend_negation_clause(clauses[cl_ind], predicate)
            )
            for ext_cl_ind in range(len(clause_extensions)):
                cls = [
                    clauses[x] if x != cl_ind else clause_extensions[ext_cl_ind]
                    for x in range(len(clauses))
                ]

                if isinstance(clause, Disjunction):
                    extensions.append(Disjunction(cls))
                else:
                    extensions.append(Recursion(cls))

        return extensions
Esempio n. 4
0
    def _get_recursions(self, node: Body) -> typing.Sequence[Recursion]:
        """
        Prepares the valid recursions
        """
        pointer_name = self._hypothesis_space.nodes[node]["partner"]
        init_pointer_value = self._pointers[pointer_name]
        last_pointer_value = None

        valid_heads = list(self._hypothesis_space.nodes[node]["heads"].keys())
        recursions = []

        # for each valid head
        for h_ind in range(len(valid_heads)):
            c_head: Atom = valid_heads[h_ind]
            recursive_clause = Clause(c_head, node)

            frontier = [self._pointers[pointer_name]]

            while len(frontier) > 0:
                focus_node = frontier[0]
                frontier = frontier[1:]

                # find matching heads
                focus_node_heads: typing.Sequence[Atom] = list(
                    self._hypothesis_space.nodes[focus_node]["heads"].keys())
                focus_node_heads = [
                    x for x in focus_node_heads
                    if x.get_predicate().get_arg_types() ==
                    c_head.get_predicate().get_arg_types()
                ]

                # prepare recursion
                for bcl_ind in range(len(focus_node_heads)):
                    if isinstance(self._head_constructor, Predicate):
                        recursions.append(
                            Recursion([
                                Clause(focus_node_heads[bcl_ind], focus_node),
                                recursive_clause,
                            ]))
                    else:
                        # if the filler predicate is used to construct heads, make sure the same head predicate is used
                        head_args = focus_node_heads[bcl_ind].get_arguments()
                        recursions.append(
                            Recursion([
                                Clause(
                                    Atom(c_head.get_predicate(), head_args),
                                    focus_node,
                                ),
                                recursive_clause,
                            ]))

                # extend the frontier - exclude recursive nodes
                to_add = [
                    x for x in self._hypothesis_space.successors(focus_node)
                    if "partner" not in self._hypothesis_space.nodes[x]
                ]
                frontier += to_add
                last_pointer_value = focus_node

            # reset the pointer value for next valid head
            self.reset_pointer(pointer_name, init_pointer_value)

        # set the pointer to the last explored clause
        self.reset_pointer(pointer_name, last_pointer_value)

        return recursions