def __init__(self,
                 source,
                 save_path,
                 output_norm=True,
                 freeze=True,
                 pretrain=True):
        super().__init__()

        # Download the extractor from HuggingFace.
        # The extractor is only used to retrieve the normalisation
        self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
            source, cache_dir=save_path)

        # Download the model from HuggingFace.
        # if pretrain is False, we do not download the pretrained weights
        # it it is True, we download and load them.
        if not (pretrain):
            config = Wav2Vec2Config.from_pretrained(source,
                                                    cache_dir=save_path)
            self.model = Wav2Vec2Model(config)
        else:
            self.model = Wav2Vec2Model.from_pretrained(source,
                                                       cache_dir=save_path)

        # We check if inputs need to be normalized w.r.t pretrained wav2vec2
        self.normalize_wav = self.feature_extractor.do_normalize

        self.freeze = freeze
        self.output_norm = output_norm
        if self.freeze:
            self.model.eval()
        else:
            self.model.train()
Beispiel #2
0
def convert_wav2vec2_checkpoint(checkpoint_path,
                                pytorch_dump_folder_path,
                                config_path=None,
                                dict_path=None,
                                is_finetuned=True):
    """
    Copy/paste/tweak model's weights to transformers design.
    """
    if config_path is not None:
        config = Wav2Vec2Config.from_pretrained(config_path)
    else:
        config = Wav2Vec2Config()

    if is_finetuned:
        hf_wav2vec = Wav2Vec2ForCTC(config)
    else:
        hf_wav2vec = Wav2Vec2Model(config)

    if is_finetuned:
        model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
            [checkpoint_path], arg_overrides={"data": dict_path})
    else:
        model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
            [checkpoint_path])

    model = model[0].eval()

    recursively_load_weights(model, hf_wav2vec, is_finetuned)

    hf_wav2vec.save_pretrained(pytorch_dump_folder_path)
Beispiel #3
0
    def create_and_check_batch_inference(self, config, input_values, *args):
        # Not sure how to make this test pass at the moment. Batched input yields
        # same results as official fairseq implementation, but gives different results
        # depending on whether batched input is used or not
        # check: https://github.com/pytorch/fairseq/issues/3227
        model = Wav2Vec2Model(config=config)
        model.to(torch_device)
        model.eval()

        input_values = input_values[:3]
        attention_mask = torch.ones(input_values.shape,
                                    device=torch_device,
                                    dtype=torch.bool)

        input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]

        # pad input
        for i in range(len(input_lengths)):
            input_values[i, input_lengths[i]:] = 0.0
            attention_mask[i, input_lengths[i]:] = 0.0

        batch_outputs = model(input_values,
                              attention_mask=attention_mask).last_hidden_state

        for i in range(input_values.shape[0]):
            input_slice = input_values[i:i + 1, :input_lengths[i]]
            output = model(input_slice).last_hidden_state

            batch_output = batch_outputs[i:i + 1, :output.shape[1]]
            self.parent.assertTrue(
                torch.allclose(output, batch_output, atol=1e-3))
Beispiel #4
0
    def __init__(self,
                 ckpt: str = None,
                 model_config: str = None,
                 feature_selection: str = None,
                 **kwargs):
        """
        Args:
            ckpt:
                The checkpoint path for loading your pretrained weights.

            model_config:
                The config path for constructing your model.
                Might not needed if you also save that in your checkpoint file.

            feature_selection:
                The string for you to control the different behavior of the
                same pretrained model, like extracting different layers as
                the representations.
        """
        super().__init__()

        self.processor = Wav2Vec2Processor.from_pretrained(ckpt)
        self.model = Wav2Vec2Model.from_pretrained(ckpt)

        pseudo_input = [torch.randn(SAMPLE_RATE)]
        pseudo_output = self.forward(pseudo_input)
        self._output_dim = pseudo_output[0].size(-1)
    def create_and_check_batch_inference(self, config, input_values, *args):
        # test does not pass for models making use of `group_norm`
        # check: https://github.com/pytorch/fairseq/issues/3227
        model = Wav2Vec2Model(config=config)
        model.to(torch_device)
        model.eval()

        input_values = input_values[:3]
        attention_mask = torch.ones(input_values.shape,
                                    device=torch_device,
                                    dtype=torch.bool)

        input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]

        # pad input
        for i in range(len(input_lengths)):
            input_values[i, input_lengths[i]:] = 0.0
            attention_mask[i, input_lengths[i]:] = 0.0

        batch_outputs = model(input_values,
                              attention_mask=attention_mask).last_hidden_state

        for i in range(input_values.shape[0]):
            input_slice = input_values[i:i + 1, :input_lengths[i]]
            output = model(input_slice).last_hidden_state

            batch_output = batch_outputs[i:i + 1, :output.shape[1]]
            self.parent.assertTrue(
                torch.allclose(output, batch_output, atol=1e-3))
