def _test_score_all_triples(self, k: Optional[int], batch_size: int = 16): """Test score_all_triples. :param k: The number of triples to return. Set to None, to keep all. :param batch_size: The batch size to use for calculating scores. """ top_triples, top_scores = predict(model=self.instance, batch_size=batch_size, k=k) # check type assert torch.is_tensor(top_triples) assert torch.is_tensor(top_scores) assert top_triples.dtype == torch.long assert top_scores.dtype == torch.float32 # check shape actual_k, n_cols = top_triples.shape assert n_cols == 3 if k is None: assert actual_k == self.factory.num_entities**2 * self.factory.num_relations else: assert actual_k == min(k, self.factory.num_triples) assert top_scores.shape == (actual_k, ) # check ID ranges assert (top_triples >= 0).all() assert top_triples[:, [0, 2]].max() < self.instance.num_entities assert top_triples[:, 1].max() < self.instance.num_relations
def test_predict(self): """Test prediction workflow with inverse relations.""" predict(model=self.instance, k=10)