def test_fold_long_sequences(self):
        # Let's just say [PAD] is 0, [CLS] is 1, and [SEP] is 2
        token_ids = torch.LongTensor([
            [1, 101, 102, 103, 104, 2, 1, 105, 106, 107, 108, 2, 1, 109, 2],
            [1, 201, 202, 203, 204, 2, 1, 205, 206, 207, 208, 2, 0, 0, 0],
            [1, 301, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        ])  # Shape: [3, 15]
        segment_concat_mask = (token_ids > 0).long()

        folded_token_ids = torch.LongTensor([
            [1, 101, 102, 103, 104, 2],
            [1, 105, 106, 107, 108, 2],
            [1, 109, 2, 0, 0, 0],
            [1, 201, 202, 203, 204, 2],
            [1, 205, 206, 207, 208, 2],
            [0, 0, 0, 0, 0, 0],
            [1, 301, 2, 0, 0, 0],
            [0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0],
        ])
        folded_segment_concat_mask = (folded_token_ids > 0).long()

        token_embedder = PretrainedTransformerEmbedder("bert-base-uncased",
                                                       max_length=6)

        (
            folded_token_ids_out,
            folded_segment_concat_mask_out,
            _,
        ) = token_embedder._fold_long_sequences(token_ids, segment_concat_mask)
        assert (folded_token_ids_out == folded_token_ids).all()
        assert (folded_segment_concat_mask_out == folded_segment_concat_mask
                ).all()
    def test_unfold_long_sequences(self):
        # Let's just say [PAD] is 0, [CLS] is xxx1, and [SEP] is xxx2
        # We assume embeddings are 1-dim and are the same as indices
        embeddings = torch.LongTensor([
            [1001, 101, 102, 103, 104, 1002],
            [1011, 105, 106, 107, 108, 1012],
            [1021, 109, 1022, 0, 0, 0],
            [2001, 201, 202, 203, 204, 2002],
            [2011, 205, 206, 207, 208, 2012],
            [0, 0, 0, 0, 0, 0],
            [3001, 301, 3002, 0, 0, 0],
            [0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0],
        ]).unsqueeze(-1)
        mask = (embeddings > 0).long()

        unfolded_embeddings = torch.LongTensor([
            [1001, 101, 102, 103, 104, 105, 106, 107, 108, 109, 1022],
            [2001, 201, 202, 203, 204, 205, 206, 207, 208, 2012, 0],
            [3001, 301, 3002, 0, 0, 0, 0, 0, 0, 0, 0],
        ]).unsqueeze(-1)

        token_embedder = PretrainedTransformerEmbedder("bert-base-uncased",
                                                       max_length=6)

        unfolded_embeddings_out = token_embedder._unfold_long_sequences(
            embeddings, mask, unfolded_embeddings.size(0), 15)
        assert (unfolded_embeddings_out == unfolded_embeddings).all()
 def __init__(
     self,
     model_name: str,
     max_length: int = None,
     train_parameters: bool = True,
     last_layer_only: bool = True,
     override_weights_file: Optional[str] = None,
     override_weights_strip_prefix: Optional[str] = None,
     load_weights: bool = True,
     gradient_checkpointing: Optional[bool] = None,
     tokenizer_kwargs: Optional[Dict[str, Any]] = None,
     transformer_kwargs: Optional[Dict[str, Any]] = None,
     sub_token_mode: Optional[str] = "avg",
 ) -> None:
     super().__init__()
     # The matched version v.s. mismatched
     self._matched_embedder = PretrainedTransformerEmbedder(
         model_name,
         max_length=max_length,
         train_parameters=train_parameters,
         last_layer_only=last_layer_only,
         override_weights_file=override_weights_file,
         override_weights_strip_prefix=override_weights_strip_prefix,
         load_weights=load_weights,
         gradient_checkpointing=gradient_checkpointing,
         tokenizer_kwargs=tokenizer_kwargs,
         transformer_kwargs=transformer_kwargs,
     )
     self.sub_token_mode = sub_token_mode
Esempio n. 4
0
 def __init__(self,
              model_name: str,
              max_length: int = None,
              requires_grad: bool = True) -> None:
     super().__init__()
     # The matched version v.s. mismatched
     self._matched_embedder = PretrainedTransformerEmbedder(
         model_name, max_length, requires_grad)
Esempio n. 5
0
 def __init__(self,
              model_name: str,
              max_length: int = None,
              train_parameters: bool = True) -> None:
     super().__init__()
     # The matched version v.s. mismatched
     self._matched_embedder = PretrainedTransformerEmbedder(
         model_name,
         max_length=max_length,
         train_parameters=train_parameters)
Esempio n. 6
0
 def test_embeddings_resize(self):
     regular_token_embedder = PretrainedTransformerEmbedder(
         "bert-base-cased")
     assert (regular_token_embedder.transformer_model.embeddings.
             word_embeddings.num_embeddings == 28996)
     tokenizer_kwargs = {"additional_special_tokens": ["<NEW_TOKEN>"]}
     enhanced_token_embedder = PretrainedTransformerEmbedder(
         "bert-base-cased", tokenizer_kwargs=tokenizer_kwargs)
     assert (enhanced_token_embedder.transformer_model.embeddings.
             word_embeddings.num_embeddings == 28997)
Esempio n. 7
0
    def __init__(self, vocab, num_labels):
        super().__init__(vocab)
        self.bert_embedder = PretrainedTransformerEmbedder('bert-base-uncased')
        self.pooler = ClsPooler(self.bert_embedder.get_output_dim())

        self.linear = torch.nn.Sequential(
            torch.nn.Dropout(0.1),
            torch.nn.Linear(in_features=768, out_features=num_labels))

        self.accuracy = CategoricalAccuracy()
        self.loss_function = torch.nn.CrossEntropyLoss()
def emb_returner(config):
    if config.bert_name == 'japanese-bert':
        huggingface_model = 'cl-tohoku/bert-base-japanese'
    else:
        huggingface_model = 'dummy'
        print(config.bert_name, 'are not supported')
        exit()
    bert_embedder = PretrainedTransformerEmbedder(
        model_name="cl-tohoku/bert-base-japanese")
    return bert_embedder, bert_embedder.get_output_dim(
    ), BasicTextFieldEmbedder({'tokens': bert_embedder})
Esempio n. 9
0
 def emb_returner(self):
     if self.args.bert_name == 'bert-base-uncased':
         huggingface_model = 'bert-base-uncased'
     else:
         huggingface_model = 'dummy'
         print(self.args.bert_name, 'are not supported')
         exit()
     bert_embedder = PretrainedTransformerEmbedder(
         model_name=huggingface_model)
     return bert_embedder, bert_embedder.get_output_dim(
     ), BasicTextFieldEmbedder({'tokens': bert_embedder},
                               allow_unmatched_keys=True)
Esempio n. 10
0
    def from_params(cls,
                    vocab: Vocabulary,
                    params: Params,
                    constructor_to_call=None,
                    constructor_to_inspect=None) -> 'BertModel':
        #initialize the class using JSON params
        embedder_params = params.pop("text_field_embedder")
        token_params = embedder_params.pop("tokens")
        embedding = PretrainedTransformerEmbedder.from_params(
            vocab=vocab, params=token_params)
        text_field_embedder = BasicTextFieldEmbedder(
            token_embedders={'tokens': embedding})
        #         text_field_embedder = TextFieldEmbedder.from_params(vocab, embedder_params)

        seq2vec_encoder_params = params.pop("seq2vec_encoder")
        seq2vec_encoder = Seq2VecEncoder.from_params(seq2vec_encoder_params)

        initializer = InitializerApplicator(
        )  #.from_params(params.pop("initializer", []))

        params.assert_empty(cls.__name__)
        #         print(cls)
        return cls(vocab=vocab,
                   text_field_embedder=text_field_embedder,
                   seq2vec_encoder=seq2vec_encoder,
                   initializer=initializer)
Esempio n. 11
0
    def __init__(
        self,
        vocab: Vocabulary,
        transformer_model: str = "roberta-large",
        override_weights_file: Optional[str] = None,
        override_weights_strip_prefix: Optional[str] = None,
        **kwargs
    ) -> None:
        super().__init__(vocab, **kwargs)

        self._text_field_embedder = PretrainedTransformerEmbedder(
            transformer_model,
            override_weights_file=override_weights_file,
            override_weights_strip_prefix=override_weights_strip_prefix,
        )
        self._text_field_embedder = BasicTextFieldEmbedder(
            {"tokens": self._text_field_embedder})
        self._pooler = BertPooler(
            transformer_model,
            override_weights_file=override_weights_file,
            override_weights_strip_prefix=override_weights_strip_prefix,
            dropout=0.1,
        )

        self._linear_layer = torch.nn.Linear(
            self._text_field_embedder.get_output_dim(), 1)
        self._linear_layer.weight.data.normal_(mean=0.0, std=0.02)
        self._linear_layer.bias.data.zero_()

        self._loss = torch.nn.CrossEntropyLoss()
        self._accuracy = CategoricalAccuracy()
Esempio n. 12
0
def build_model(vocab: Vocabulary, wbrun: Any) -> Model:
    """
    Build the Model object, along with the embedder and encoder.

    :param vocab: The pre-instantiated vocabulary object.
    :return Model: The model object itself.
    """
    log.debug("Building the model.")
    # vocab_size = vocab.get_vocab_size("tokens")

    # TokenEmbedder object.
    bert_embedder = PretrainedTransformerEmbedder("bert-base-uncased")

    # TextFieldEmbedder that wraps TokenEmbedder objects. Each
    # TokenEmbedder output from one TokenIndexer--the data produced
    # by a TextField is a dict {names:representations}, hence
    # TokenEmbedders have corresponding names.
    embedder: TextFieldEmbedder = BasicTextFieldEmbedder(
        {"tokens": bert_embedder}
    )

    log.debug("Embedder built.")
    encoder = BertPooler("bert-base-uncased", requires_grad=True)
    # encoder = PytorchSeq2VecWrapper(torch.nn.LSTM(768,20,batch_first=True))
    log.debug("Encoder built.")

    return BertLinearClassifier(vocab, embedder, encoder, wbrun).cuda(0)
Esempio n. 13
0
def build_model(vocab: Vocabulary, bert_model: str = None) -> Model:
    if bert_model:
        embedder = BasicTextFieldEmbedder({"bert": PretrainedTransformerEmbedder(model_name=bert_model,
                                                                                 train_parameters=True)})
        encoder = BertPooler(pretrained_model=bert_model, requires_grad=True)
    else:
        # (3) How to get vectors for each Token ID:
        # (3.1) embed each token
        token_embedding = Embedding(embedding_dim=10, num_embeddings=vocab.get_vocab_size("token_vocab"))
        # pretrained_file='https://allennlp.s3.amazonaws.com/datasets/glove/glove.6B.50d.txt.gz'

        # (3.2) embed each character in each token
        character_embedding = Embedding(embedding_dim=3, num_embeddings=vocab.get_vocab_size("character_vocab"))
        cnn_encoder = CnnEncoder(embedding_dim=3, num_filters=4, ngram_filter_sizes=[3,])
        token_encoder = TokenCharactersEncoder(character_embedding, cnn_encoder)
        # (3.3) embed the POS of each token
        pos_tag_embedding = Embedding(embedding_dim=10, num_embeddings=vocab.get_vocab_size("pos_tag_vocab"))

        # Each TokenEmbedders embeds its input, and the result is concatenated in an arbitrary (but consistent) order
        # cf: https://docs.allennlp.org/master/api/modules/text_field_embedders/basic_text_field_embedder/
        embedder = BasicTextFieldEmbedder(
            token_embedders={"tokens": token_embedding,
                             "token_characters": token_encoder,
                             "pos_tags": pos_tag_embedding}
        )  # emb_dim = 10 + 4 + 10 = 24
        encoder = BagOfEmbeddingsEncoder(embedding_dim=24, averaged=True)
        #                                                  ^
        # average the embeddings across time, rather than simply summing
        # (ie. we will divide the summed embeddings by the length of the sentence).
    return SimpleClassifier(vocab, embedder, encoder)
 def test_xlnet_token_type_ids(self):
     token_embedder = PretrainedTransformerEmbedder("xlnet-base-cased")
     token_ids = torch.LongTensor([[1, 2, 3], [2, 3, 4]])
     mask = torch.ones_like(token_ids).bool()
     type_ids = torch.zeros_like(token_ids)
     type_ids[1, 1] = 1
     token_embedder(token_ids, mask, type_ids)
def build_adversarial_transformer_model(vocab: Vocabulary, transformer_model: str) -> Model:
    print("Building the model")
    vocab_size = vocab.get_vocab_size("tokens")
    embedding = PretrainedTransformerEmbedder(model_name=transformer_model)
    embedder = BasicTextFieldEmbedder(token_embedders={'bert_tokens': embedding})
    encoder = BertPooler(transformer_model)
    return SimpleClassifier(vocab, embedder, encoder)
def build_pool_transformer_model(vocab: Vocabulary, transformer_model: str) -> Model:
    print("Building the model")
    vocab_size = vocab.get_vocab_size("tokens")
    embedding = PretrainedTransformerEmbedder(model_name=transformer_model)
    embedder = BasicTextFieldEmbedder(token_embedders={'bert_tokens': embedding})
    encoder = BagOfEmbeddingsEncoder(embedding_dim=embedder.get_output_dim(), averaged=True)
    #encoder = ClsPooler(embedding_dim=embedder.get_output_dim())
    return SimpleClassifier(vocab, embedder, encoder)
 def __init__(
     self,
     model_name: str,
     max_length: int = None,
     train_parameters: bool = True,
     last_layer_only: bool = True,
     gradient_checkpointing: Optional[bool] = None,
 ) -> None:
     super().__init__()
     # The matched version v.s. mismatched
     self._matched_embedder = PretrainedTransformerEmbedder(
         model_name,
         max_length=max_length,
         train_parameters=train_parameters,
         last_layer_only=last_layer_only,
         gradient_checkpointing=gradient_checkpointing,
     )
 def test_big_token_type_ids(self):
     token_embedder = PretrainedTransformerEmbedder("roberta-base")
     token_ids = torch.LongTensor([[1, 2, 3], [2, 3, 4]])
     mask = torch.ones_like(token_ids)
     type_ids = torch.zeros_like(token_ids)
     type_ids[1, 1] = 1
     with pytest.raises(ValueError):
         token_embedder(token_ids, mask, type_ids)
 def test_forward_runs_when_initialized_from_params(self):
     # This code just passes things off to pytorch-transformers, so we only have a very simple
     # test.
     params = Params({'model_name': 'bert-base-uncased'})
     embedder = PretrainedTransformerEmbedder.from_params(params)
     tensor = torch.randint(0, 100, (1, 4))
     output = embedder(tensor)
     assert tuple(output.size()) == (1, 4, 768)
 def test_forward_runs_when_initialized_from_params(self):
     # This code just passes things off to `transformers`, so we only have a very simple
     # test.
     params = Params({"model_name": "bert-base-uncased"})
     embedder = PretrainedTransformerEmbedder.from_params(params)
     token_ids = torch.randint(0, 100, (1, 4))
     mask = torch.randint(0, 2, (1, 4))
     output = embedder(token_ids=token_ids, mask=mask)
     assert tuple(output.size()) == (1, 4, 768)
Esempio n. 21
0
    def __init__(
            self,
            vocab: Vocabulary,
            span_extractor: SpanExtractor,
            transformer_model_name_or_archive_path: str = "bert-base-uncased",
            encoder: Optional[Seq2SeqEncoder] = None,
            freeze: bool = False,
            smoothing: bool = False,
            drop_out: float = 0.0,
            **kwargs) -> None:
        super().__init__(vocab, **kwargs)
        self._span_extractor = span_extractor

        # text_field_embedder
        if "model.tar.gz" in transformer_model_name_or_archive_path:
            archive = load_archive(transformer_model_name_or_archive_path)
            self._text_field_embedder = archive.extract_module(
                "_text_field_embedder", freeze)
        else:
            self._text_field_embedder = BasicTextFieldEmbedder({
                "tokens":
                PretrainedTransformerEmbedder(
                    transformer_model_name_or_archive_path)
            })
            if freeze:
                for parameter in self._text_field_embedder.parameters(
                ):  # type: ignore
                    parameter.requires_grad_(not freeze)

        # encoder
        if encoder is None:
            self._encoder = None
        else:
            self._encoder = encoder

        # linear
        self._linear = nn.Linear(self._span_extractor.get_output_dim(), 1)

        # drop out
        self._drop_out = nn.Dropout(drop_out)

        # loss
        self._smoothing = smoothing
        if not smoothing:
            self._loss = nn.CrossEntropyLoss()
        else:
            self._loss = nn.KLDivLoss(reduction="batchmean")

        # metric
        self._span_accuracy = CategoricalAccuracy()
        self._processed_span_accuracy = CategoricalAccuracy()
        self._candidate_jaccard = Jaccard()
Esempio n. 22
0
    def __init__(
        self, vocab: Vocabulary, transformer_model_name: str = "bert-base-cased", **kwargs
    ) -> None:
        super().__init__(vocab, **kwargs)
        self._text_field_embedder = BasicTextFieldEmbedder(
            {"tokens": PretrainedTransformerEmbedder(transformer_model_name)}
        )
        self._linear_layer = nn.Linear(self._text_field_embedder.get_output_dim(), 2)

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._per_instance_metrics = SquadEmAndF1()
Esempio n. 23
0
 def __init__(self, model_name, use_pretrained_embeddings=False):
     super().__init__()
     # or some flag that indicates the bart encoder in it's entirety could be used.
     if use_pretrained_embeddings:
         # will use the entire bart encoder including all embeddings
         bart = PretrainedTransformerEmbedder(model_name,
                                              sub_module="encoder")
     else:
         bart = BartModel.from_pretrained(model_name)
         self.bart_encoder.embed_tokens = lambda x: x
         self.bart_encoder.embed_positions = lambda x: torch.zeros(
             (x.shape[0], x.shape[1], self.hidden_dim), dtype=torch.float32)
     self.hidden_dim = bart.config.hidden_size
     self.bart_encoder = bart.transformer_model
Esempio n. 24
0
    def test_eval_mode(self):
        token_embedder = PretrainedTransformerEmbedder(
            "epwalsh/bert-xsmall-dummy", train_parameters=False)
        assert token_embedder.training and not token_embedder.transformer_model.training

        class TrainableModule(torch.nn.Module):
            def __init__(self, fixed_module):
                super().__init__()
                self.fixed_module = fixed_module

        trainable = TrainableModule(token_embedder)
        assert (trainable.training and trainable.fixed_module.training
                and not trainable.fixed_module.transformer_model.training)

        trainable.train()
        assert not trainable.fixed_module.transformer_model.training
Esempio n. 25
0
 def __init__(
     self,
     word_embedding_dropout: float = 0.05,
     bert_model_name: str = 'japanese_bert',
     word_embedder: BasicTextFieldEmbedder = BasicTextFieldEmbedder({
         'tokens':
         PretrainedTransformerEmbedder(
             model_name='cl-tohoku/bert-base-japanese')
     })):
     super(BertPoolerForMention, self).__init__()
     self.bert_model_name = bert_model_name
     self.huggingface_nameloader()
     self.bertpooler_sec2vec = BertPooler(
         pretrained_model=self.bert_weight_filepath)
     self.word_embedder = word_embedder
     self.word_embedding_dropout = nn.Dropout(word_embedding_dropout)
Esempio n. 26
0
def build_model_Transformer(vocab: Vocabulary, use_reg: bool = True) -> Model:
    print("Building the model")
    vocab_size = vocab.get_vocab_size("tokens")
    EMBED_DIMS = 300
    # turn the tokens into 300 dim embedding. Then, turn the embeddings into encodings
    embedder = PretrainedTransformerEmbedder(BERT_MODEL_NAME)
    encoder = BertPooler(
        BERT_MODEL_NAME
    )  # num_filters is a tad bit dangerous: the reason is that we have this many filters for EACH ngram f
    # encoder = BertPooler("bert-base-cased")
    # the output dim is just the num filters *len(ngram_filter_sizes)

    #     construct the regularizer applicator
    regularizer_applicator = None
    if use_reg:
        l2_reg = L2Regularizer()
        regexes = [("embedder", l2_reg), ("encoder", l2_reg),
                   ("classifier", l2_reg)]
        regularizer_applicator = RegularizerApplicator(regexes)

    return MortalityClassifier(vocab, embedder, encoder,
                               regularizer_applicator)
Esempio n. 27
0
class BertClassifier(Model):
    def __init__(self, vocab, num_labels):
        super().__init__(vocab)
        self.bert_embedder = PretrainedTransformerEmbedder('bert-base-uncased')
        self.pooler = ClsPooler(self.bert_embedder.get_output_dim())

        self.linear = torch.nn.Sequential(
            torch.nn.Dropout(0.1),
            torch.nn.Linear(in_features=768, out_features=num_labels))

        self.accuracy = CategoricalAccuracy()
        self.loss_function = torch.nn.CrossEntropyLoss()

    def forward(self, sent, label=None):
        bert_embeddings = self.bert_embedder(
            token_ids=sent['tokens']['token_ids'],
            type_ids=sent['tokens']['type_ids'],
            mask=sent['tokens']['mask'])
        bert_vec = self.pooler(bert_embeddings)

        logits = self.linear(bert_vec)
        output = {"logits": logits, "probs": F.softmax(logits, dim=1)}
        if label is not None:
            self.accuracy(logits, label)
            output["loss"] = self.loss_function(logits, label)
        return output

    def get_metrics(self, reset=False):
        return {'accuracy': self.accuracy.get_metric(reset)}

    def get_optimizer(self):
        optimizer = AdamW(self.parameters(), lr=2e-5, eps=1e-8)
        # get_linear_schedule_with_warmup(
        #     optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
        # )
        return optimizer
Esempio n. 28
0
class PretrainedTransformerMismatchedEmbedder(TokenEmbedder):
    """
    Use this embedder to embed wordpieces given by `PretrainedTransformerMismatchedIndexer`
    and to pool the resulting vectors to get word-level representations.

    Registered as a `TokenEmbedder` with name "pretrained_transformer_mismatchd".

    # Parameters

    model_name : `str`
        The name of the `transformers` model to use. Should be the same as the corresponding
        `PretrainedTransformerMismatchedIndexer`.
    max_length : `int`, optional (default = `None`)
        If positive, folds input token IDs into multiple segments of this length, pass them
        through the transformer model independently, and concatenate the final representations.
        Should be set to the same value as the `max_length` option on the
        `PretrainedTransformerMismatchedIndexer`.
    train_parameters: `bool`, optional (default = `True`)
        If this is `True`, the transformer weights get updated during training.
    """
    def __init__(self,
                 model_name: str,
                 max_length: int = None,
                 train_parameters: bool = True) -> None:
        super().__init__()
        # The matched version v.s. mismatched
        self._matched_embedder = PretrainedTransformerEmbedder(
            model_name,
            max_length=max_length,
            train_parameters=train_parameters)

    @overrides
    def get_output_dim(self):
        return self._matched_embedder.get_output_dim()

    @overrides
    def forward(
        self,
        token_ids: torch.LongTensor,
        mask: torch.BoolTensor,
        offsets: torch.LongTensor,
        wordpiece_mask: torch.BoolTensor,
        type_ids: Optional[torch.LongTensor] = None,
        segment_concat_mask: Optional[torch.BoolTensor] = None,
    ) -> torch.Tensor:  # type: ignore
        """
        # Parameters

        token_ids: `torch.LongTensor`
            Shape: [batch_size, num_wordpieces] (for exception see `PretrainedTransformerEmbedder`).
        mask: `torch.BoolTensor`
            Shape: [batch_size, num_orig_tokens].
        offsets: `torch.LongTensor`
            Shape: [batch_size, num_orig_tokens, 2].
            Maps indices for the original tokens, i.e. those given as input to the indexer,
            to a span in token_ids. `token_ids[i][offsets[i][j][0]:offsets[i][j][1] + 1]`
            corresponds to the original j-th token from the i-th batch.
        wordpiece_mask: `torch.BoolTensor`
            Shape: [batch_size, num_wordpieces].
        type_ids: `Optional[torch.LongTensor]`
            Shape: [batch_size, num_wordpieces].
        segment_concat_mask: `Optional[torch.BoolTensor]`
            See `PretrainedTransformerEmbedder`.

        # Returns

        `torch.Tensor`
            Shape: [batch_size, num_orig_tokens, embedding_size].
        """
        # Shape: [batch_size, num_wordpieces, embedding_size].
        embeddings = self._matched_embedder(
            token_ids,
            wordpiece_mask,
            type_ids=type_ids,
            segment_concat_mask=segment_concat_mask)

        # span_embeddings: (batch_size, num_orig_tokens, max_span_length, embedding_size)
        # span_mask: (batch_size, num_orig_tokens, max_span_length)
        span_embeddings, span_mask = util.batched_span_select(
            embeddings.contiguous(), offsets)
        span_mask = span_mask.unsqueeze(-1)
        span_embeddings *= span_mask  # zero out paddings

        span_embeddings_sum = span_embeddings.sum(2)
        span_embeddings_len = span_mask.sum(2)
        # Shape: (batch_size, num_orig_tokens, embedding_size)
        orig_embeddings = span_embeddings_sum / span_embeddings_len

        # All the places where the span length is zero, write in zeros.
        orig_embeddings[(span_embeddings_len == 0).expand(
            orig_embeddings.shape)] = 0

        return orig_embeddings
Esempio n. 29
0
def bert_emb_returner():
    return BasicTextFieldEmbedder({
        'tokens':
        PretrainedTransformerEmbedder(
            model_name='cl-tohoku/bert-base-japanese')
    })
Esempio n. 30
0
 def test_encoder_decoder_model(self):
     token_embedder = PretrainedTransformerEmbedder("facebook/bart-large", sub_module="encoder")
     token_ids = torch.LongTensor([[1, 2, 3], [2, 3, 4]])
     mask = torch.ones_like(token_ids).bool()
     token_embedder(token_ids, mask)