def _main():
    keys = [
        # pretrained
        "facebook/wav2vec2-base",
        "facebook/wav2vec2-large",
        "facebook/wav2vec2-large-lv60",
        "facebook/wav2vec2-base-10k-voxpopuli",
        "facebook/wav2vec2-large-xlsr-53",
        # finetuned
        "facebook/wav2vec2-base-960h",
        "facebook/wav2vec2-large-960h",
        "facebook/wav2vec2-large-960h-lv60",
        "facebook/wav2vec2-large-960h-lv60-self",
        "facebook/wav2vec2-large-xlsr-53-german",
    ]
    for key in keys:
        path = os.path.join(_THIS_DIR, f'{key}.json')
        print('Generating ', path)
        cfg = Wav2Vec2Model.from_pretrained(key).config
        cfg = json.loads(cfg.to_json_string())
        del cfg['_name_or_path']

        with open(path, 'w') as file_:
            file_.write(json.dumps(cfg, indent=4, sort_keys=True))
            file_.write('\n')
    def __init__(self, device):
        self.device = device

        self.tokenizer = Wav2Vec2Tokenizer.from_pretrained(
            "facebook/wav2vec2-base-960h")
        self.raw_model = Wav2Vec2Model.from_pretrained(
            "facebook/wav2vec2-base-960h").to(self.device)
    def __init__(self,
                 source,
                 save_path,
                 output_norm=True,
                 freeze=True,
                 pretrain=True):
        super().__init__()

        # Download the model from HuggingFace and load it.
        # The Processor is only used to retrieve the normalisation
        self.proc = Wav2Vec2Processor.from_pretrained(source,
                                                      cache_dir=save_path)
        self.model = Wav2Vec2Model.from_pretrained(source, cache_dir=save_path)

        # Randomly initialized layers if pretrain is False
        if not (pretrain):
            self.reset_layer(self.model)

        # We check if inputs need to be normalized w.r.t pretrained wav2vec2
        self.normalize_wav = self.proc.feature_extractor.do_normalize

        self.freeze = freeze
        self.output_norm = output_norm
        if self.freeze:
            self.model.eval()
        else:
            self.model.train()
Beispiel #9
0
    def __init__(self, config, pretrained_model, pretrained_dir):
        super().__init__(config)

        self.wav2vec2 = Wav2Vec2Model.from_pretrained(pretrained_model, cache_dir=pretrained_dir)
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(config.hidden_size, 23)

        self.init_weights()
Beispiel #10
0
def load_wav2vec2(ckpt_path='facebook/wav2vec2-base-960h'):
    """Load pretrained Wav2Vec2 model."""
    def extract_features(self, wav, mask):
        return [self(wav).last_hidden_state]

    Wav2Vec2Model.extract_features = extract_features  # for same behaviour as fairseq.Wav2Vec2Model
    model = Wav2Vec2Model.from_pretrained(ckpt_path).eval()
    return model
 def create_and_check_model(self, config, input_values, attention_mask):
     model = Wav2Vec2Model(config=config)
     model.to(torch_device)
     model.eval()
     result = model(input_values, attention_mask=attention_mask)
     self.parent.assertEqual(
         result.last_hidden_state.shape,
         (self.batch_size, self.output_seq_length, self.hidden_size))
Beispiel #12
0
 def __init__(self, device="cuda"):
     self.encoder = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
     self.encoder.eval()
     self.encoder = self.encoder.to(device)
     self.preprocessor = Wav2Vec2Processor.from_pretrained(
         "facebook/wav2vec2-base")
     self.preprocessor._sample_rate = 16000
     self.device = device
 def __init__(self, device):
     super(COVIDWav2Vec, self).__init__()
     self.tokenizer = Wav2Vec2Tokenizer.from_pretrained(
         "facebook/wav2vec2-base-960h"
     )
     self.model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
     self.linear = nn.Linear(768, 1)
     self.device = device
     return
