Example #1
0
 def _calc_logits(self, encoded, candidate_entity_ids):
     if self.model_params.use_wiki2vec or self.model_params.use_sum_encoder:
         return Logits()(encoded, self.entity_embeds(candidate_entity_ids))
     else:
         desc_embeds, mention_context_embeds = encoded
         if self.model_params.use_adaptive_softmax:
             raise NotImplementedError('No longer supported')
         else:
             logits = Logits()
             desc_logits = logits(desc_embeds,
                                  self.entity_embeds(candidate_entity_ids))
             mention_logits = logits(
                 mention_context_embeds,
                 self.entity_embeds(candidate_entity_ids))
         return desc_logits, mention_logits
Example #2
0
def predict_deep_el(embedding, token_idx_lookup, p_prior, model, batch,
                    ablation, entity_embeds, use_stacker):
    model.eval()
    if ablation == ['prior']:
        return torch.argmax(p_prior, dim=1)
    elif 'local_context' in ablation:
        left_splits, right_splits = embed_and_pack_batch(
            embedding, token_idx_lookup, batch['sentence_splits'])
        if 'document_context' in ablation:
            mention_embeds, desc_embeds = model.encoder(
                ((left_splits, right_splits), batch['embedded_page_content'],
                 batch['entity_page_mentions']))
        else:
            mention_embeds = model.encoder.mention_context_encoder(
                ((left_splits, right_splits), batch['embedded_page_content'],
                 batch['entity_page_mentions']))
        logits = Logits()
        calc_logits = lambda embeds, ids: logits(embeds, entity_embeds(ids))
        men_logits = calc_logits(mention_embeds, batch['candidate_ids'])
        if use_stacker:
            p_text, __ = model.calc_scores(
                (men_logits, torch.zeros_like(men_logits)),
                batch['candidate_mention_sim'], p_prior)
        else:
            p_text = men_logits
        return torch.argmax(p_text, dim=1)
    else:
        raise NotImplementedError
Example #3
0
 def get_calc(context):
     if self.model_params.use_adaptive_softmax:
         softmax = self.adaptive_logits[context].log_prob
         calc = lambda hidden, _: softmax(hidden)
     else:
         calc_logits = Logits()
         softmax = Softmax()
         calc = lambda hidden, candidate_entity_ids: softmax(
             calc_logits(hidden, self.entity_embeds(candidate_entity_ids
                                                    )))
     return calc
Example #4
0
def predict_wiki2vec(embedding, token_idx_lookup, p_prior, model, batch,
                     ablation, entity_embeds):
    model.eval()
    context = model.encoder(batch['bag_of_nouns'])
    logits = Logits()
    calc_logits = lambda embeds, ids: logits(embeds, entity_embeds(ids))
    context_logits = calc_logits(context, batch['candidate_ids'])
    p_text, __ = model.calc_scores(
        (context_logits, torch.zeros_like(context_logits)),
        batch['candidate_mention_sim'], p_prior)
    return torch.argmax(p_text, dim=1)
Example #5
0
def predict_sum_encoder(embedding, token_idx_lookup, p_prior, model, batch,
                        ablation, entity_embeds, use_stacker):
    model.eval()
    context_bows = [
        Counter(to_idx(token_idx_lookup, token) for token in sentence)
        for sentence in batch['mention_sentence']
    ]
    doc_bows = batch['page_token_cnts']
    encoded = model.encoder(context_bows, doc_bows)
    logits = Logits()
    calc_logits = lambda embeds, ids: logits(embeds, entity_embeds(ids))
    men_logits = calc_logits(encoded, batch['candidate_ids'])
    if use_stacker:
        p_text, __ = model.calc_scores(
            (men_logits, torch.zeros_like(men_logits)),
            batch['candidate_mention_sim'], p_prior)
    else:
        p_text = men_logits
    return torch.argmax(p_text, dim=1)
