def filter_scores(self, batch, filter_col, scores): positive_filter, relation_filter = create_sparse_positive_filter_( hrt_batch=batch, all_pos_triples=self.dl.all, filter_col = filter_col ) return filter_scores_( scores=scores, filter_batch=positive_filter, )
def test_create_sparse_positive_filter_(self): """Test method create_sparse_positive_filter_.""" batch_size = 4 factory = Nations().training all_triples = factory.mapped_triples batch = all_triples[:batch_size, :] # head based filter sparse_positives, relation_filter = create_sparse_positive_filter_( hrt_batch=batch, all_pos_triples=all_triples, relation_filter=None, filter_col=0) # preprocessing for faster lookup triples = set() for trip in all_triples.detach().numpy(): triples.add(tuple(map(int, trip))) # check that all found positives are positive for batch_id, entity_id in sparse_positives: same = batch[batch_id, 1:] assert (int(entity_id), ) + tuple(map(int, same)) in triples
def test_filter_corrupted_triples(self): """Test the filter_corrupted_triples() function.""" batch_size = 2 num_entities = 4 all_pos_triples = torch.tensor( [ [0, 1, 2], [1, 2, 3], [1, 3, 3], [3, 4, 1], [0, 2, 2], [3, 1, 2], [1, 2, 0], ], dtype=torch.long, ) batch = torch.tensor( [ [0, 1, 2], [1, 2, 3], ], dtype=torch.long, ) head_filter_mask = torch.tensor( [ [True, False, False, False], [False, True, False, False], ], dtype=torch.bool, ) tail_filter_mask = torch.tensor( [ [False, False, True, False], [False, False, False, True], ], dtype=torch.bool, ) exp_head_filter_mask = torch.tensor( [ [True, False, False, True], [False, True, False, False], ], dtype=torch.bool, ) exp_tail_filter_mask = torch.tensor( [ [False, False, True, False], [True, False, False, True], ], dtype=torch.bool, ) assert batch.shape == (batch_size, 3) assert head_filter_mask.shape == (batch_size, num_entities) assert tail_filter_mask.shape == (batch_size, num_entities) # Test head scores head_scores = torch.randn(batch_size, num_entities, generator=self.generator) old_head_scores = head_scores.detach().clone() positive_filter_heads, relation_filter = create_sparse_positive_filter_( hrt_batch=batch, all_pos_triples=all_pos_triples, relation_filter=None, filter_col=0, ) filtered_head_scores = filter_scores_( scores=head_scores, filter_batch=positive_filter_heads, ) # Assert in-place modification mask = torch.isfinite(head_scores) assert (head_scores[mask] == filtered_head_scores[mask]).all() assert not torch.isfinite(filtered_head_scores[~mask]).any() # Assert correct filtering assert (old_head_scores[~exp_head_filter_mask] == filtered_head_scores[~exp_head_filter_mask]).all() assert not torch.isfinite( filtered_head_scores[exp_head_filter_mask]).any() # Test tail scores tail_scores = torch.randn(batch_size, num_entities, generator=self.generator) old_tail_scores = tail_scores.detach().clone() positive_filter_tails, _ = create_sparse_positive_filter_( hrt_batch=batch, all_pos_triples=all_pos_triples, relation_filter=relation_filter, filter_col=2, ) filtered_tail_scores = filter_scores_( scores=tail_scores, filter_batch=positive_filter_tails, ) # Assert in-place modification mask = torch.isfinite(tail_scores) assert (tail_scores[mask] == filtered_tail_scores[mask]).all() assert not torch.isfinite(filtered_tail_scores[~mask]).any() # Assert correct filtering assert (old_tail_scores[~exp_tail_filter_mask] == filtered_tail_scores[~exp_tail_filter_mask]).all() assert not torch.isfinite( filtered_tail_scores[exp_tail_filter_mask]).any()