def convert_tf_checkpoint_to_pytorch(tf_hub_path, pytorch_dump_path, is_encoder_named_decoder, vocab_size, is_encoder): # Initialise PyTorch model bert_config = BertConfig.from_pretrained( "bert-large-cased", vocab_size=vocab_size, max_position_embeddings=512, is_decoder=True, add_cross_attention=True, ) bert_config_dict = bert_config.to_dict() del bert_config_dict["type_vocab_size"] config = BertGenerationConfig(**bert_config_dict) if is_encoder: model = BertGenerationEncoder(config) else: model = BertGenerationDecoder(config) print("Building PyTorch model from configuration: {}".format(str(config))) # Load weights from tf checkpoint load_tf_weights_in_bert_generation( model, tf_hub_path, model_class="bert", is_encoder_named_decoder=is_encoder_named_decoder, is_encoder=is_encoder, ) # Save pytorch-model print("Save PyTorch model and config to {}".format(pytorch_dump_path)) model.save_pretrained(pytorch_dump_path)
def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_mask = None if self.use_input_mask: input_mask = random_attention_mask([self.batch_size, self.seq_length]) if self.use_labels: token_labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) config = BertGenerationConfig( vocab_size=self.vocab_size, hidden_size=self.hidden_size, num_hidden_layers=self.num_hidden_layers, num_attention_heads=self.num_attention_heads, intermediate_size=self.intermediate_size, hidden_act=self.hidden_act, hidden_dropout_prob=self.hidden_dropout_prob, attention_probs_dropout_prob=self.attention_probs_dropout_prob, max_position_embeddings=self.max_position_embeddings, is_decoder=False, initializer_range=self.initializer_range, return_dict=True, ) return config, input_ids, input_mask, token_labels
def create_slt_transformer(input_vocab_size=1, output_vocab_size=1, **bert_params): if input_vocab_size == 1: print('WARNING: Input vocab size is 1') if output_vocab_size == 1: print('WARNING: Output vocab size is 1') params = { 'vocab_size': input_vocab_size, 'hidden_size': 512, 'intermediate_size': 2048, 'max_position_embeddings': 500, 'num_attention_heads': 8, 'num_hidden_layers': 3, 'hidden_act': 'relu', 'type_vocab_size': 1, 'hidden_dropout_prob': 0.1, 'attention_probs_dropout_prob': 0.1 } params.update(bert_params) config = BertGenerationConfig(**params) encoder = BertGenerationEncoder(config=config) params['vocab_size'] = output_vocab_size decoder_config = BertGenerationConfig(is_decoder=True, add_cross_attention=True, **params) decoder = BertGenerationDecoder(config=decoder_config) transformer = EncoderDecoderModel(encoder=encoder, decoder=decoder) def count_parameters(m): return sum(p.numel() for p in m.parameters() if p.requires_grad) print( f'The encoder has {count_parameters(encoder):,} trainable parameters') print( f'The decoder has {count_parameters(decoder):,} trainable parameters') print( f'The whole model has {count_parameters(transformer):,} trainable parameters' ) return transformer
def get_config(self): return BertGenerationConfig( vocab_size=self.vocab_size, hidden_size=self.hidden_size, num_hidden_layers=self.num_hidden_layers, num_attention_heads=self.num_attention_heads, intermediate_size=self.intermediate_size, hidden_act=self.hidden_act, hidden_dropout_prob=self.hidden_dropout_prob, attention_probs_dropout_prob=self.attention_probs_dropout_prob, max_position_embeddings=self.max_position_embeddings, is_decoder=False, initializer_range=self.initializer_range, )
def test_torch_encode_plus_sent_to_model(self): import torch from transformers import BertGenerationConfig, BertGenerationEncoder # Build sequence first_ten_tokens = list(self.big_tokenizer.get_vocab().keys())[:10] sequence = " ".join(first_ten_tokens) encoded_sequence = self.big_tokenizer.encode_plus(sequence, return_tensors="pt", return_token_type_ids=False) batch_encoded_sequence = self.big_tokenizer.batch_encode_plus( [sequence + " " + sequence], return_tensors="pt", return_token_type_ids=False ) config = BertGenerationConfig() model = BertGenerationEncoder(config) assert model.get_input_embeddings().weight.shape[0] >= self.big_tokenizer.vocab_size with torch.no_grad(): model(**encoded_sequence) model(**batch_encoded_sequence)