예제 #1
0
def _turn_clause_to_interim_repr(clause: Clause, suffix: str = "_x"):
    head_vars = dict([
        (v, ind) for ind, v in enumerate(clause.get_head().get_variables())
    ])

    return [
        tuple([a.get_predicate()] + [
            head_vars[t] if isinstance(t, Variable) and t in head_vars else t
            for t in a.get_terms()
        ]) for a in clause.get_literals()
    ]
예제 #2
0
    def evaluate(self, clause: Clause, examples: Task,
                 covered: Sequence[Atom]):
        self._clauses_evaluated += 1

        pos, neg = examples.get_examples()
        covered_pos = len(pos.intersection(covered))
        covered_neg = len(neg.intersection(covered))
        clause_length = len(clause.get_literals())
        if self._return_upperbound:
            return (covered_pos - covered_neg - clause_length + 1), covered_pos
        return covered_pos - covered_neg - clause_length + 1
def _valid_positions(cl: Clause,allowed_positions_dict,allowed_reflexivity=[]):
    """
    Returns True iff the clause `cl` respects the allowed
    positions for constants as given by `allowed_positions_dict`, 
    and is not reflective (e.g. next(X,X)) when disallowed
    """
    for atom in cl.get_literals():
        pred = atom.get_predicate()
        args = atom.get_arguments()
        for i in range(len(args)):
            arg = args[i]
            # Constants must appear at the right places
            if isinstance(arg,Constant):
                if not i in allowed_positions_dict[arg][pred]:
                    return False
        
        # If all arguments are equal, this must be explicitly allowed
        if len(args) > 0 and all(args[i] == args[0] for i in range(len(args))) and pred not in allowed_reflexivity:
                return False
    return True
예제 #4
0
 def encode2(self, clause: Clause):
     head = clause.get_head()
     if not self._problemindices.__contains__(head.get_predicate()):
         self.addproblem(head.get_predicate(), 10)
     problems = self._encodingProblems.copy()
     primitives = self._encodingprimitives.copy()
     variables = self._encodingvariables.copy()
     problems[self._problemindices[head.get_predicate()]] = 1
     variables[self._variablesindices[head.get_variables()[0]]] = 1
     cijfer = 3
     for literal in clause.get_literals():
         startindexliteral = self._primitivesindices[
             literal.get_predicate()]
         fillin = False
         while fillin == False:
             if (primitives[startindexliteral] == 0):
                 primitives[startindexliteral] = cijfer
                 fillin = True
             else:
                 startindexliteral += 1
         startindexvariable = self._variablesindices[
             literal.get_variables()[0]]
         fillin = False
         while fillin == False:
             if (variables[startindexvariable] == 0):
                 variables[startindexvariable] = cijfer
                 fillin = True
             else:
                 startindexvariable += 1
         if (len(literal.get_variables()) == 2):
             startindexvariable = self._variablesindices[
                 literal.get_variables()[1]]
             fillin = False
             while fillin == False:
                 if (variables[startindexvariable] == 0):
                     variables[startindexvariable] = cijfer + 1
                     fillin = True
                 else:
                     startindexvariable += 1
         cijfer += 2
     return problems + primitives + variables
예제 #5
0
    def encode(self, clause: Clause):
        encodingClause = np.zeros(1850)
        vars = []
        set = {}
        index = 0
        for lit in [clause.get_head(), *clause.get_literals()]:
            var = ''
            for variable in lit.get_variables():
                var += variable.get_name()
            if var in set:
                index = set[var]
                vars[index][1].append(lit)
            else:
                set[var] = index
                index += 1
                list = [lit]
                if len(var) == 1:
                    value = 100000 * (ord(var) - 64) + 3500 * (
                        ord(var) - 64) + 130 * (ord(var) - 64)
                else:
                    if len(var) == 2:
                        if ord(var[0]) <= ord(var[-1]):
                            value = 100000 * (ord(var[0]) - 64) + 3500 * (
                                ord(var[0]) - 64) + 130 * (ord(var[-1]) - 64)
                        else:
                            value = 100000 * (ord(var[0]) - 64) + 3500 * (ord(
                                var[0]) - 64) + 130 * (ord(var[-1]) - 64) + 1
                    else:
                        if ord(var[0]) <= ord(var[1]) <= ord(var[2]):
                            value = 100000 * (ord(var[0]) - 64) + 3500 * (
                                ord(var[1]) - 64) + 130 * (ord(var[2]) - 64)
                        else:
                            if ord(var[0]) <= ord(var[2]) <= ord(var[1]):
                                value = 100000 * (ord(var[0]) - 64) + 3500 * (
                                    ord(var[2]) - 64) + 130 * (ord(var[1]) -
                                                               64) + 1
                            else:
                                if ord(var[1]) <= ord(var[0]) <= ord(var[2]):
                                    value = 100000 * (
                                        ord(var[1]) - 64) + 3500 * (
                                            ord(var[0]) -
                                            64) + 130 * (ord(var[2]) - 64) + 2
                                else:
                                    if ord(var[1]) <= ord(var[0]) <= ord(
                                            var[2]):
                                        value = 100000 * (
                                            ord(var[1]) - 64) + 3500 * (
                                                ord(var[0]) - 64) + 130 * (
                                                    ord(var[2]) - 64) + 3
                                    else:
                                        if ord(var[2]) <= ord(var[0]) <= ord(
                                                var[1]):
                                            value = 100000 * (
                                                ord(var[2]) - 64) + 3500 * (
                                                    ord(var[0]) - 64) + 130 * (
                                                        ord(var[1]) - 64) + 4
                                        else:
                                            value = 100000 * (
                                                ord(var[2]) - 64) + 3500 * (
                                                    ord(var[1]) - 64) + 130 * (
                                                        ord(var[0]) - 64) + 5

                vars.append((value, list))
        vars.sort()
        newClause = []
        for v in vars:
            newClause = newClause + v[1]
        encoding = [
            self.variableSubstition(newClause[i:i + 2])
            for i in range(len(newClause) - 1)
        ]
        for element in encoding:
            encodingClause[self._dictionary.get(tuple(element))] += 1
        encodingClause[1849] = len(clause.get_variables())
        return encodingClause