def convert_tf_checkpoint_to_pytorch(tf_hub_path, pytorch_dump_path, is_encoder_named_decoder, vocab_size, is_encoder):
    # Initialise PyTorch model
    bert_config = BertConfig.from_pretrained(
        "bert-large-cased",
        vocab_size=vocab_size,
        max_position_embeddings=512,
        is_decoder=True,
        add_cross_attention=True,
    )
    bert_config_dict = bert_config.to_dict()
    del bert_config_dict["type_vocab_size"]
    config = BertGenerationConfig(**bert_config_dict)
    if is_encoder:
        model = BertGenerationEncoder(config)
    else:
        model = BertGenerationDecoder(config)
    print("Building PyTorch model from configuration: {}".format(str(config)))

    # Load weights from tf checkpoint
    load_tf_weights_in_bert_generation(
        model,
        tf_hub_path,
        model_class="bert",
        is_encoder_named_decoder=is_encoder_named_decoder,
        is_encoder=is_encoder,
    )

    # Save pytorch-model
    print("Save PyTorch model and config to {}".format(pytorch_dump_path))
    model.save_pretrained(pytorch_dump_path)
예제 #2
0
    def prepare_config_and_inputs(self):
        input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)

        input_mask = None
        if self.use_input_mask:
            input_mask = random_attention_mask([self.batch_size, self.seq_length])

        if self.use_labels:
            token_labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)

        config = BertGenerationConfig(
            vocab_size=self.vocab_size,
            hidden_size=self.hidden_size,
            num_hidden_layers=self.num_hidden_layers,
            num_attention_heads=self.num_attention_heads,
            intermediate_size=self.intermediate_size,
            hidden_act=self.hidden_act,
            hidden_dropout_prob=self.hidden_dropout_prob,
            attention_probs_dropout_prob=self.attention_probs_dropout_prob,
            max_position_embeddings=self.max_position_embeddings,
            is_decoder=False,
            initializer_range=self.initializer_range,
            return_dict=True,
        )

        return config, input_ids, input_mask, token_labels
def create_slt_transformer(input_vocab_size=1,
                           output_vocab_size=1,
                           **bert_params):

    if input_vocab_size == 1:
        print('WARNING: Input vocab size is 1')
    if output_vocab_size == 1:
        print('WARNING: Output vocab size is 1')

    params = {
        'vocab_size': input_vocab_size,
        'hidden_size': 512,
        'intermediate_size': 2048,
        'max_position_embeddings': 500,
        'num_attention_heads': 8,
        'num_hidden_layers': 3,
        'hidden_act': 'relu',
        'type_vocab_size': 1,
        'hidden_dropout_prob': 0.1,
        'attention_probs_dropout_prob': 0.1
    }
    params.update(bert_params)

    config = BertGenerationConfig(**params)
    encoder = BertGenerationEncoder(config=config)

    params['vocab_size'] = output_vocab_size
    decoder_config = BertGenerationConfig(is_decoder=True,
                                          add_cross_attention=True,
                                          **params)
    decoder = BertGenerationDecoder(config=decoder_config)

    transformer = EncoderDecoderModel(encoder=encoder, decoder=decoder)

    def count_parameters(m):
        return sum(p.numel() for p in m.parameters() if p.requires_grad)

    print(
        f'The encoder has {count_parameters(encoder):,} trainable parameters')
    print(
        f'The decoder has {count_parameters(decoder):,} trainable parameters')
    print(
        f'The whole model has {count_parameters(transformer):,} trainable parameters'
    )

    return transformer
 def get_config(self):
     return BertGenerationConfig(
         vocab_size=self.vocab_size,
         hidden_size=self.hidden_size,
         num_hidden_layers=self.num_hidden_layers,
         num_attention_heads=self.num_attention_heads,
         intermediate_size=self.intermediate_size,
         hidden_act=self.hidden_act,
         hidden_dropout_prob=self.hidden_dropout_prob,
         attention_probs_dropout_prob=self.attention_probs_dropout_prob,
         max_position_embeddings=self.max_position_embeddings,
         is_decoder=False,
         initializer_range=self.initializer_range,
     )
    def test_torch_encode_plus_sent_to_model(self):
        import torch

        from transformers import BertGenerationConfig, BertGenerationEncoder

        # Build sequence
        first_ten_tokens = list(self.big_tokenizer.get_vocab().keys())[:10]
        sequence = " ".join(first_ten_tokens)
        encoded_sequence = self.big_tokenizer.encode_plus(sequence, return_tensors="pt", return_token_type_ids=False)
        batch_encoded_sequence = self.big_tokenizer.batch_encode_plus(
            [sequence + " " + sequence], return_tensors="pt", return_token_type_ids=False
        )

        config = BertGenerationConfig()
        model = BertGenerationEncoder(config)

        assert model.get_input_embeddings().weight.shape[0] >= self.big_tokenizer.vocab_size

        with torch.no_grad():
            model(**encoded_sequence)
            model(**batch_encoded_sequence)