Beispiel #14
0
    def __init__(self, config):
        super().__init__(config)

        self.wav2vec2 = Wav2Vec2Model(config)

        self.tanh = nn.Tanh()
        self.linear1 = nn.Linear(1024, 1024)
        self.linear2 = nn.Linear(1024, 5)
        self.init_weights()
def load_wav2vec2(ckpt_path='facebook/wav2vec2-base-960h'):
  """Load pretrained Wav2Vec2 model."""
  def extract_features(self, wav, mask):
    # wav2vec has window of 400, so we pad to center windows
    wav = torch.nn.functional.pad(wav.unsqueeze(1), (200, 200), mode='reflect').squeeze(1)
    return [self(wav).last_hidden_state]

  Wav2Vec2Model.extract_features = extract_features # for same behaviour as fairseq.Wav2Vec2Model
  model = Wav2Vec2Model.from_pretrained(ckpt_path)
  return model
Beispiel #16
0
    def __init__(self, wave2vec_model_name=None, hidden_size=768, num_classes=397):
        super().__init__()

        self.wav2vec2 = Wav2Vec2Model.from_pretrained(wave2vec_model_name)
        # for param in self.wav2vec2.modules():
        # param.requires_grad = False

        self.classifier = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(hidden_size, num_classes),
        )
 def __init__(self, hidden_size=512, num_classes=8, device='cpu', sr=16000):
     super(Wav2VecClassifier, self).__init__()
     self.hidden_size = hidden_size
     self.sr = sr
     self.device = device
     self.processor = Wav2Vec2Processor.from_pretrained(
         "facebook/wav2vec2-base-960h")
     self.model = Wav2Vec2Model.from_pretrained(
         "facebook/wav2vec2-base-960h")
     self.lstm = nn.LSTM(768, hidden_size, batch_first=True)
     self.fc = nn.Linear(hidden_size, num_classes)
Beispiel #18
0
    def __init__(self, config):
        super().__init__(config)

        self.wav2vec2 = Wav2Vec2Model(config)

        # self.inner_dim = 128
        # self.feature_size = 999

        self.tanh = nn.Tanh()
        self.linear1 = nn.Linear(1024, 1024)
        self.linear2 = nn.Linear(1024, 2)
        self.init_weights()
Beispiel #19
0
    def __init__(self, freeze=True):
        super().__init__()
        self.encoder = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
        self.freeze = freeze

        self.dense = nn.Sequential(nn.Linear(768, 128), nn.ReLU(),
                                   nn.Dropout(0.1), nn.Linear(128, 1))

        if self.freeze:
            self.encoder.eval()
            for p in self.encoder.parameters():
                p.requires_grad_(False)
Beispiel #20
0
    def __init__(self, config):
        super().__init__(config)

        self.wav2vec2 = Wav2Vec2Model(config)

        self.inner_dim = 128
        self.feature_size = 249

        self.tanh = nn.Tanh()
        self.linear1 = nn.Linear(1024, self.inner_dim)
        self.linear2 = nn.Linear(self.inner_dim * self.feature_size, 5)
        self.init_weights()
 def __init__(self,
              model_path: str = "facebook/wav2vec2-large-xlsr-53",
              device: str = "cpu",
              target_sample_rate: int = 16000) -> None:
     super().__init__()
     self.target_sample_rate = target_sample_rate
     self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
         model_path, cache_dir=".ckpt")  #, map_location="cpu"
     if not device == "cpu" and not torch.cuda.is_available():
         logging.warning("gpu 不可用,使用cpu推理")
         device = "cpu"
     self.device = torch.device(device)
     self.model = Wav2Vec2Model.from_pretrained(model_path,
                                                cache_dir=".ckpt").to(
                                                    self.device)
