def get_encoder_decoder_model(self, config, decoder_config): encoder_model = Wav2Vec2Model(config).eval() decoder_model = Speech2Text2ForCausalLM(decoder_config).eval() return encoder_model, decoder_model
def convert_wav2vec2_checkpoint( checkpoint_path, pytorch_dump_folder_path, dict_path, encoder_config_path, decoder_config_path, vocab_size, num_decoder_layers, ): """ Copy/paste/tweak model's weights to transformers design. """ encoder_config = Wav2Vec2Config.from_pretrained(encoder_config_path) decoder_config = Speech2Text2Config.from_pretrained( decoder_config_path, vocab_size=vocab_size, decoder_layers=num_decoder_layers, do_stable_layer_norm=True ) feature_extractor = Wav2Vec2FeatureExtractor( feature_size=1, sampling_rate=16000, padding_value=0, do_normalize=True, return_attention_mask=True, ) model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task( [checkpoint_path], arg_overrides={"data": "/".join(dict_path.split("/")[:-1])} ) model = model[0].eval() # set weights for wav2vec2 encoder hf_encoder = Wav2Vec2Model(encoder_config) projection_layer = recursively_load_weights_wav2vec2(model.encoder, hf_encoder) hf_decoder = Speech2Text2ForCausalLM(decoder_config) missing_keys, unexpected_keys = hf_decoder.model.decoder.load_state_dict(model.decoder.state_dict(), strict=False) # set output linear layer unexpected_keys.remove("embed_out") hf_decoder.lm_head.weight = nn.Parameter(model.decoder.embed_out.detach()) # layer norm is init to identity matrix so leaving it is fine logger.warning(f"The following keys are missing when loading the decoder weights: {missing_keys}") logger.warning(f"The following keys are unexpected when loading the decoder weights: {unexpected_keys}") hf_wav2vec = SpeechEncoderDecoderModel(encoder=hf_encoder, decoder=hf_decoder) hf_wav2vec.config.tie_word_embeddings = False # add projection layer hf_wav2vec.enc_to_dec_proj.weight = nn.Parameter(projection_layer.weight) hf_wav2vec.enc_to_dec_proj.bias = nn.Parameter(projection_layer.bias) vocab_dict = create_vocab_dict(dict_path) with open(os.path.join(pytorch_dump_folder_path, "vocab.json"), "w") as fp: json.dump(vocab_dict, fp) tokenizer = Speech2Text2Tokenizer(os.path.join(pytorch_dump_folder_path, "vocab.json")) tokenizer.save_pretrained(pytorch_dump_folder_path) config = hf_wav2vec.config.to_dict() config["pad_token_id"] = tokenizer.pad_token_id config["bos_token_id"] = tokenizer.bos_token_id config["eos_token_id"] = tokenizer.eos_token_id config["tokenizer_class"] = "speech_to_text_2" config["feature_extractor_type"] = "wav2vec2" hf_wav2vec.config = SpeechEncoderDecoderConfig.from_dict(config) hf_wav2vec.save_pretrained(pytorch_dump_folder_path) feature_extractor.save_pretrained(pytorch_dump_folder_path)