def test(model: torch.nn.Module, data_generator: torch_data.DataLoader, entities_count: int, device: torch.device, epoch_id: int, metric_suffix: str, ) -> METRICS: examples_count = 0.0 hits_at_1 = 0.0 hits_at_3 = 0.0 hits_at_10 = 0.0 mrr = 0.0 entity_ids = torch.arange(end=entities_count, device=device).unsqueeze(0) for head, relation, tail in data_generator: current_batch_size = head.size()[0] head, relation, tail = head.to(device), relation.to(device), tail.to(device) # [B] all_entities = entity_ids.repeat(current_batch_size, 1) # [B, e] heads = head.reshape(-1, 1).repeat(1, all_entities.size()[1]) # [B, e] relations = relation.reshape(-1, 1).repeat(1, all_entities.size()[1]) # [B, e] tails = tail.reshape(-1, 1).repeat(1, all_entities.size()[1]) # [B, e] # Check all possible tails triplets = torch.stack((heads, relations, all_entities), dim=2).reshape(-1, 3) # [B, e, 3]->[B*e, 3] tails_predictions = model.predict(triplets).reshape(current_batch_size, -1) # B*e ->[B, e] # Check all possible heads triplets = torch.stack((all_entities, relations, tails), dim=2).reshape(-1, 3) heads_predictions = model.predict(triplets).reshape(current_batch_size, -1) # Concat predictions predictions = torch.cat((tails_predictions, heads_predictions), dim=0) # [2*B, e] ground_truth_entity_id = torch.cat((tail.reshape(-1, 1), head.reshape(-1, 1))) # [2*B, 1] 2B is tail and head hits_at_1 += metric.hit_at_k(predictions, ground_truth_entity_id, device=device, k=1) hits_at_3 += metric.hit_at_k(predictions, ground_truth_entity_id, device=device, k=3) hits_at_10 += metric.hit_at_k(predictions, ground_truth_entity_id, device=device, k=10) mrr += metric.mrr(predictions, ground_truth_entity_id) # 对于一个query,若第一个正确答案排在第n位,则MRR得分就是 1/n 。(如果没有正确答案,则得分为0) examples_count += predictions.size()[0] # hits_at_1_score = hits_at_1 / examples_count * 100 # hits_at_3_score = hits_at_3 / examples_count * 100 # hits_at_10_score = hits_at_10 / examples_count * 100 # mrr_score = mrr / examples_count * 100 hits_at_1_score = hits_at_1 / examples_count hits_at_3_score = hits_at_3 / examples_count hits_at_10_score = hits_at_10 / examples_count mrr_score = mrr / examples_count # summary_writer.add_scalar('Metrics/Hits_1/' + metric_suffix, hits_at_1_score, global_step=epoch_id) # summary_writer.add_scalar('Metrics/Hits_3/' + metric_suffix, hits_at_3_score, global_step=epoch_id) # summary_writer.add_scalar('Metrics/Hits_10/' + metric_suffix, hits_at_10_score, global_step=epoch_id) # summary_writer.add_scalar('Metrics/MRR/' + metric_suffix, mrr_score, global_step=epoch_id) return hits_at_1_score, hits_at_3_score, hits_at_10_score, mrr_score
def test(model: torch.nn.Module, data_generator: torch_data.DataLoader, entities_count: int, summary_writer: tensorboard.SummaryWriter, device: torch.device, epoch_id: int, metric_suffix: str, ) -> METRICS: examples_count = 0.0 hits_at_1 = 0.0 hits_at_3 = 0.0 hits_at_10 = 0.0 mrr = 0.0 entity_ids = torch.arange(end=entities_count, device=device).unsqueeze(0) for head, relation, tail in data_generator: current_batch_size = head.size()[0] head, relation, tail = head.to(device), relation.to(device), tail.to(device) all_entities = entity_ids.repeat(current_batch_size, 1) heads = head.reshape(-1, 1).repeat(1, all_entities.size()[1]) relations = relation.reshape(-1, 1).repeat(1, all_entities.size()[1]) tails = tail.reshape(-1, 1).repeat(1, all_entities.size()[1]) # Check all possible tails triplets = torch.stack((heads, relations, all_entities), dim=2).reshape(-1, 3) tails_predictions = model.predict(triplets).reshape(current_batch_size, -1) # Check all possible heads triplets = torch.stack((all_entities, relations, tails), dim=2).reshape(-1, 3) heads_predictions = model.predict(triplets).reshape(current_batch_size, -1) # Concat predictions predictions = torch.cat((tails_predictions, heads_predictions), dim=0) ground_truth_entity_id = torch.cat((tail.reshape(-1, 1), head.reshape(-1, 1))) hits_at_1 += metric.hit_at_k(predictions, ground_truth_entity_id, device=device, k=1) hits_at_3 += metric.hit_at_k(predictions, ground_truth_entity_id, device=device, k=3) hits_at_10 += metric.hit_at_k(predictions, ground_truth_entity_id, device=device, k=10) mrr += metric.mrr(predictions, ground_truth_entity_id) examples_count += predictions.size()[0] hits_at_1_score = hits_at_1 / examples_count * 100 hits_at_3_score = hits_at_3 / examples_count * 100 hits_at_10_score = hits_at_10 / examples_count * 100 mrr_score = mrr / examples_count * 100 summary_writer.add_scalar('Metrics/Hits_1/' + metric_suffix, hits_at_1_score, global_step=epoch_id) summary_writer.add_scalar('Metrics/Hits_3/' + metric_suffix, hits_at_3_score, global_step=epoch_id) summary_writer.add_scalar('Metrics/Hits_10/' + metric_suffix, hits_at_10_score, global_step=epoch_id) summary_writer.add_scalar('Metrics/MRR/' + metric_suffix, mrr_score, global_step=epoch_id) return hits_at_1_score, hits_at_3_score, hits_at_10_score, mrr_score
def test_one_element_batch_with_last_element_correct(self): # given predictions = torch.tensor([[0.1, 0.0, 0.4, 0.9]]) ground_truth_idx = torch.tensor([0]) expected = 1 # when actual = metric.hit_at_k(predictions, ground_truth_idx, k=2, device=torch.device('cpu')) # then self.assertEqual(expected, actual)
def test_multiple_elements_batch(self): # given predictions = torch.tensor([[0.1, 0.0, 0.4, 0.9], [0.0, 0.9, 0.1, 1.0], [0.0, 0.1, 0.2, 0.8]]) k = 2 ground_truth_idx = torch.tensor([[1], [2], [3]]) # third row doesn't have hit in top 2 expected = 2 # when actual = metric.hit_at_k(predictions, ground_truth_idx, k=k, device=torch.device('cpu')) # then self.assertEqual(expected, actual)
def test( model: torch.nn.Module, data_generator: torch_data.DataLoader, entities_count: int, summary_writer: tensorboard.SummaryWriter, device: torch.device, epoch_id: int, metric_suffix: str, ) -> METRICS: examples_count = 0.0 hits_at_1 = 0.0 hits_at_3 = 0.0 hits_at_10 = 0.0 mrr = 0.0 entity_ids = torch.arange(end=entities_count, device=device).unsqueeze( 0 ) # Returns a 1-D tensor of size entities_count with values from 0 to entities_count, and then Returns a new tensor with a dimension of size one inserted at the specified position/ basically adding another dimension. for head, relation, tail in data_generator: # print(head, relation, tail) current_batch_size = head.size()[0] head, relation, tail = head.to(device), relation.to(device), tail.to( device) all_entities = entity_ids.repeat( current_batch_size, 1 ) # with torch.repeat(), you can specify the number of repeats for each dimension heads = head.reshape(-1, 1).repeat(1, all_entities.size()[1]) relations = relation.reshape(-1, 1).repeat(1, all_entities.size()[1]) tails = tail.reshape(-1, 1).repeat(1, all_entities.size()[1]) # Check all possible tails triplets = torch.stack((heads, relations, all_entities), dim=2).reshape(-1, 3) # print(triplets) tails_predictions = model.predict(triplets).reshape( current_batch_size, -1) # Check all possible heads triplets = torch.stack((all_entities, relations, tails), dim=2).reshape(-1, 3) heads_predictions = model.predict(triplets).reshape( current_batch_size, -1) # Concat predictions predictions = torch.cat((tails_predictions, heads_predictions), dim=0) ground_truth_entity_id = torch.cat( (tail.reshape(-1, 1), head.reshape(-1, 1))) # Each prediction is an array of N size, where N is no_of_Entity_in_KB, and there are no_of_samples_in_batch * 2 (head & tail) ground_truth and column level prediction # https://medium.com/@m_n_malaeb/recall-and-precision-at-k-for-recommender-systems-618483226c54 hits_at_1 += metric.hit_at_k(predictions, ground_truth_entity_id, device=device, k=1) hits_at_3 += metric.hit_at_k(predictions, ground_truth_entity_id, device=device, k=3) hits_at_10 += metric.hit_at_k(predictions, ground_truth_entity_id, device=device, k=10) mrr += metric.mrr(predictions, ground_truth_entity_id) examples_count += predictions.size()[0] hits_at_1_score = hits_at_1 / examples_count * 100 hits_at_3_score = hits_at_3 / examples_count * 100 hits_at_10_score = hits_at_10 / examples_count * 100 mrr_score = mrr / examples_count * 100 summary_writer.add_scalar('Metrics/Hits_1/' + metric_suffix, hits_at_1_score, global_step=epoch_id) summary_writer.add_scalar('Metrics/Hits_3/' + metric_suffix, hits_at_3_score, global_step=epoch_id) summary_writer.add_scalar('Metrics/Hits_10/' + metric_suffix, hits_at_10_score, global_step=epoch_id) summary_writer.add_scalar('Metrics/MRR/' + metric_suffix, mrr_score, global_step=epoch_id) return hits_at_1_score, hits_at_3_score, hits_at_10_score, mrr_score