def get_ground_atom_index(self, ground_atom: GroundAtom) -> int: pred, consts = ground_atom if u.arity(pred) == 0: return self._ground_atom_base_index[pred] elif u.arity(pred) == 1: return self._ground_atom_base_index[pred] + self._constants.map[ consts[0]] elif u.arity(pred) == 2: return self._ground_atom_base_index[pred] + self._constants.map[ consts[0]] * self._number_of_constants + self._constants.map[ consts[1]] else: raise Exception() # TODO: something better
def get_ground_atoms(language_model: LanguageModel, program_template: ProgramTemplate) -> List[GroundAtom]: preds_ext = language_model.preds_ext preds_aux = program_template.preds_aux preds = preds_ext + preds_aux + [language_model.target] ground_atoms = [] for pred in preds: for constant_combination in itertools.product(language_model.constants, repeat=u.arity(pred)): ground_atoms.append(GroundAtom(pred, constant_combination)) return ground_atoms
def all_ground_atom_generator(self) -> Iterable[GroundAtom]: for pred in self._preds: arity = u.arity(pred) if arity == 0: yield GroundAtom(pred, ()) elif arity == 1: for c in self._constants: yield GroundAtom(pred, (c, )) elif arity == 2: for c1, c2 in itertools.product(self._constants, repeat=2): yield GroundAtom(pred, (c1, c2))
def __init__(self, language_model: LanguageModel, program_template: ProgramTemplate): self._constants = OrderedSet(language_model.constants) self._number_of_constants = len(language_model.constants) self._preds = language_model.preds_ext + program_template.preds_aux + [ language_model.target ] preds = self._preds # First element (index 0) is falsum # key: predicate, # value: index of predicate's first ground atom (amongst all ground atoms) self._ground_atom_base_index = {preds[0]: 1} for i in range(1, len(preds)): prev_pred = preds[i - 1] pred = preds[i] self._ground_atom_base_index[pred] = self._ground_atom_base_index[ prev_pred] + len(self._constants)**u.arity(prev_pred) self._len = self._ground_atom_base_index[preds[-1]] + len( self._constants)**u.arity(preds[-1])
def test_arity(self): p = Predicate('p', 2) self.assertEqual(u.arity(p), 2)
def from_pred(cls, pred: Predicate): args = [Variable(f'tmp{i}') for i in range(u.arity(pred))] groundedness = [False] * u.arity(pred) return cls(pred, args, groundedness)
def from_atom(cls, atom: Atom): pred, args = atom groundedness = [False] * u.arity(pred) return cls(pred, args, groundedness)
def from_ground_atom(cls, ground_atom: GroundAtom): pred, args = ground_atom groundedness = [True] * u.arity(pred) return cls(pred, args, groundedness)
def arity(self) -> int: return u.arity(self._pred)