def test(model: torch.nn.Module, data_generator: torch_data.DataLoader, entities_count: int,
          device: torch.device, epoch_id: int, metric_suffix: str,
         ) -> METRICS:
    examples_count = 0.0
    hits_at_1 = 0.0
    hits_at_3 = 0.0
    hits_at_10 = 0.0
    mrr = 0.0

    entity_ids = torch.arange(end=entities_count, device=device).unsqueeze(0)
    for head, relation, tail in data_generator:
        current_batch_size = head.size()[0]

        head, relation, tail = head.to(device), relation.to(device), tail.to(device)  # [B]
        all_entities = entity_ids.repeat(current_batch_size, 1)  # [B, e]
        heads = head.reshape(-1, 1).repeat(1, all_entities.size()[1])  # [B, e]
        relations = relation.reshape(-1, 1).repeat(1, all_entities.size()[1])  # [B, e]
        tails = tail.reshape(-1, 1).repeat(1, all_entities.size()[1])  # [B, e]

        # Check all possible tails
        triplets = torch.stack((heads, relations, all_entities), dim=2).reshape(-1, 3)  # [B, e, 3]->[B*e, 3]
        tails_predictions = model.predict(triplets).reshape(current_batch_size, -1)  # B*e ->[B, e]
        # Check all possible heads
        triplets = torch.stack((all_entities, relations, tails), dim=2).reshape(-1, 3)
        heads_predictions = model.predict(triplets).reshape(current_batch_size, -1)

        # Concat predictions
        predictions = torch.cat((tails_predictions, heads_predictions), dim=0)  # [2*B, e]
        ground_truth_entity_id = torch.cat((tail.reshape(-1, 1), head.reshape(-1, 1)))  # [2*B, 1]  2B is tail and head

        hits_at_1 += metric.hit_at_k(predictions, ground_truth_entity_id, device=device, k=1)
        hits_at_3 += metric.hit_at_k(predictions, ground_truth_entity_id, device=device, k=3)
        hits_at_10 += metric.hit_at_k(predictions, ground_truth_entity_id, device=device, k=10)
        mrr += metric.mrr(predictions, ground_truth_entity_id)  # 对于一个query,若第一个正确答案排在第n位,则MRR得分就是 1/n 。(如果没有正确答案,则得分为0)

        examples_count += predictions.size()[0]

    # hits_at_1_score = hits_at_1 / examples_count * 100
    # hits_at_3_score = hits_at_3 / examples_count * 100
    # hits_at_10_score = hits_at_10 / examples_count * 100
    # mrr_score = mrr / examples_count * 100
    hits_at_1_score = hits_at_1 / examples_count
    hits_at_3_score = hits_at_3 / examples_count
    hits_at_10_score = hits_at_10 / examples_count
    mrr_score = mrr / examples_count
    # summary_writer.add_scalar('Metrics/Hits_1/' + metric_suffix, hits_at_1_score, global_step=epoch_id)
    # summary_writer.add_scalar('Metrics/Hits_3/' + metric_suffix, hits_at_3_score, global_step=epoch_id)
    # summary_writer.add_scalar('Metrics/Hits_10/' + metric_suffix, hits_at_10_score, global_step=epoch_id)
    # summary_writer.add_scalar('Metrics/MRR/' + metric_suffix, mrr_score, global_step=epoch_id)

    return hits_at_1_score, hits_at_3_score, hits_at_10_score, mrr_score
示例#2
0
def test(model: torch.nn.Module, data_generator: torch_data.DataLoader, entities_count: int,
         summary_writer: tensorboard.SummaryWriter, device: torch.device, epoch_id: int, metric_suffix: str,
         ) -> METRICS:
    examples_count = 0.0
    hits_at_1 = 0.0
    hits_at_3 = 0.0
    hits_at_10 = 0.0
    mrr = 0.0

    entity_ids = torch.arange(end=entities_count, device=device).unsqueeze(0)
    for head, relation, tail in data_generator:
        current_batch_size = head.size()[0]

        head, relation, tail = head.to(device), relation.to(device), tail.to(device)
        all_entities = entity_ids.repeat(current_batch_size, 1)
        heads = head.reshape(-1, 1).repeat(1, all_entities.size()[1])
        relations = relation.reshape(-1, 1).repeat(1, all_entities.size()[1])
        tails = tail.reshape(-1, 1).repeat(1, all_entities.size()[1])

        # Check all possible tails
        triplets = torch.stack((heads, relations, all_entities), dim=2).reshape(-1, 3)
        tails_predictions = model.predict(triplets).reshape(current_batch_size, -1)
        # Check all possible heads
        triplets = torch.stack((all_entities, relations, tails), dim=2).reshape(-1, 3)
        heads_predictions = model.predict(triplets).reshape(current_batch_size, -1)

        # Concat predictions
        predictions = torch.cat((tails_predictions, heads_predictions), dim=0)
        ground_truth_entity_id = torch.cat((tail.reshape(-1, 1), head.reshape(-1, 1)))

        hits_at_1 += metric.hit_at_k(predictions, ground_truth_entity_id, device=device, k=1)
        hits_at_3 += metric.hit_at_k(predictions, ground_truth_entity_id, device=device, k=3)
        hits_at_10 += metric.hit_at_k(predictions, ground_truth_entity_id, device=device, k=10)
        mrr += metric.mrr(predictions, ground_truth_entity_id)

        examples_count += predictions.size()[0]

    hits_at_1_score = hits_at_1 / examples_count * 100
    hits_at_3_score = hits_at_3 / examples_count * 100
    hits_at_10_score = hits_at_10 / examples_count * 100
    mrr_score = mrr / examples_count * 100
    summary_writer.add_scalar('Metrics/Hits_1/' + metric_suffix, hits_at_1_score, global_step=epoch_id)
    summary_writer.add_scalar('Metrics/Hits_3/' + metric_suffix, hits_at_3_score, global_step=epoch_id)
    summary_writer.add_scalar('Metrics/Hits_10/' + metric_suffix, hits_at_10_score, global_step=epoch_id)
    summary_writer.add_scalar('Metrics/MRR/' + metric_suffix, mrr_score, global_step=epoch_id)

    return hits_at_1_score, hits_at_3_score, hits_at_10_score, mrr_score
