예제 #1
0
    def _embed(self, x):
        vocab_idx = self.vocab_idx
        embeddings = []
        for sequence in x:
            padding_length = sequence.size(0)
            if self._get_from_cache(sequence) is None:
                flair_sentence = flair.data.Sentence()
                for index in sequence:
                    index = index.item()
                    if index == self.pad_index:
                        break  # skip padding
                    padding_length = padding_length - 1
                    token = vocab_idx.get(index, '[UNK]')
                    flair_sentence.add_token(token)
                self.embeddings.embed(flair_sentence)
                sentence_embedding = torch.stack(
                    [token.embedding.cpu() for token in flair_sentence.tokens])
                self._add_to_cache(sequence, sentence_embedding)
            else:
                sentence_embedding = self._get_from_cache(sequence)
                padding_length = padding_length - sentence_embedding.size(0)
            if padding_length:
                sentence_embedding = torch.cat(
                    (sentence_embedding,
                     torch.zeros(padding_length,
                                 sentence_embedding.size(-1),
                                 device=sentence_embedding.device)))
            embeddings.append(sentence_embedding)

        return torch.stack(embeddings).to(self.device)
예제 #2
0
    def _embed_list_of_tensors(self, batch: List[torch.Tensor]):
        embeddings = []
        for sequence in batch:
            index_sequence = [
                index.item() for index in sequence.cpu().clone().detach()
                if index.item() != self.pad_index
            ]
            sentence_embedding = self._get_index_sequence_embedding(
                index_sequence)

            embeddings.append(sentence_embedding.to(self.device))
        return torch.stack(embeddings).to(self.device)
예제 #3
0
    def _embed_list_of_sequences(self, batch: List[List[torch.Tensor]]):
        embeddings = []
        for sequences in batch:
            sequence_embeddings = []
            for sequence in sequences:
                index_sequence = [
                    index.item() for index in sequence.cpu().clone().detach()
                    if index.item() != self.pad_index
                ]

                sentence_embedding = self._get_index_sequence_embedding(
                    index_sequence).to(self.device)
                sequence_embeddings.append(sentence_embedding)

            if sequence_embeddings:
                sequence_embeddings = self.encoder(
                    torch.stack(sequence_embeddings, 0))
            else:
                sequence_embeddings = torch.zeros(self.embedding_dim,
                                                  requires_grad=False,
                                                  device=self.device)

            embeddings.append(sequence_embeddings)
        return torch.stack(embeddings).to(self.device)