def predict_deep_el(embedding, token_idx_lookup, p_prior, model, batch, ablation, entity_embeds, use_stacker): model.eval() if ablation == ['prior']: return torch.argmax(p_prior, dim=1) elif 'local_context' in ablation: left_splits, right_splits = embed_and_pack_batch( embedding, token_idx_lookup, batch['sentence_splits']) if 'document_context' in ablation: mention_embeds, desc_embeds = model.encoder( ((left_splits, right_splits), batch['embedded_page_content'], batch['entity_page_mentions'])) else: mention_embeds = model.encoder.mention_context_encoder( ((left_splits, right_splits), batch['embedded_page_content'], batch['entity_page_mentions'])) logits = Logits() calc_logits = lambda embeds, ids: logits(embeds, entity_embeds(ids)) men_logits = calc_logits(mention_embeds, batch['candidate_ids']) if use_stacker: p_text, __ = model.calc_scores( (men_logits, torch.zeros_like(men_logits)), batch['candidate_mention_sim'], p_prior) else: p_text = men_logits return torch.argmax(p_text, dim=1) else: raise NotImplementedError
def train_deep_el(self): for epoch_num in range(self.num_epochs): self.experiment.update_epoch(epoch_num) self._dataset = self.get_dataset() dataloader = DataLoader(dataset=self._dataset, batch_sampler=self.get_batch_sampler(), collate_fn=collate_deep_el) for batch_num, batch in enumerate(dataloader): self.model.train() self.optimizer.zero_grad() batch = tensors_to_device(batch, self.device) if self.use_adaptive_softmax: labels = batch['label'] else: labels = self._get_labels_for_batch( batch['label'], batch['candidate_ids']) left_splits, right_splits = embed_and_pack_batch( self.embedding, self.token_idx_lookup, batch['sentence_splits']) encoded = self.model.encoder(((left_splits, right_splits), batch['embedded_page_content'], batch['entity_page_mentions'])) logits = self.calc_logits(encoded, batch['candidate_ids']) scores = self.model.calc_scores(logits, batch['candidate_mention_sim']) # batch['prior']) loss = self.calc_loss(scores, labels) loss.backward() torch.nn.utils.clip_grad_norm_( itertools.chain(self.model.parameters(), self._get_adaptive_logits_params()), self.clip_grad) self.optimizer.step() with torch.no_grad(): self.model.eval() encoded_test = self.model.encoder( ((left_splits, right_splits), batch['embedded_page_content'], batch['entity_page_mentions'])) logits_test = self.calc_logits(encoded_test, batch['candidate_ids']) mention_probas, desc_probas = self.model.calc_scores( logits_test, batch['candidate_mention_sim']) mention_context_error = self._classification_error( mention_probas, labels) document_context_error = self._classification_error( desc_probas, labels) self.experiment.record_metrics( { 'mention_context_error': mention_context_error, 'document_context_error': document_context_error, 'loss': loss.item() }, batch_num=batch_num)
def test_embed_and_pack_batch(): embedding_dict = {'a': torch.tensor([1]), 'b': torch.tensor([2])} token_idx_lookup = dict( zip(embedding_dict.keys(), range(len(embedding_dict)))) embedding = nn.Embedding.from_pretrained( torch.stack([embedding_dict[token] for token in token_idx_lookup])) sentence_splits_batch = [[['a', 'b', 'a', 'b'], ['b', 'a']], [['b', 'a'], ['a', 'b', 'a', 'b']]] left = [torch.tensor(vec) for vec in [[[1], [2], [1], [2]], [[2], [1]]]] right = [torch.tensor(vec) for vec in [[[1], [2], [1], [2]], [[2], [1]]]] result = dt.embed_and_pack_batch(embedding, token_idx_lookup, sentence_splits_batch) assert torch.equal(result[0]['embeddings'].data, nn.utils.rnn.pack_sequence(left).data) assert torch.equal(result[0]['embeddings'].batch_sizes, nn.utils.rnn.pack_sequence(left).batch_sizes) assert result[0]['order'] == [0, 1] assert torch.equal(result[0]['embeddings'].data, nn.utils.rnn.pack_sequence(right).data) assert torch.equal(result[0]['embeddings'].batch_sizes, nn.utils.rnn.pack_sequence(right).batch_sizes) assert result[1]['order'] == [1, 0]