Beispiel #22
0
    def __init__(self, config):
        super().__init__(config)

        self.model = Wav2Vec2Model(config)

        self.inner = 128
        self.features = 499

        self.leakyReLu = nn.LeakyReLU()
        self.sigmoid = nn.Sigmoid()

        self.fc1 = nn.Linear(1024, self.inner)
        self.fc2 = nn.Linear(self.inner * self.features, 1024)
        self.fc3 = nn.Linear(1024, 1024)
        self.fc4 = nn.Linear(1024, 5)
    def __init__(self, hp, pretrain_model, freeze_feature_extractor=False):
        super().__init__()
        self.hp = hp
        self.d_model_d = hp.d_model_d
        self.trg_vocab = hp.vocab_size
        self.encoder_type = hp.encoder
        self.decoder_type = hp.decoder
        self.use_ctc = hp.use_ctc
        self.freeze_feature_extractor = freeze_feature_extractor
        self.iter_freeze_encoder = hp.iter_freeze_encoder

        if self.decoder_type == 'ctc' and self.use_ctc:
            warnings.warn(f"hp.decoder == 'ctc' and hp.use_ctc is True, hp.use_ctc is changed to False")
            self.use_ctc = False
        
        self.frame_stacking = True if hp.frame_stacking is not None else False

        self.encoder = Wav2Vec2Model.from_pretrained(pretrain_model)

        #self.encoder.config.mask_feature_prob = hp.feature_mask

        if self.freeze_feature_extractor:
            print('freeze parameters')
            self.encoder.feature_extractor._freeze_parameters()

        if self.encoder_type.lower() == 'conformer':
            if hp.cnn_avepool:
                self.cnn_encoder = CNN_embedding_avepool(hp)
            else:
                self.cnn_encoder = CNN_embedding(hp)
                self.encoder_asr = ConformerEncoder(hp)
                self.decoder = LSTMDecoder(hp)
        else:
            self.dropout = nn.Dropout(0.1)
            if self.decoder_type.lower() == 'transformer':
                self.linear = nn.Linear(1024, self.d_model_d)
                self.decoder = Decoder(hp)
                self.out = nn.Linear(self.d_model_d, self.trg_vocab)
            elif self.decoder_type.lower() == 'ctc':
                self.decoder = nn.Linear(1024, self.trg_vocab)
            else:
                self.linear = nn.Linear(1024, self.d_model_d)
                self.decoder = LSTMDecoder(hp)

        if self.use_ctc:
            self.out_ctc = nn.Linear(1024, self.trg_vocab)
    def __init__(self, filepath, vocab_size=100, internal_vector_size=128, upscale=True, multilingual=False, pretrained_model="facebook/wav2vec2-base-960h"):
        super().__init__()
        self.filepath = filepath
        self.vocab_size = vocab_size
        self.wav2vec = Wav2Vec2Model.from_pretrained(pretrained_model)#, apply_spec_augment=False)
        self.upscale = upscale
        self.multilingual = multilingual
        fc_input_size = self.wav2vec.config.hidden_size
        if upscale:
            self.upscaler = Upscaler(self.wav2vec.config.hidden_size, internal_vector_size)
            fc_input_size = internal_vector_size

        if self.multilingual:
            self.fcs = nn.ModuleDict({k: nn.Linear(fc_input_size, v) for \
                    k, v in VOCAB_SIZES.items()})
        else:
            self.fc = nn.Linear(fc_input_size, vocab_size)
Beispiel #25
0
def load_pretrained_wav2vec(ckpt_path):
    """Load pretrained Wav2Vec model."""

    # ckpt = torch.load(ckpt_path)
    # model = Wav2Vec2Model.build_model(ckpt["args"], task=None)
    # model.load_state_dict(ckpt["model"])
    # model.remove_pretraining_modules()
    # model.eval()

    def extract_features(self, wav, mask):
        # wav2vec has window of 400, so we pad to center windows
        wav = torch.nn.functional.pad(wav.unsqueeze(1), (200, 200),
                                      mode='reflect').squeeze(1)
        return [self(wav).last_hidden_state]

    Wav2Vec2Model.extract_features = extract_features  # for same behaviour as fairseq.Wav2Vec2Model
    model = Wav2Vec2Model.from_pretrained(ckpt_path).eval()
    return model
Beispiel #26
0
    def __init__(self, path, freeze=True):
        super().__init__()
        self.encoder = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
        self.freeze = freeze

        self.dense = nn.Sequential(nn.Linear(768, 128), nn.ReLU(),
                                   nn.Dropout(0.1), nn.Linear(128, 1))

        if self.freeze:
            self.encoder.eval()
            for p in self.encoder.parameters():
                p.requires_grad_(False)
        self.load_state_dict(
            extract_prefix('model.',
                           torch.load(path)['state_dict']))
        self.eval()
        self.cuda()
        self.processor = Wav2Vec2Processor.from_pretrained(
            "facebook/wav2vec2-base")
