Пример #1
0
class ElmoWordEmbedding(torch.nn.Module):
    """
    Compute a single layer of ELMo word representations.
    """
    def __init__(self,
                 options_file: str,
                 weight_file: str,
                 vocab_to_cache: List[str],
                 do_layer_norm: bool = False,
                 dropout: float = 0.5,
                 requires_grad: bool = False,
                 projection_dim: int = None) -> None:
        super(ElmoWordEmbedding, self).__init__()

        self._elmo = ElmoTokenEmbedder(options_file=options_file,
                                       weight_file=weight_file,
                                       do_layer_norm=do_layer_norm,
                                       dropout=dropout,
                                       requires_grad=requires_grad,
                                       projection_dim=projection_dim,
                                       vocab_to_cache=vocab_to_cache)

        self._projection = self._elmo._projection

    def get_output_dim(self):
        if self._projection is not None:
            return self._projection.out_features
        else:
            return self._elmo.get_output_dim()

    def forward(self, word_inputs: torch.Tensor) -> torch.Tensor:
        if len(word_inputs.shape) == 1:
            word_inputs = word_inputs.unsqueeze(dim=-1)
        return self._elmo.forward(word_inputs, word_inputs)

    @property
    def weight(self):
        embedding_weight = torch.cat(
            (self.word_embedding.weight, self.word_embedding.weight), dim=1)
        if self._projection:
            embedding_weight = self._projection(embedding_weight)
        return embedding_weight

    @property
    def num_embeddings(self):
        return self.word_embedding.num_embeddings

    @property
    def word_embedding(self):
        return self._elmo._elmo._elmo_lstm._word_embedding
Пример #2
0
    def embeddings_returner(self, vocab=None):
        '''
        Either the name of the pretrained model to use (e.g. bert-base-uncased),or the path to the .tar.gz
        file with the model weights.
        :param args: vocab_size and vocab is needed only when pretrained embeddings is used.
        :return: embedder
        '''
        '''
        "bert-base-uncased", do_lower_case=True
        "bert-base-cased" , do_lower_case=False
        https://github.com/huggingface/pytorch-transformers/issues/712
        https://qiita.com/uedake722/items/b7f4b75b4d77d9bd358b
        '''
        if self.embedding_strategy == 'bert':
            self.bertmodel_dir = ''
            if self.ifbert_use_whichmodel == 'general':
                self.bertmodel_dir += 'bert-base-uncased/'  # recomendded ver is uncased, in original repository
                self.bertmodel_relative_dirpath = self.bert_src_dir + self.bertmodel_dir

                # included in pytorch_transformers, so we replace it with model name itself
                self.bert_weight_filepath = copy.copy('bert-base-uncased')

            elif self.ifbert_use_whichmodel == 'scibert':
                self.bertmodel_dir += 'scibert_scivocab_uncased/'  # recomendded ver is uncased, in original repository
                self.bertmodel_relative_dirpath = self.bert_src_dir + self.bertmodel_dir
                self.bert_weight_filepath = self.bertmodel_relative_dirpath + 'weights.tar.gz'

            elif self.ifbert_use_whichmodel == 'biobert':
                self.bertmodel_dir += 'biobert_v1.1_pubmed/'  # currently cased version only supported
                self.bertmodel_relative_dirpath = self.bert_src_dir + self.bertmodel_dir
                self.bert_weight_filepath = self.bertmodel_relative_dirpath + 'weights.tar.gz'  # including bert_config.json and bin.

            # Load embedder
            bert_embedder = PretrainedBertEmbedder(
                pretrained_model=self.bert_weight_filepath,
                top_layer_only=self.bert_top_layer_only,
                requires_grad=self.emb_requires_grad)
            return bert_embedder, bert_embedder.get_output_dim(
            ), BasicTextFieldEmbedder({'tokens': bert_embedder},
                                      allow_unmatched_keys=True)

        elif self.embedding_strategy == 'elmo':
            if self.ifelmo_use_whichmodel == 'general':
                options_file = 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway_5.5B/elmo_2x4096_512_2048cnn_2xhighway_5.5B_options.json'
                weight_file = 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway_5.5B/elmo_2x4096_512_2048cnn_2xhighway_5.5B_weights.hdf5'
            elif self.ifelmo_use_whichmodel == 'pubmed':
                options_file = 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/contributed/pubmed/elmo_2x4096_512_2048cnn_2xhighway_options.json'
                weight_file = 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/contributed/pubmed/elmo_2x4096_512_2048cnn_2xhighway_weights_PubMed_only.hdf5'
            elif self.ifelmo_use_whichmodel == 'bioelmo':
                options_file = self.elmo_src_dir + 'BioELMo/weights/biomed_elmo_options.json'
                weight_file = self.elmo_src_dir + 'BioELMo/weights/biomed_elmo_weights.hdf5'
            else:
                options_file = -1
                weight_file = -1
            assert options_file != -1
            elmo_embedder = ElmoTokenEmbedder(
                options_file=options_file,
                weight_file=weight_file,
                requires_grad=self.emb_requires_grad)
            return elmo_embedder, elmo_embedder.get_output_dim(
            ), BasicTextFieldEmbedder({'tokens': elmo_embedder})

        elif self.embedding_strategy == 'pretrained':

            print('\nGloVe pretrained vocab loading\n')

            if 'glove' in self.args.ifpretrained_use_whichmodel:
                embedding_dim = 300
            else:
                embedding_dim = 200

            pretrain_emb_embedder = Embedding.from_params(
                vocab=vocab,
                params=Params({
                    'pretrained_file': self.glove_embeddings_file,
                    'embedding_dim': embedding_dim,
                    'trainable': False,
                    'padding_index': 0
                }))

            return pretrain_emb_embedder, pretrain_emb_embedder.get_output_dim(
            ), BasicTextFieldEmbedder({'tokens': pretrain_emb_embedder})