def convert_unispeech_sat_checkpoint( checkpoint_path, pytorch_dump_folder_path, config_path=None, dict_path=None, is_finetuned=True ): """ Copy/paste/tweak model's weights to transformers design. """ if config_path is not None: config = UniSpeechSatConfig.from_pretrained(config_path) else: config = UniSpeechSatConfig() dict_path = "" if is_finetuned: hf_wav2vec = UniSpeechSatForCTC(config) else: hf_wav2vec = UniSpeechSatForPreTraining(config) model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task( [checkpoint_path], arg_overrides={"data": "/".join(dict_path.split("/")[:-1])} ) model = model[0].eval() recursively_load_weights(model, hf_wav2vec) hf_wav2vec.save_pretrained(pytorch_dump_folder_path)
def check_ctc_training(self, config, input_values, *args): config.ctc_zero_infinity = True model = UniSpeechSatForCTC(config=config) model.to(torch_device) model.train() # freeze feature encoder model.freeze_feature_encoder() input_values = input_values[:3] input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]] max_length_labels = model._get_feat_extract_output_lengths( torch.tensor(input_lengths)) labels = ids_tensor( (input_values.shape[0], max(max_length_labels) - 2), model.config.vocab_size) # pad input for i in range(len(input_lengths)): input_values[i, input_lengths[i]:] = 0.0 if max_length_labels[i] < labels.shape[-1]: # it's important that we make sure that target lenghts are at least # one shorter than logit lenghts to prevent -inf labels[i, max_length_labels[i] - 1:] = -100 loss = model(input_values, labels=labels).loss self.parent.assertFalse(torch.isinf(loss).item()) loss.backward()
def check_ctc_loss(self, config, input_values, *args): model = UniSpeechSatForCTC(config=config) model.to(torch_device) # make sure that dropout is disabled model.eval() input_values = input_values[:3] attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.long) input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]] max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths)) labels = ids_tensor((input_values.shape[0], min(max_length_labels) - 1), model.config.vocab_size) # pad input for i in range(len(input_lengths)): input_values[i, input_lengths[i] :] = 0.0 attention_mask[i, input_lengths[i] :] = 0 model.config.ctc_loss_reduction = "sum" sum_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss.item() model.config.ctc_loss_reduction = "mean" mean_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss.item() self.parent.assertTrue(isinstance(sum_loss, float)) self.parent.assertTrue(isinstance(mean_loss, float))
def test_mask_time_prob_ctc(self): model = UniSpeechSatForCTC.from_pretrained( "hf-internal-testing/tiny-random-unispeech-sat", mask_time_prob=0.2, mask_time_length=2 ) model.to(torch_device).train() processor = Wav2Vec2Processor.from_pretrained( "hf-internal-testing/tiny-random-unispeech-sat", return_attention_mask=True ) batch_duration_in_seconds = [1, 3, 2, 6] input_features = [np.random.random(16_000 * s) for s in batch_duration_in_seconds]
def check_labels_out_of_vocab(self, config, input_values, *args): model = UniSpeechSatForCTC(config) model.to(torch_device) model.train() input_values = input_values[:3] input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]] max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths)) labels = ids_tensor((input_values.shape[0], max(max_length_labels) - 2), model.config.vocab_size + 100) with pytest.raises(ValueError): model(input_values, labels=labels)