Beispiel #27
0
    def __init__(
        self,
        initial_vocab_tokens: List[str],
        model_name: str = "facebook/wav2vec2-base",
        gradient_checkpointing: bool = False,
        decoder_dropout: float = 0.1,
        learning_rate: float = 3e-4,
        **kwargs: Dict[str, Any],
    ):
        """Wav2Vec model for fine-tuning.

        Args:
            initial_vocab_tokens : List of tokens to be used in the vocab, special tokens should not be included here. Check [`docs`](https://scart97.github.io/thunder-speech/quick%20reference%20guide/#how-to-get-the-initial_vocab_tokens-from-my-dataset)
            model_name : Name of the original huggingface checkpoint to load from.
            gradient_checkpointing : Use gradient checkpointing to save memory at the expense of slower backward pass.
            decoder_dropout : Dropout before the final decoding layer
            learning_rate : Learning rate used on the optimizer.
            kwargs: Any other option that can be passed to the original Wav2Vec2Model.from_pretrained
        """
        super().__init__()
        self.save_hyperparameters()
        self.audio_transform = Wav2Vec2Preprocess()

        self.encoder = Wav2Vec2Model.from_pretrained(
            model_name,
            gradient_checkpointing=gradient_checkpointing,
            **kwargs,
        )
        self.encoder.feature_extractor._freeze_parameters()

        self.text_transform = self.build_text_transform(initial_vocab_tokens)
        self.decoder = self.build_decoder(
            decoder_dropout,
            self.encoder.config.hidden_size,
            len(self.text_transform.vocab),
        )

        # Metrics
        self.val_cer = CER()
        self.val_wer = WER()
        # Example input is one second of fake audio
        self.example_input_array = torch.randn((10, 16000))
Beispiel #28
0
 def get_encoder_decoder_model(self, config, decoder_config):
     encoder_model = Wav2Vec2Model(config).eval()
     decoder_model = Speech2Text2ForCausalLM(decoder_config).eval()
     return encoder_model, decoder_model
Beispiel #29
0
 def get_encoder_decoder_model(self, config, decoder_config):
     encoder_model = Wav2Vec2Model(config).eval()
     decoder_model = BertLMHeadModel(decoder_config).eval()
     return encoder_model, decoder_model
def convert_wav2vec2_checkpoint(checkpoint_path,
                                pytorch_dump_folder_path,
                                config_path=None,
                                dict_path=None,
                                is_finetuned=True):
    """
    Copy/paste/tweak model's weights to transformers design.
    """
    if config_path is not None:
        config = Wav2Vec2Config.from_pretrained(config_path)
    else:
        config = Wav2Vec2Config()

    if is_finetuned:
        if dict_path:
            target_dict = Dictionary.load(dict_path)

            config.bos_token_id = target_dict.bos_index
            config.eos_token_id = target_dict.eos_index
            config.pad_token_id = target_dict.pad_index
            config.vocab_size = len(target_dict.symbols)
            vocab_path = os.path.join(pytorch_dump_folder_path, "vocab.json")
            if not os.path.isdir(pytorch_dump_folder_path):
                logger.error(
                    "--pytorch_dump_folder_path ({}) should be a directory".
                    format(pytorch_dump_folder_path))
                return
            os.makedirs(pytorch_dump_folder_path, exist_ok=True)
            with open(vocab_path, "w", encoding="utf-8") as vocab_handle:
                json.dump(target_dict.indices, vocab_handle)
            tokenizer = Wav2Vec2CTCTokenizer(
                vocab_path,
                unk_token=target_dict.unk_word,
                pad_token=target_dict.pad_word,
                bos_token=target_dict.bos_word,
                eos_token=target_dict.eos_word,
                word_delimiter_token="|",
                do_lower_case=False,
            )
            return_attention_mask = True if config.feat_extract_norm == "layer" else False
            feature_extractor = Wav2Vec2FeatureExtractor(
                feature_size=1,
                sampling_rate=16000,
                padding_value=0,
                do_normalize=True,
                return_attention_mask=return_attention_mask,
            )
            processor = Wav2Vec2Processor(feature_extractor=feature_extractor,
                                          tokenizer=tokenizer)
            processor.save_pretrained(pytorch_dump_folder_path)

        hf_wav2vec = Wav2Vec2ForCTC(config)
    else:
        hf_wav2vec = Wav2Vec2Model(config)

    if is_finetuned:

        model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
            [checkpoint_path], arg_overrides={"data": dict_path})
    else:
        model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
            [checkpoint_path])

    model = model[0].eval()

    recursively_load_weights(model, hf_wav2vec, is_finetuned)

    hf_wav2vec.save_pretrained(pytorch_dump_folder_path)