コード例 #1
0
    def _find_model_function(self):
        embedding_dim = self.configuration['embed_size']
        embedding_matrix_filepath = self.base_data_dir + 'embedding_matrix'
        if os.path.exists(embedding_matrix_filepath):
            embedding_matrix = super()._load_object(embedding_matrix_filepath)
        else:
            embedding_filepath = self.configuration['embedding_filepath']
            embedding_matrix = embedding._read_embeddings_from_text_file(embedding_filepath, embedding_dim,
                                                                         self.vocab, namespace='tokens')
            super()._save_object(embedding_matrix_filepath, embedding_matrix)
        embedding_matrix = embedding_matrix.to(self.configuration['device'])
        token_embedding = Embedding(num_embeddings=self.vocab.get_vocab_size(namespace='tokens'),
                                    embedding_dim=embedding_dim, padding_index=0, vocab_namespace='tokens',
                                    trainable=self._is_train_token_embeddings(), weight=embedding_matrix)
        # the embedder maps the input tokens to the appropriate embedding matrix
        word_embedder: TextFieldEmbedder = BasicTextFieldEmbedder({"tokens": token_embedding})

        position_embedding = Embedding(num_embeddings=self.vocab.get_vocab_size(namespace='position'),
                                    embedding_dim=self._get_position_embeddings_dim(), padding_index=0)
        position_embedder: TextFieldEmbedder = BasicTextFieldEmbedder({"position": position_embedding},
                                                                    # we'll be ignoring masks so we'll need to set this to True
                                                                    allow_unmatched_keys=True)

        model_function = self._find_model_function_pure()
        model = model_function(
            word_embedder,
            position_embedder,
            self.distinct_polarities,
            self.vocab,
            self.configuration,
        )
        self._print_args(model)
        model = model.to(self.configuration['device'])
        return model
コード例 #2
0
    def _find_model_function(self):
        embedding_dim = self.configuration['embed_size']
        embedding_matrix_filepath = self.base_data_dir + 'embedding_matrix'
        if os.path.exists(embedding_matrix_filepath):
            embedding_matrix = super()._load_object(embedding_matrix_filepath)
        else:
            embedding_filepath = self.configuration['embedding_filepath']
            embedding_matrix = embedding._read_embeddings_from_text_file(
                embedding_filepath,
                embedding_dim,
                self.vocab,
                namespace='tokens')
            super()._save_object(embedding_matrix_filepath, embedding_matrix)
        token_embedding = Embedding(
            num_embeddings=self.vocab.get_vocab_size(namespace='tokens'),
            embedding_dim=embedding_dim,
            padding_index=0,
            vocab_namespace='tokens',
            trainable=False,
            weight=embedding_matrix)
        # the embedder maps the input tokens to the appropriate embedding matrix
        word_embedder: TextFieldEmbedder = BasicTextFieldEmbedder(
            {"tokens": token_embedding})

        position_embedding = Embedding(
            num_embeddings=self.vocab.get_vocab_size(namespace='position'),
            embedding_dim=25,
            padding_index=0)
        position_embedder: TextFieldEmbedder = BasicTextFieldEmbedder(
            {"position": position_embedding},
            # we'll be ignoring masks so we'll need to set this to True
            allow_unmatched_keys=True)

        # bert_embedder = PretrainedBertEmbedder(
        #     pretrained_model=self.bert_file_path,
        #     top_layer_only=True,  # conserve memory
        #     requires_grad=True
        # )
        # bert_word_embedder: TextFieldEmbedder = BasicTextFieldEmbedder({"bert": bert_embedder},
        #                                                                  # we'll be ignoring masks so we'll need to set this to True
        #                                                                  allow_unmatched_keys=True)
        bert_word_embedder = self._get_bert_word_embedder()

        model = pytorch_models.AsMilSimultaneouslyBert(
            word_embedder,
            position_embedder,
            self.distinct_categories,
            self.distinct_polarities,
            self.vocab,
            self.configuration,
            bert_word_embedder=bert_word_embedder)
        self._print_args(model)
        model = model.to(self.configuration['device'])
        return model
コード例 #3
0
def load_glove_embeddings(vocab):
    """
    Loads pre-trained GloVe embeddings.

    Returns
    -------
    TextFieldEmbedder
    """
    embedding_matrix = _read_embeddings_from_text_file(file_uri="https://s3-us-west-2.amazonaws.com/allennlp/datasets/glove/glove.840B.300d.txt.gz",
                                                       embedding_dim=300,
                                                       vocab=vocab)
    print("Pre-trained Glove loaded:", embedding_matrix.size())

    token_embedding = Embedding(num_embeddings=vocab.get_vocab_size('tokens'),
                                embedding_dim=300,
                                weight=embedding_matrix)
    word_embeddings: TextFieldEmbedder = BasicTextFieldEmbedder({"tokens": token_embedding})
    return word_embeddings
コード例 #4
0
 discourse_validation_dataset = discourse_reader.read(
     cached_path(DISCOURSE_VALIDATION_PATH))
 vocab = Vocabulary.from_instances(claim_train_dataset + \
                                   claim_validation_dataset + \
                                   discourse_train_dataset + \
                                   discourse_validation_dataset)
 discourse_dict = {
     'RESULTS': 0,
     'METHODS': 1,
     'CONCLUSIONS': 2,
     'BACKGROUND': 3,
     'OBJECTIVE': 4
 }
 claim_dict = {'0': 0, '1': 1}
 embedding_matrix = _read_embeddings_from_text_file(
     file_uri=PUBMED_PRETRAINED_FILE,
     embedding_dim=EMBEDDING_DIM,
     vocab=vocab)
 token_embedding = Embedding(num_embeddings=vocab.get_vocab_size('tokens'),
                             embedding_dim=EMBEDDING_DIM,
                             weight=embedding_matrix)
 word_embeddings = BasicTextFieldEmbedder({"tokens": token_embedding})
 sentence_encoder = PytorchSeq2VecWrapper(
     torch.nn.LSTM(EMBEDDING_DIM,
                   HIDDEN_DIM,
                   batch_first=True,
                   bidirectional=True))
 model = DiscourseClaimCrfClassifier(
     vocab,
     word_embeddings,
     sentence_encoder,
 )