Пример #1
0
def predict_deep_el(embedding, token_idx_lookup, p_prior, model, batch,
                    ablation, entity_embeds, use_stacker):
    model.eval()
    if ablation == ['prior']:
        return torch.argmax(p_prior, dim=1)
    elif 'local_context' in ablation:
        left_splits, right_splits = embed_and_pack_batch(
            embedding, token_idx_lookup, batch['sentence_splits'])
        if 'document_context' in ablation:
            mention_embeds, desc_embeds = model.encoder(
                ((left_splits, right_splits), batch['embedded_page_content'],
                 batch['entity_page_mentions']))
        else:
            mention_embeds = model.encoder.mention_context_encoder(
                ((left_splits, right_splits), batch['embedded_page_content'],
                 batch['entity_page_mentions']))
        logits = Logits()
        calc_logits = lambda embeds, ids: logits(embeds, entity_embeds(ids))
        men_logits = calc_logits(mention_embeds, batch['candidate_ids'])
        if use_stacker:
            p_text, __ = model.calc_scores(
                (men_logits, torch.zeros_like(men_logits)),
                batch['candidate_mention_sim'], p_prior)
        else:
            p_text = men_logits
        return torch.argmax(p_text, dim=1)
    else:
        raise NotImplementedError
Пример #2
0
 def train_deep_el(self):
     for epoch_num in range(self.num_epochs):
         self.experiment.update_epoch(epoch_num)
         self._dataset = self.get_dataset()
         dataloader = DataLoader(dataset=self._dataset,
                                 batch_sampler=self.get_batch_sampler(),
                                 collate_fn=collate_deep_el)
         for batch_num, batch in enumerate(dataloader):
             self.model.train()
             self.optimizer.zero_grad()
             batch = tensors_to_device(batch, self.device)
             if self.use_adaptive_softmax:
                 labels = batch['label']
             else:
                 labels = self._get_labels_for_batch(
                     batch['label'], batch['candidate_ids'])
             left_splits, right_splits = embed_and_pack_batch(
                 self.embedding, self.token_idx_lookup,
                 batch['sentence_splits'])
             encoded = self.model.encoder(((left_splits, right_splits),
                                           batch['embedded_page_content'],
                                           batch['entity_page_mentions']))
             logits = self.calc_logits(encoded, batch['candidate_ids'])
             scores = self.model.calc_scores(logits,
                                             batch['candidate_mention_sim'])
             # batch['prior'])
             loss = self.calc_loss(scores, labels)
             loss.backward()
             torch.nn.utils.clip_grad_norm_(
                 itertools.chain(self.model.parameters(),
                                 self._get_adaptive_logits_params()),
                 self.clip_grad)
             self.optimizer.step()
             with torch.no_grad():
                 self.model.eval()
                 encoded_test = self.model.encoder(
                     ((left_splits,
                       right_splits), batch['embedded_page_content'],
                      batch['entity_page_mentions']))
                 logits_test = self.calc_logits(encoded_test,
                                                batch['candidate_ids'])
                 mention_probas, desc_probas = self.model.calc_scores(
                     logits_test, batch['candidate_mention_sim'])
                 mention_context_error = self._classification_error(
                     mention_probas, labels)
                 document_context_error = self._classification_error(
                     desc_probas, labels)
             self.experiment.record_metrics(
                 {
                     'mention_context_error': mention_context_error,
                     'document_context_error': document_context_error,
                     'loss': loss.item()
                 },
                 batch_num=batch_num)
def test_embed_and_pack_batch():
    embedding_dict = {'a': torch.tensor([1]), 'b': torch.tensor([2])}
    token_idx_lookup = dict(
        zip(embedding_dict.keys(), range(len(embedding_dict))))
    embedding = nn.Embedding.from_pretrained(
        torch.stack([embedding_dict[token] for token in token_idx_lookup]))
    sentence_splits_batch = [[['a', 'b', 'a', 'b'], ['b', 'a']],
                             [['b', 'a'], ['a', 'b', 'a', 'b']]]
    left = [torch.tensor(vec) for vec in [[[1], [2], [1], [2]], [[2], [1]]]]
    right = [torch.tensor(vec) for vec in [[[1], [2], [1], [2]], [[2], [1]]]]
    result = dt.embed_and_pack_batch(embedding, token_idx_lookup,
                                     sentence_splits_batch)
    assert torch.equal(result[0]['embeddings'].data,
                       nn.utils.rnn.pack_sequence(left).data)
    assert torch.equal(result[0]['embeddings'].batch_sizes,
                       nn.utils.rnn.pack_sequence(left).batch_sizes)
    assert result[0]['order'] == [0, 1]
    assert torch.equal(result[0]['embeddings'].data,
                       nn.utils.rnn.pack_sequence(right).data)
    assert torch.equal(result[0]['embeddings'].batch_sizes,
                       nn.utils.rnn.pack_sequence(right).batch_sizes)
    assert result[1]['order'] == [1, 0]