def convert_s3prl_checkpoint(base_model_name, config_path, checkpoint_path, model_dump_path): """ Copy/paste/tweak model's weights to transformers design. """ checkpoint = torch.load(checkpoint_path, map_location="cpu") downstream_dict = checkpoint["Downstream"] hf_config = UniSpeechSatConfig.from_pretrained(config_path) hf_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( base_model_name, return_attention_mask=True, do_normalize=False) arch = hf_config.architectures[0] if arch.endswith("ForSequenceClassification"): hf_model = convert_classification(base_model_name, hf_config, downstream_dict) elif arch.endswith("ForAudioFrameClassification"): hf_model = convert_diarization(base_model_name, hf_config, downstream_dict) elif arch.endswith("ForXVector"): hf_model = convert_xvector(base_model_name, hf_config, downstream_dict) else: raise NotImplementedError( f"S3PRL weights conversion is not supported for {arch}") if hf_config.use_weighted_layer_sum: hf_model.layer_weights.data = checkpoint["Featurizer"]["weights"] hf_feature_extractor.save_pretrained(model_dump_path) hf_model.save_pretrained(model_dump_path)
def convert_unispeech_sat_checkpoint( checkpoint_path, pytorch_dump_folder_path, config_path=None, dict_path=None, is_finetuned=True ): """ Copy/paste/tweak model's weights to transformers design. """ if config_path is not None: config = UniSpeechSatConfig.from_pretrained(config_path) else: config = UniSpeechSatConfig() dict_path = "" if is_finetuned: hf_wav2vec = UniSpeechSatForCTC(config) else: hf_wav2vec = UniSpeechSatForPreTraining(config) model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task( [checkpoint_path], arg_overrides={"data": "/".join(dict_path.split("/")[:-1])} ) model = model[0].eval() recursively_load_weights(model, hf_wav2vec) hf_wav2vec.save_pretrained(pytorch_dump_folder_path)