示例#3
0
    def test_one_element_batch_with_last_element_correct(self):
        # given
        predictions = torch.tensor([[0.1, 0.0, 0.4, 0.9]])
        ground_truth_idx = torch.tensor([0])
        expected = 1

        # when
        actual = metric.hit_at_k(predictions,
                                 ground_truth_idx,
                                 k=2,
                                 device=torch.device('cpu'))

        # then
        self.assertEqual(expected, actual)
示例#4
0
    def test_multiple_elements_batch(self):
        # given
        predictions = torch.tensor([[0.1, 0.0, 0.4, 0.9], [0.0, 0.9, 0.1, 1.0],
                                    [0.0, 0.1, 0.2, 0.8]])
        k = 2
        ground_truth_idx = torch.tensor([[1], [2], [3]])
        # third row doesn't have hit in top 2
        expected = 2

        # when
        actual = metric.hit_at_k(predictions,
                                 ground_truth_idx,
                                 k=k,
                                 device=torch.device('cpu'))

        # then
        self.assertEqual(expected, actual)
示例#5
0
def test(
    model: torch.nn.Module,
    data_generator: torch_data.DataLoader,
    entities_count: int,
    summary_writer: tensorboard.SummaryWriter,
    device: torch.device,
    epoch_id: int,
    metric_suffix: str,
) -> METRICS:
    examples_count = 0.0
    hits_at_1 = 0.0
    hits_at_3 = 0.0
    hits_at_10 = 0.0
    mrr = 0.0

    entity_ids = torch.arange(end=entities_count, device=device).unsqueeze(
        0
    )  # Returns a 1-D tensor of size entities_count with values from 0 to entities_count, and then Returns a new tensor with a dimension of size one inserted at the specified position/ basically adding another dimension.
    for head, relation, tail in data_generator:
        # print(head, relation, tail)
        current_batch_size = head.size()[0]

        head, relation, tail = head.to(device), relation.to(device), tail.to(
            device)
        all_entities = entity_ids.repeat(
            current_batch_size, 1
        )  # with torch.repeat(), you can specify the number of repeats for each dimension
        heads = head.reshape(-1, 1).repeat(1, all_entities.size()[1])
        relations = relation.reshape(-1, 1).repeat(1, all_entities.size()[1])
        tails = tail.reshape(-1, 1).repeat(1, all_entities.size()[1])

        # Check all possible tails
        triplets = torch.stack((heads, relations, all_entities),
                               dim=2).reshape(-1, 3)
        # print(triplets)
        tails_predictions = model.predict(triplets).reshape(
            current_batch_size, -1)
        # Check all possible heads
        triplets = torch.stack((all_entities, relations, tails),
                               dim=2).reshape(-1, 3)
        heads_predictions = model.predict(triplets).reshape(
            current_batch_size, -1)

        # Concat predictions
        predictions = torch.cat((tails_predictions, heads_predictions), dim=0)
        ground_truth_entity_id = torch.cat(
            (tail.reshape(-1, 1), head.reshape(-1, 1)))

        # Each prediction is an array of N size, where N is no_of_Entity_in_KB, and there are no_of_samples_in_batch * 2 (head & tail) ground_truth and column level prediction

        # https://medium.com/@m_n_malaeb/recall-and-precision-at-k-for-recommender-systems-618483226c54
        hits_at_1 += metric.hit_at_k(predictions,
                                     ground_truth_entity_id,
                                     device=device,
                                     k=1)
        hits_at_3 += metric.hit_at_k(predictions,
                                     ground_truth_entity_id,
                                     device=device,
                                     k=3)
        hits_at_10 += metric.hit_at_k(predictions,
                                      ground_truth_entity_id,
                                      device=device,
                                      k=10)
        mrr += metric.mrr(predictions, ground_truth_entity_id)

        examples_count += predictions.size()[0]

    hits_at_1_score = hits_at_1 / examples_count * 100
    hits_at_3_score = hits_at_3 / examples_count * 100
    hits_at_10_score = hits_at_10 / examples_count * 100
    mrr_score = mrr / examples_count * 100
    summary_writer.add_scalar('Metrics/Hits_1/' + metric_suffix,
                              hits_at_1_score,
                              global_step=epoch_id)
    summary_writer.add_scalar('Metrics/Hits_3/' + metric_suffix,
                              hits_at_3_score,
                              global_step=epoch_id)
    summary_writer.add_scalar('Metrics/Hits_10/' + metric_suffix,
                              hits_at_10_score,
                              global_step=epoch_id)
    summary_writer.add_scalar('Metrics/MRR/' + metric_suffix,
                              mrr_score,
                              global_step=epoch_id)

    return hits_at_1_score, hits_at_3_score, hits_at_10_score, mrr_score