def get_sentence_embedding(self, tokens: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor: """Auxiliar function that extracts sentence embeddings for a single sentence. :param tokens: sequences [batch_size x seq_len] :param lengths: lengths [batch_size] :return: torch.Tensor [batch_size x hidden_size] """ # When using just one GPU this should not change behavior # but when splitting batches across GPU the tokens have padding # from the entire original batch if self.trainer and self.trainer.use_dp and self.trainer.num_gpus > 1: tokens = tokens[:, :lengths.max()] encoder_out = self.encoder(tokens, lengths) if self.scalar_mix: embeddings = self.scalar_mix(encoder_out["all_layers"], encoder_out["mask"]) elif self.layer >= 0 and self.layer < self.encoder.num_layers: embeddings = encoder_out["all_layers"][self.layer] else: raise Exception("Invalid model layer {}.".format(self.layer)) if self.hparams.pool == "default": sentemb = encoder_out["sentemb"] elif self.hparams.pool == "max": sentemb = max_pooling(tokens, embeddings, self.encoder.tokenizer.padding_index) elif self.hparams.pool == "avg": sentemb = average_pooling( tokens, embeddings, encoder_out["mask"], self.encoder.tokenizer.padding_index, ) elif self.hparams.pool == "cls": sentemb = embeddings[:, 0, :] elif self.hparams.pool == "cls+avg": cls_sentemb = embeddings[:, 0, :] avg_sentemb = average_pooling( tokens, embeddings, encoder_out["mask"], self.encoder.tokenizer.padding_index, ) sentemb = torch.cat((cls_sentemb, avg_sentemb), dim=1) else: raise Exception("Invalid pooling technique.") return sentemb
def test_get_sentence_embedding(self): self.ranker.scalar_mix = None self.ranker.layer = 12 # tokens from ["hello world", "how are your?"] tokens = torch.tensor([[29733, 4139, 0, 0], [2231, 137, 57374, 8]]) lengths = torch.tensor([2, 4]) encoder_out = self.ranker.encoder(tokens, lengths) # Expected sentence output with pool = 'default' hparams = Namespace(**{"encoder_model": "LASER", "pool": "default"}) self.ranker.hparams = hparams expected = encoder_out["sentemb"] sentemb = self.ranker.get_sentence_embedding(tokens, lengths) self.assertTrue(torch.equal(sentemb, expected)) # Expected sentence output with pool = 'max' hparams = Namespace( # LASER always used default... we need to pretend our encoder is another one **{ "encoder_model": "other", "pool": "max" }) self.ranker.hparams = hparams # Max pooling is tested in test_utils.py expected = max_pooling(tokens, encoder_out["wordemb"], 0) sentemb = self.ranker.get_sentence_embedding(tokens, lengths) self.assertTrue(torch.equal(sentemb, expected)) # Expected sentence output with pool = 'avg' hparams = Namespace( # LASER always used default... we need to pretend our encoder is another one **{ "encoder_model": "other", "pool": "avg" }) self.ranker.hparams = hparams # AVG pooling is tested in test_utils.py expected = average_pooling(tokens, encoder_out["wordemb"], encoder_out["mask"], 0) sentemb = self.ranker.get_sentence_embedding(tokens, lengths) self.assertTrue(torch.equal(sentemb, expected)) # Expected sentence output with pool = 'cls' hparams = Namespace( # LASER always used default... we need to pretend our encoder is another one **{ "encoder_model": "other", "pool": "cls" }) self.ranker.hparams = hparams expected = encoder_out["wordemb"][:, 0, :] sentemb = self.ranker.get_sentence_embedding(tokens, lengths) self.assertTrue(torch.equal(sentemb, expected))
def test_get_sentence_embedding(self): self.estimator.scalar_mix = None self.estimator.encoder.eval() # tokens from ["hello world", "how are your?"] tokens = torch.tensor([[29733, 4139, 0, 0], [2231, 137, 57374, 8]]) lengths = torch.tensor([2, 4]) encoder_out = self.estimator.encoder(tokens, lengths) # Expected sentence output with pool = 'max' hparams = Namespace(**{"pool": "max"}) self.estimator.hparams = hparams # Max pooling is tested in test_utils.py expected = max_pooling(tokens, encoder_out["wordemb"], 0) sentemb = self.estimator.get_sentence_embedding(tokens, lengths) self.assertTrue(torch.equal(sentemb, expected)) sentemb = self.estimator.get_sentence_embedding(tokens, lengths) self.assertTrue(torch.equal(sentemb, expected)) # Expected sentence output with pool = 'avg' hparams = Namespace(**{"pool": "avg"}) self.estimator.hparams = hparams # AVG pooling is tested in test_utils.py expected = average_pooling(tokens, encoder_out["wordemb"], encoder_out["mask"], 0) sentemb = self.estimator.get_sentence_embedding(tokens, lengths) self.assertTrue(torch.equal(sentemb, expected)) # Expected sentence output with pool = 'cls' hparams = Namespace(**{"pool": "cls"}) self.estimator.hparams = hparams expected = encoder_out["wordemb"][:, 0, :] sentemb = self.estimator.get_sentence_embedding(tokens, lengths) self.assertTrue(torch.equal(sentemb, expected)) # Expected sentence output with pool = 'cls+avg' hparams = Namespace(**{"pool": "cls+avg"}) self.estimator.hparams = hparams cls_embedding = encoder_out["wordemb"][:, 0, :] avg_embedding = average_pooling(tokens, encoder_out["wordemb"], encoder_out["mask"], 0) expected = torch.cat((cls_embedding, avg_embedding), dim=1) sentemb = self.estimator.get_sentence_embedding(tokens, lengths) self.assertTrue(torch.equal(sentemb, expected)) # Expected sentence output with pool = 'default' hparams = Namespace(**{"pool": "default"}) self.estimator.hparams = hparams expected = encoder_out["sentemb"] sentemb = self.estimator.get_sentence_embedding(tokens, lengths) self.assertTrue(torch.equal(sentemb, expected))
def test_max_pooling(self): tokens = torch.tensor([[2, 2, 2, 2], [2, 2, 2, 0]]) embeddings = torch.tensor( [ [ [3.1416, 4.1416, 5.1416], [3.1416, 3.1416, 3.1416], [3.1416, 3.1416, 3.1416], [3.1416, 3.1416, 3.1416], ], [ [6.1416, 3.1416, -3.1416], [3.1416, 3.1416, -1.1416], [3.1416, 7.1416, -3.1416], [0.0000, 0.0000, 0.0000], ], ] ) expected = torch.tensor([[3.1416, 4.1416, 5.1416], [6.1416, 7.1416, -1.1416]]) result = max_pooling(tokens, embeddings, 0) self.assertTrue(torch.equal(result, expected))