def convert_s3prl_checkpoint(base_model_name, config_path, checkpoint_path,
                             model_dump_path):
    """
    Copy/paste/tweak model's weights to transformers design.
    """
    checkpoint = torch.load(checkpoint_path, map_location="cpu")

    downstream_dict = checkpoint["Downstream"]

    hf_config = UniSpeechSatConfig.from_pretrained(config_path)
    hf_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
        base_model_name, return_attention_mask=True, do_normalize=False)

    arch = hf_config.architectures[0]
    if arch.endswith("ForSequenceClassification"):
        hf_model = convert_classification(base_model_name, hf_config,
                                          downstream_dict)
    elif arch.endswith("ForAudioFrameClassification"):
        hf_model = convert_diarization(base_model_name, hf_config,
                                       downstream_dict)
    elif arch.endswith("ForXVector"):
        hf_model = convert_xvector(base_model_name, hf_config, downstream_dict)
    else:
        raise NotImplementedError(
            f"S3PRL weights conversion is not supported for {arch}")

    if hf_config.use_weighted_layer_sum:
        hf_model.layer_weights.data = checkpoint["Featurizer"]["weights"]

    hf_feature_extractor.save_pretrained(model_dump_path)
    hf_model.save_pretrained(model_dump_path)
def convert_unispeech_sat_checkpoint(
    checkpoint_path, pytorch_dump_folder_path, config_path=None, dict_path=None, is_finetuned=True
):
    """
    Copy/paste/tweak model's weights to transformers design.
    """
    if config_path is not None:
        config = UniSpeechSatConfig.from_pretrained(config_path)
    else:
        config = UniSpeechSatConfig()

    dict_path = ""

    if is_finetuned:
        hf_wav2vec = UniSpeechSatForCTC(config)
    else:
        hf_wav2vec = UniSpeechSatForPreTraining(config)

    model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
        [checkpoint_path], arg_overrides={"data": "/".join(dict_path.split("/")[:-1])}
    )
    model = model[0].eval()

    recursively_load_weights(model, hf_wav2vec)

    hf_wav2vec.save_pretrained(pytorch_dump_folder_path)