Example #6
0
def test_description_encoder_loss():
  embed_len = 10
  word_embed_len = 15
  num_entities = 20
  batch_size = 2
  desc_len = 9
  pad_vector = torch.randn((word_embed_len, ))
  entity_embeds = torch.nn.Embedding(num_entities,
                                     embed_len,
                                     _weight=torch.randn((num_entities, embed_len)))
  desc_enc = DescriptionEncoder(word_embed_len,
                                entity_embeds,
                                pad_vector)
  descriptions = torch.randn((batch_size, desc_len, word_embed_len))
  desc_embeds = desc_enc(descriptions)
  labels_for_batch = torch.arange(batch_size, dtype=torch.long)
  calc_logits = Logits()
  criterion = nn.CrossEntropyLoss()
  candidates = entity_embeds(torch.randint(0, num_entities, (batch_size,), dtype=torch.long))
  logits = calc_logits(desc_embeds, candidates)
  loss = criterion(logits, labels_for_batch)
  assert isinstance(loss, torch.Tensor)
  assert loss.shape == torch.Size([])
Example #7
0
def test_tester(monkeypatch, myMock):
    dataset = [{
        'label':
        0,
        'sentence_splits': [['a', 'b', 'c'], ['c', 'd']],
        'candidate_ids':
        torch.tensor([0, 1]),
        'embedded_page_content':
        torch.tensor([[1], [-2], [2], [3], [-3], [4]]),
        'entity_page_mentions':
        torch.tensor([[1], [-2], [0], [3], [0], [4]]),
        'p_prior':
        torch.tensor([0.1, 0.9])
    }, {
        'label':
        2,
        'sentence_splits': [['a', 'b', 'c'], ['c', 'd']],
        'candidate_ids':
        torch.tensor([2, 1]),
        'embedded_page_content':
        torch.tensor([[1], [-2], [2], [3], [-3], [4]]),
        'entity_page_mentions':
        torch.tensor([[1], [-2], [0], [3], [0], [4]]),
        'p_prior':
        torch.tensor([0.1, 0.9])
    }, {
        'label':
        1,
        'sentence_splits': [['a', 'b', 'c'], ['c', 'd']],
        'candidate_ids':
        torch.tensor([3, 1]),
        'embedded_page_content':
        torch.tensor([[1], [-2], [2], [3], [-3], [4]]),
        'entity_page_mentions':
        torch.tensor([[1], [-2], [0], [3], [0], [4]]),
        'p_prior':
        torch.tensor([0.1, 0.9])
    }]
    num_entities = 10
    embed_len = 200
    batch_size = 3
    entity_embeds = nn.Embedding(num_entities,
                                 embed_len,
                                 _weight=torch.randn(
                                     (num_entities, embed_len)))
    embedding_dict = dict(
        zip(string.ascii_lowercase, [
            torch.tensor([i]) for i, char in enumerate(string.ascii_lowercase)
        ]))
    token_idx_lookup = dict(
        zip(embedding_dict.keys(), range(len(embedding_dict))))
    embedding = nn.Embedding.from_pretrained(
        torch.stack([embedding_dict[token] for token in token_idx_lookup]))
    vector_to_return = entity_embeds(torch.tensor([1, 1, 1]))
    model = get_mock_model(vector_to_return)
    device = None
    batch_sampler = BatchSampler(RandomSampler(dataset), batch_size, True)
    mock_experiment = create_autospec(Experiment, instance=True)
    calc_logits = Logits()
    softmax = Softmax()
    logits_and_softmax = {
        'mention':
        lambda hidden, candidate_ids_or_targets: softmax(
            calc_logits(hidden, entity_embeds(candidate_ids_or_targets)))
    }
    with monkeypatch.context() as m:
        m.setattr(nn, 'DataParallel', _.identity)
        m.setattr(u, 'tensors_to_device', lambda batch, device: batch)
        tester = t.Tester(
            dataset=dataset,
            batch_sampler=batch_sampler,
            model=model,
            logits_and_softmax=logits_and_softmax,
            embedding=embedding,
            token_idx_lookup=token_idx_lookup,
            device=device,
            experiment=mock_experiment,
            ablation=['prior', 'local_context', 'document_context'],
            use_adaptive_softmax=False)
        assert tester.test() == (1, 3)
        labels_for_batch = tester._get_labels_for_batch(
            torch.tensor([elem['label'] for elem in dataset]),
            torch.tensor([[1, 0], [4, 5], [1, 0]]))
        assert torch.equal(labels_for_batch, torch.tensor([1, -1, 0]))