def _calc_logits(self, encoded, candidate_entity_ids): if self.model_params.use_wiki2vec or self.model_params.use_sum_encoder: return Logits()(encoded, self.entity_embeds(candidate_entity_ids)) else: desc_embeds, mention_context_embeds = encoded if self.model_params.use_adaptive_softmax: raise NotImplementedError('No longer supported') else: logits = Logits() desc_logits = logits(desc_embeds, self.entity_embeds(candidate_entity_ids)) mention_logits = logits( mention_context_embeds, self.entity_embeds(candidate_entity_ids)) return desc_logits, mention_logits
def predict_deep_el(embedding, token_idx_lookup, p_prior, model, batch, ablation, entity_embeds, use_stacker): model.eval() if ablation == ['prior']: return torch.argmax(p_prior, dim=1) elif 'local_context' in ablation: left_splits, right_splits = embed_and_pack_batch( embedding, token_idx_lookup, batch['sentence_splits']) if 'document_context' in ablation: mention_embeds, desc_embeds = model.encoder( ((left_splits, right_splits), batch['embedded_page_content'], batch['entity_page_mentions'])) else: mention_embeds = model.encoder.mention_context_encoder( ((left_splits, right_splits), batch['embedded_page_content'], batch['entity_page_mentions'])) logits = Logits() calc_logits = lambda embeds, ids: logits(embeds, entity_embeds(ids)) men_logits = calc_logits(mention_embeds, batch['candidate_ids']) if use_stacker: p_text, __ = model.calc_scores( (men_logits, torch.zeros_like(men_logits)), batch['candidate_mention_sim'], p_prior) else: p_text = men_logits return torch.argmax(p_text, dim=1) else: raise NotImplementedError
def get_calc(context): if self.model_params.use_adaptive_softmax: softmax = self.adaptive_logits[context].log_prob calc = lambda hidden, _: softmax(hidden) else: calc_logits = Logits() softmax = Softmax() calc = lambda hidden, candidate_entity_ids: softmax( calc_logits(hidden, self.entity_embeds(candidate_entity_ids ))) return calc
def predict_wiki2vec(embedding, token_idx_lookup, p_prior, model, batch, ablation, entity_embeds): model.eval() context = model.encoder(batch['bag_of_nouns']) logits = Logits() calc_logits = lambda embeds, ids: logits(embeds, entity_embeds(ids)) context_logits = calc_logits(context, batch['candidate_ids']) p_text, __ = model.calc_scores( (context_logits, torch.zeros_like(context_logits)), batch['candidate_mention_sim'], p_prior) return torch.argmax(p_text, dim=1)
def predict_sum_encoder(embedding, token_idx_lookup, p_prior, model, batch, ablation, entity_embeds, use_stacker): model.eval() context_bows = [ Counter(to_idx(token_idx_lookup, token) for token in sentence) for sentence in batch['mention_sentence'] ] doc_bows = batch['page_token_cnts'] encoded = model.encoder(context_bows, doc_bows) logits = Logits() calc_logits = lambda embeds, ids: logits(embeds, entity_embeds(ids)) men_logits = calc_logits(encoded, batch['candidate_ids']) if use_stacker: p_text, __ = model.calc_scores( (men_logits, torch.zeros_like(men_logits)), batch['candidate_mention_sim'], p_prior) else: p_text = men_logits return torch.argmax(p_text, dim=1)
def test_description_encoder_loss(): embed_len = 10 word_embed_len = 15 num_entities = 20 batch_size = 2 desc_len = 9 pad_vector = torch.randn((word_embed_len, )) entity_embeds = torch.nn.Embedding(num_entities, embed_len, _weight=torch.randn((num_entities, embed_len))) desc_enc = DescriptionEncoder(word_embed_len, entity_embeds, pad_vector) descriptions = torch.randn((batch_size, desc_len, word_embed_len)) desc_embeds = desc_enc(descriptions) labels_for_batch = torch.arange(batch_size, dtype=torch.long) calc_logits = Logits() criterion = nn.CrossEntropyLoss() candidates = entity_embeds(torch.randint(0, num_entities, (batch_size,), dtype=torch.long)) logits = calc_logits(desc_embeds, candidates) loss = criterion(logits, labels_for_batch) assert isinstance(loss, torch.Tensor) assert loss.shape == torch.Size([])
def test_tester(monkeypatch, myMock): dataset = [{ 'label': 0, 'sentence_splits': [['a', 'b', 'c'], ['c', 'd']], 'candidate_ids': torch.tensor([0, 1]), 'embedded_page_content': torch.tensor([[1], [-2], [2], [3], [-3], [4]]), 'entity_page_mentions': torch.tensor([[1], [-2], [0], [3], [0], [4]]), 'p_prior': torch.tensor([0.1, 0.9]) }, { 'label': 2, 'sentence_splits': [['a', 'b', 'c'], ['c', 'd']], 'candidate_ids': torch.tensor([2, 1]), 'embedded_page_content': torch.tensor([[1], [-2], [2], [3], [-3], [4]]), 'entity_page_mentions': torch.tensor([[1], [-2], [0], [3], [0], [4]]), 'p_prior': torch.tensor([0.1, 0.9]) }, { 'label': 1, 'sentence_splits': [['a', 'b', 'c'], ['c', 'd']], 'candidate_ids': torch.tensor([3, 1]), 'embedded_page_content': torch.tensor([[1], [-2], [2], [3], [-3], [4]]), 'entity_page_mentions': torch.tensor([[1], [-2], [0], [3], [0], [4]]), 'p_prior': torch.tensor([0.1, 0.9]) }] num_entities = 10 embed_len = 200 batch_size = 3 entity_embeds = nn.Embedding(num_entities, embed_len, _weight=torch.randn( (num_entities, embed_len))) embedding_dict = dict( zip(string.ascii_lowercase, [ torch.tensor([i]) for i, char in enumerate(string.ascii_lowercase) ])) token_idx_lookup = dict( zip(embedding_dict.keys(), range(len(embedding_dict)))) embedding = nn.Embedding.from_pretrained( torch.stack([embedding_dict[token] for token in token_idx_lookup])) vector_to_return = entity_embeds(torch.tensor([1, 1, 1])) model = get_mock_model(vector_to_return) device = None batch_sampler = BatchSampler(RandomSampler(dataset), batch_size, True) mock_experiment = create_autospec(Experiment, instance=True) calc_logits = Logits() softmax = Softmax() logits_and_softmax = { 'mention': lambda hidden, candidate_ids_or_targets: softmax( calc_logits(hidden, entity_embeds(candidate_ids_or_targets))) } with monkeypatch.context() as m: m.setattr(nn, 'DataParallel', _.identity) m.setattr(u, 'tensors_to_device', lambda batch, device: batch) tester = t.Tester( dataset=dataset, batch_sampler=batch_sampler, model=model, logits_and_softmax=logits_and_softmax, embedding=embedding, token_idx_lookup=token_idx_lookup, device=device, experiment=mock_experiment, ablation=['prior', 'local_context', 'document_context'], use_adaptive_softmax=False) assert tester.test() == (1, 3) labels_for_batch = tester._get_labels_for_batch( torch.tensor([elem['label'] for elem in dataset]), torch.tensor([[1, 0], [4, 5], [1, 0]])) assert torch.equal(labels_for_batch, torch.tensor([1, -1, 0]))