def test_from_pretrained_dynamic_processor(self):
        processor = AutoProcessor.from_pretrained(
            "hf-internal-testing/test_dynamic_processor",
            trust_remote_code=True)
        self.assertTrue(processor.special_attribute_present)
        self.assertEqual(processor.__class__.__name__, "NewProcessor")

        feature_extractor = processor.feature_extractor
        self.assertTrue(feature_extractor.special_attribute_present)
        self.assertEqual(feature_extractor.__class__.__name__,
                         "NewFeatureExtractor")

        tokenizer = processor.tokenizer
        self.assertTrue(tokenizer.special_attribute_present)
        if is_tokenizers_available():
            self.assertEqual(tokenizer.__class__.__name__, "NewTokenizerFast")

            # Test we can also load the slow version
            processor = AutoProcessor.from_pretrained(
                "hf-internal-testing/test_dynamic_processor",
                trust_remote_code=True,
                use_fast=False)
            tokenizer = processor.tokenizer
            self.assertTrue(tokenizer.special_attribute_present)
            self.assertEqual(tokenizer.__class__.__name__, "NewTokenizer")
        else:
            self.assertEqual(tokenizer.__class__.__name__, "NewTokenizer")
示例#2
0
    def test_new_processor_registration(self):
        try:
            AutoConfig.register("custom", CustomConfig)
            AutoFeatureExtractor.register(CustomConfig, CustomFeatureExtractor)
            AutoTokenizer.register(CustomConfig, slow_tokenizer_class=CustomTokenizer)
            AutoProcessor.register(CustomConfig, CustomProcessor)
            # Trying to register something existing in the Transformers library will raise an error
            with self.assertRaises(ValueError):
                AutoProcessor.register(Wav2Vec2Config, Wav2Vec2Processor)

            # Now that the config is registered, it can be used as any other config with the auto-API
            feature_extractor = CustomFeatureExtractor.from_pretrained(SAMPLE_PROCESSOR_CONFIG_DIR)

            with tempfile.TemporaryDirectory() as tmp_dir:
                vocab_file = os.path.join(tmp_dir, "vocab.txt")
                with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
                    vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
                tokenizer = CustomTokenizer(vocab_file)

            processor = CustomProcessor(feature_extractor, tokenizer)

            with tempfile.TemporaryDirectory() as tmp_dir:
                processor.save_pretrained(tmp_dir)
                new_processor = AutoProcessor.from_pretrained(tmp_dir)
                self.assertIsInstance(new_processor, CustomProcessor)

        finally:
            if "custom" in CONFIG_MAPPING._extra_content:
                del CONFIG_MAPPING._extra_content["custom"]
            if CustomConfig in FEATURE_EXTRACTOR_MAPPING._extra_content:
                del FEATURE_EXTRACTOR_MAPPING._extra_content[CustomConfig]
            if CustomConfig in TOKENIZER_MAPPING._extra_content:
                del TOKENIZER_MAPPING._extra_content[CustomConfig]
            if CustomConfig in PROCESSOR_MAPPING._extra_content:
                del PROCESSOR_MAPPING._extra_content[CustomConfig]
示例#3
0
    def __init__(
        self,
        backbone: str = "facebook/wav2vec2-base-960h",
        processor_backbone: str = None,
        optimizer: OPTIMIZER_TYPE = "Adam",
        lr_scheduler: LR_SCHEDULER_TYPE = None,
        learning_rate: Optional[float] = None,
    ):
        os.environ["TOKENIZERS_PARALLELISM"] = "TRUE"
        # disable HF thousand warnings
        warnings.simplefilter("ignore")
        # set os environ variable for multiprocesses
        os.environ["PYTHONWARNINGS"] = "ignore"

        model = self.backbones.get(backbone)()
        super().__init__(
            model=model,
            optimizer=optimizer,
            lr_scheduler=lr_scheduler,
            learning_rate=learning_rate,
            output_transform=SpeechRecognitionOutputTransform(backbone),
        )

        self.save_hyperparameters()

        self.collate_fn = DataCollatorCTCWithPadding(
            AutoProcessor.from_pretrained(backbone)
            if processor_backbone is None
            else AutoProcessor.from_pretrained(processor_backbone)
        )
示例#4
0
    def test_processor_from_local_directory_from_repo(self):
        with tempfile.TemporaryDirectory() as tmpdirname:
            model_config = Wav2Vec2Config()
            processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h")

            # save in new folder
            model_config.save_pretrained(tmpdirname)
            processor.save_pretrained(tmpdirname)

            processor = AutoProcessor.from_pretrained(tmpdirname)

        self.assertIsInstance(processor, Wav2Vec2Processor)
示例#5
0
    def test_processor_from_local_directory_from_extractor_config(self):
        with tempfile.TemporaryDirectory() as tmpdirname:
            # copy relevant files
            copyfile(SAMPLE_PROCESSOR_CONFIG, os.path.join(tmpdirname, FEATURE_EXTRACTOR_NAME))
            copyfile(SAMPLE_VOCAB, os.path.join(tmpdirname, "vocab.json"))

            processor = AutoProcessor.from_pretrained(tmpdirname)

        self.assertIsInstance(processor, Wav2Vec2Processor)
示例#6
0
    def test_processor_from_local_directory_from_model_config(self):
        with tempfile.TemporaryDirectory() as tmpdirname:
            model_config = Wav2Vec2Config(processor_class="Wav2Vec2Processor")
            model_config.save_pretrained(tmpdirname)
            # copy relevant files
            copyfile(SAMPLE_VOCAB, os.path.join(tmpdirname, "vocab.json"))
            # create emtpy sample processor
            with open(os.path.join(tmpdirname, FEATURE_EXTRACTOR_NAME), "w") as f:
                f.write("{}")

            processor = AutoProcessor.from_pretrained(tmpdirname)

        self.assertIsInstance(processor, Wav2Vec2Processor)
示例#7
0
    def test_push_to_hub_dynamic_processor(self):
        CustomFeatureExtractor.register_for_auto_class()
        CustomTokenizer.register_for_auto_class()
        CustomProcessor.register_for_auto_class()

        feature_extractor = CustomFeatureExtractor.from_pretrained(SAMPLE_PROCESSOR_CONFIG_DIR)

        with tempfile.TemporaryDirectory() as tmp_dir:
            vocab_file = os.path.join(tmp_dir, "vocab.txt")
            with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
                vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
            tokenizer = CustomTokenizer(vocab_file)

        processor = CustomProcessor(feature_extractor, tokenizer)

        with tempfile.TemporaryDirectory() as tmp_dir:
            repo = Repository(tmp_dir, clone_from=f"{USER}/test-dynamic-processor", use_auth_token=self._token)
            processor.save_pretrained(tmp_dir)

            # This has added the proper auto_map field to the feature extractor config
            self.assertDictEqual(
                processor.feature_extractor.auto_map,
                {
                    "AutoFeatureExtractor": "custom_feature_extraction.CustomFeatureExtractor",
                    "AutoProcessor": "custom_processing.CustomProcessor",
                },
            )

            # This has added the proper auto_map field to the tokenizer config
            with open(os.path.join(tmp_dir, "tokenizer_config.json")) as f:
                tokenizer_config = json.load(f)
            self.assertDictEqual(
                tokenizer_config["auto_map"],
                {
                    "AutoTokenizer": ["custom_tokenization.CustomTokenizer", None],
                    "AutoProcessor": "custom_processing.CustomProcessor",
                },
            )

            # The code has been copied from fixtures
            self.assertTrue(os.path.isfile(os.path.join(tmp_dir, "custom_feature_extraction.py")))
            self.assertTrue(os.path.isfile(os.path.join(tmp_dir, "custom_tokenization.py")))
            self.assertTrue(os.path.isfile(os.path.join(tmp_dir, "custom_processing.py")))

            repo.push_to_hub()

        new_processor = AutoProcessor.from_pretrained(f"{USER}/test-dynamic-processor", trust_remote_code=True)
        # Can't make an isinstance check because the new_processor is from the CustomProcessor class of a dynamic module
        self.assertEqual(new_processor.__class__.__name__, "CustomProcessor")
示例#8
0
    def test_processor_from_feat_extr_processor_class(self):
        with tempfile.TemporaryDirectory() as tmpdirname:
            feature_extractor = Wav2Vec2FeatureExtractor()
            tokenizer = AutoTokenizer.from_pretrained("facebook/wav2vec2-base-960h")

            processor = Wav2Vec2Processor(feature_extractor, tokenizer)

            # save in new folder
            processor.save_pretrained(tmpdirname)

            # drop `processor_class` in tokenizer
            with open(os.path.join(tmpdirname, TOKENIZER_CONFIG_FILE), "r") as f:
                config_dict = json.load(f)
                config_dict.pop("processor_class")

            with open(os.path.join(tmpdirname, TOKENIZER_CONFIG_FILE), "w") as f:
                f.write(json.dumps(config_dict))

            processor = AutoProcessor.from_pretrained(tmpdirname)

        self.assertIsInstance(processor, Wav2Vec2Processor)
示例#9
0
    def test_processor_from_auto_processor(self):
        processor_wav2vec2 = Wav2Vec2ProcessorWithLM.from_pretrained(
            "hf-internal-testing/processor_with_lm")
        processor_auto = AutoProcessor.from_pretrained(
            "hf-internal-testing/processor_with_lm")

        raw_speech = floats_list((3, 1000))

        input_wav2vec2 = processor_wav2vec2(raw_speech, return_tensors="np")
        input_auto = processor_auto(raw_speech, return_tensors="np")

        for key in input_wav2vec2.keys():
            self.assertAlmostEqual(input_wav2vec2[key].sum(),
                                   input_auto[key].sum(),
                                   delta=1e-2)

        logits = self._get_dummy_logits()

        decoded_wav2vec2 = processor_wav2vec2.batch_decode(logits)
        decoded_auto = processor_auto.batch_decode(logits)

        self.assertListEqual(decoded_wav2vec2.text, decoded_auto.text)
 def __init__(self, model_name, hotwords=[]):
     self.processor = AutoProcessor.from_pretrained(model_name)
     self.model = AutoModelForCTC.from_pretrained(model_name)
     self.hotwords = hotwords
示例#11
0
def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser(
        (ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(
            json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses(
        )

    # Detecting last checkpoint.
    last_checkpoint = None
    if os.path.isdir(
            training_args.output_dir
    ) and training_args.do_train and not training_args.overwrite_output_dir:
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        if last_checkpoint is None and len(os.listdir(
                training_args.output_dir)) > 0:
            raise ValueError(
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
                "Use --overwrite_output_dir to overcome.")
        elif last_checkpoint is not None:
            logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )
    logger.setLevel(logging.INFO if is_main_process(training_args.local_rank
                                                    ) else logging.WARN)

    # Log on each process the small summary:
    logger.warning(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
        f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
    )
    # Set the verbosity to info of the Transformers logger (on main process only):
    if is_main_process(training_args.local_rank):
        transformers.utils.logging.set_verbosity_info()
    logger.info("Training/evaluation parameters %s", training_args)

    # Set seed before initializing model.
    set_seed(training_args.seed)

    # 1. First, let's load the dataset
    raw_datasets = DatasetDict()
    task_name = data_args.task
    lang_id = data_args.language

    if task_name is None:
        raise ValueError("Set --task should be set to '<xtreme_s_task>' "
                         "(e.g. 'fleurs-asr', 'mls', 'covost2', 'minds14') ")
    if lang_id is None:
        raise ValueError(
            "Set --language should be set to the language id of the sub dataset "
            "config to be used (e.g. 'pl', 'en.tr', 'fr-FR') or 'all'"
            " for multi-lingual fine-tuning.")

    if data_args.target_column_name is None:
        target_column_name = TASK_TO_TARGET_COLUMN_NAME[task_name]
    else:
        target_column_name = data_args.target_column_name

    # here we differentiate between tasks with text as the target and classification tasks
    is_text_target = target_column_name in ("transcription", "translation")

    config_name = ".".join([task_name.split("-")[0], lang_id])

    if training_args.do_train:
        raw_datasets["train"] = load_dataset(
            data_args.dataset_name,
            config_name,
            split=data_args.train_split_name,
            use_auth_token=data_args.use_auth_token,
            cache_dir=model_args.cache_dir,
        )

        if data_args.audio_column_name not in raw_datasets[
                "train"].column_names:
            raise ValueError(
                f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
                "Make sure to set `--audio_column_name` to the correct audio column - one of "
                f"{', '.join(raw_datasets['train'].column_names)}.")

        if target_column_name not in raw_datasets["train"].column_names:
            raise ValueError(
                f"--target_column_name {target_column_name} not found in dataset '{data_args.dataset_name}'. "
                "Make sure to set `--target_column_name` to the correct text column - one of "
                f"{', '.join(raw_datasets['train'].column_names)}.")

        if data_args.max_train_samples is not None:
            raw_datasets["train"] = raw_datasets["train"].select(
                range(data_args.max_train_samples))

        if not is_text_target:
            label_list = raw_datasets["train"].features[
                target_column_name].names
            num_labels = len(label_list)

    if training_args.do_eval:
        raw_datasets["eval"] = load_dataset(
            data_args.dataset_name,
            config_name,
            split=data_args.eval_split_name,
            use_auth_token=data_args.use_auth_token,
            cache_dir=model_args.cache_dir,
        )

        if data_args.max_eval_samples is not None:
            raw_datasets["eval"] = raw_datasets["eval"].select(
                range(data_args.max_eval_samples))

    if training_args.do_predict:
        raw_datasets["predict"] = load_dataset(
            data_args.dataset_name,
            config_name,
            split=data_args.predict_split_name,
            use_auth_token=data_args.use_auth_token,
            cache_dir=model_args.cache_dir,
        )

        if data_args.max_predict_samples is not None:
            raw_datasets["predict"] = raw_datasets["predict"].select(
                range(data_args.max_predict_samples))

    # 2. We remove some special characters from the datasets
    # that make training complicated and do not help in transcribing the speech
    # E.g. characters, such as `,` and `.` do not really have an acoustic characteristic
    # that could be easily picked up by the model
    chars_to_ignore_regex = (f'[{"".join(data_args.chars_to_ignore)}]' if
                             data_args.chars_to_ignore is not None else None)

    def remove_special_characters(batch):
        if chars_to_ignore_regex is not None:
            batch["target_text"] = re.sub(
                chars_to_ignore_regex, "",
                batch[target_column_name]).lower() + " "
        else:
            batch["target_text"] = batch[target_column_name].lower() + " "
        return batch

    if is_text_target:
        with training_args.main_process_first(
                desc="dataset map special characters removal"):
            raw_datasets = raw_datasets.map(
                remove_special_characters,
                remove_columns=[target_column_name],
                desc="remove special characters from datasets",
            )

        # save special tokens for tokenizer
        word_delimiter_token = data_args.word_delimiter_token
        unk_token = data_args.unk_token
        pad_token = data_args.pad_token

    # 3. Next, let's load the config as we might need it to create
    # the tokenizer
    config = AutoConfig.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
        use_auth_token=data_args.use_auth_token)

    if is_text_target:
        # 4. (Optional, for ASR and translation) If no tokenizer file is defined,
        # we create the vocabulary of the model by extracting all unique characters from
        # the training and evaluation datasets
        # We need to make sure that only first rank saves vocabulary
        # make sure all processes wait until vocab is created
        tokenizer_name_or_path = model_args.tokenizer_name_or_path
        tokenizer_kwargs = {}
        if tokenizer_name_or_path is None:
            # save vocab in training output dir
            tokenizer_name_or_path = training_args.output_dir

            vocab_file = os.path.join(tokenizer_name_or_path, "vocab.json")

            with training_args.main_process_first():
                if training_args.overwrite_output_dir and os.path.isfile(
                        vocab_file):
                    os.remove(vocab_file)

            with training_args.main_process_first(
                    desc="dataset map vocabulary creation"):
                if not os.path.isfile(vocab_file):
                    os.makedirs(tokenizer_name_or_path, exist_ok=True)
                    vocab_dict = create_vocabulary_from_data(
                        raw_datasets,
                        word_delimiter_token=word_delimiter_token,
                        unk_token=unk_token,
                        pad_token=pad_token,
                    )

                    # save vocab dict to be loaded into tokenizer
                    with open(vocab_file, "w") as file:
                        json.dump(vocab_dict, file)

            # if tokenizer has just been created
            # it is defined by `tokenizer_class` if present in config else by `model_type`
            if not config.is_encoder_decoder:
                tokenizer_kwargs = {
                    "config":
                    config if config.tokenizer_class is not None else None,
                    "tokenizer_type":
                    config.model_type
                    if config.tokenizer_class is None else None,
                    "unk_token":
                    unk_token,
                    "pad_token":
                    pad_token,
                    "word_delimiter_token":
                    word_delimiter_token,
                }
            else:
                tokenizer_kwargs = {}

    # 5. Now we can instantiate the feature extractor, tokenizer and model
    # Note for distributed training, the .from_pretrained methods guarantee that only
    # one local process can concurrently download model & vocab.

    # load feature_extractor and tokenizer
    if is_text_target:
        tokenizer = AutoTokenizer.from_pretrained(
            tokenizer_name_or_path,
            use_auth_token=data_args.use_auth_token,
            **tokenizer_kwargs,
        )
    feature_extractor = AutoFeatureExtractor.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
        use_auth_token=data_args.use_auth_token)

    # adapt config
    config.update({
        "feat_proj_dropout": model_args.feat_proj_dropout,
        "attention_dropout": model_args.attention_dropout,
        "hidden_dropout": model_args.hidden_dropout,
        "final_dropout": model_args.final_dropout,
        "mask_time_prob": model_args.mask_time_prob,
        "mask_time_length": model_args.mask_time_length,
        "mask_feature_prob": model_args.mask_feature_prob,
        "mask_feature_length": model_args.mask_feature_length,
        "gradient_checkpointing": training_args.gradient_checkpointing,
        "layerdrop": model_args.layerdrop,
        "ctc_loss_reduction": model_args.ctc_loss_reduction,
        "activation_dropout": model_args.activation_dropout,
    })
    if training_args.do_train:
        if is_text_target:
            config.pad_token_id = tokenizer.pad_token_id
            config.vocab_size = len(tokenizer)
        else:
            label_to_id = {v: i for i, v in enumerate(label_list)}
            config.label2id = label_to_id
            config.id2label = {id: label for label, id in label_to_id.items()}
            config.num_labels = num_labels

    # create model
    if target_column_name == "transcription":
        model = AutoModelForCTC.from_pretrained(
            model_args.model_name_or_path,
            cache_dir=model_args.cache_dir,
            config=config,
            use_auth_token=data_args.use_auth_token,
        )
    elif config.is_encoder_decoder:
        model = AutoModelForSpeechSeq2Seq.from_pretrained(
            model_args.model_name_or_path,
            cache_dir=model_args.cache_dir,
            config=config,
            use_auth_token=data_args.use_auth_token,
        )
        if model.config.decoder_start_token_id is None:
            raise ValueError(
                "Make sure that `config.decoder_start_token_id` is correctly defined"
            )
    else:
        model = AutoModelForAudioClassification.from_pretrained(
            model_args.model_name_or_path,
            cache_dir=model_args.cache_dir,
            config=config,
            use_auth_token=data_args.use_auth_token,
        )

    # freeze encoder
    if model_args.freeze_feature_encoder:
        model.freeze_feature_encoder()

    # 6. Now we preprocess the datasets including loading the audio, resampling and normalization
    # Thankfully, `datasets` takes care of automatically loading and resampling the audio,
    # so that we just need to set the correct target sampling rate and normalize the input
    # via the `feature_extractor`

    # make sure that dataset decodes audio with correct sampling rate
    dataset_sampling_rate = next(iter(raw_datasets.values())).features[
        data_args.audio_column_name].sampling_rate
    if dataset_sampling_rate != feature_extractor.sampling_rate:
        raw_datasets = raw_datasets.cast_column(
            data_args.audio_column_name,
            datasets.features.Audio(
                sampling_rate=feature_extractor.sampling_rate))

    # derive max & min input length for sample rate & max duration
    max_input_length = data_args.max_duration_in_seconds * feature_extractor.sampling_rate
    min_input_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate
    audio_column_name = data_args.audio_column_name
    num_workers = data_args.preprocessing_num_workers

    # `phoneme_language` is only relevant if the model is fine-tuned on phoneme classification
    phoneme_language = data_args.phoneme_language

    # Preprocessing the datasets.
    # We need to read the audio files as arrays and tokenize the targets.
    def prepare_dataset(batch):
        # load audio
        sample = batch[audio_column_name]

        inputs = feature_extractor(sample["array"],
                                   sampling_rate=sample["sampling_rate"])
        batch["input_values"] = inputs.input_values[0]
        batch["length"] = len(batch["input_values"])

        # encode targets
        additional_kwargs = {}
        if phoneme_language is not None:
            additional_kwargs["phonemizer_lang"] = phoneme_language

        if is_text_target:
            batch["labels"] = tokenizer(batch["target_text"],
                                        **additional_kwargs).input_ids
        else:
            batch["labels"] = batch[target_column_name]
        return batch

    with training_args.main_process_first(desc="dataset map preprocessing"):
        vectorized_datasets = raw_datasets.map(
            prepare_dataset,
            remove_columns=next(iter(raw_datasets.values())).column_names,
            num_proc=num_workers,
            desc="preprocess datasets",
        )

        if training_args.do_train:

            def is_audio_in_length_range(length):
                return length > min_input_length and length < max_input_length

            # filter data that is shorter than min_input_length
            vectorized_datasets["train"] = vectorized_datasets["train"].filter(
                is_audio_in_length_range,
                num_proc=num_workers,
                input_columns=["length"],
            )

    # 7. Next, we can prepare for the training step.
    # Let's use the appropriate XTREME-S evaluation metric,
    # instantiate a data collator and the trainer

    # Define evaluation metrics during training, *i.e.* word error rate, character error rate
    eval_metric = load_metric("xtreme_s", task_name)

    # for large datasets it is advised to run the preprocessing on a
    # single machine first with ``args.preprocessing_only`` since there will mostly likely
    # be a timeout when running the script in distributed mode.
    # In a second step ``args.preprocessing_only`` can then be set to `False` to load the
    # cached dataset
    if data_args.preprocessing_only:
        logger.info(
            f"Data preprocessing finished. Files cached at {vectorized_datasets.cache_files}"
        )
        return

    def compute_asr_metric(pred):
        pred_logits = pred.predictions
        pred_ids = np.argmax(pred_logits, axis=-1)

        pred.label_ids[pred.label_ids == -100] = tokenizer.pad_token_id

        pred_str = tokenizer.batch_decode(pred_ids)
        # we do not want to group tokens when computing the metrics
        label_str = tokenizer.batch_decode(pred.label_ids, group_tokens=False)

        metric = eval_metric.compute(predictions=pred_str,
                                     references=label_str)
        return metric

    def compute_classification_metric(pred):
        pred_ids = np.argmax(pred.predictions, axis=1)
        metric = eval_metric.compute(predictions=pred_ids,
                                     references=pred.label_ids)
        return metric

    # Now save everything to be able to create a single processor later
    if is_main_process(training_args.local_rank):
        # save feature extractor, tokenizer and config
        feature_extractor.save_pretrained(training_args.output_dir)
        if is_text_target:
            tokenizer.save_pretrained(training_args.output_dir)
        config.save_pretrained(training_args.output_dir)
    # wait until configs are saved in the main process before loading the processor
    torch.distributed.barrier()

    if is_text_target:
        processor = AutoProcessor.from_pretrained(training_args.output_dir)
    else:
        processor = AutoFeatureExtractor.from_pretrained(
            training_args.output_dir)

    # Instantiate custom data collator
    data_collator = SpeechDataCollatorWithPadding(processor=processor,
                                                  pad_labels=is_text_target)

    # Initialize Trainer
    if target_column_name == "translation":
        trainer = Seq2SeqTrainer(
            model=model,
            data_collator=data_collator,
            args=training_args,
            compute_metrics=compute_asr_metric
            if training_args.predict_with_generate else None,
            train_dataset=vectorized_datasets["train"]
            if training_args.do_train else None,
            eval_dataset=vectorized_datasets["eval"]
            if training_args.do_eval else None,
            tokenizer=feature_extractor,
        )
    else:
        trainer = Trainer(
            model=model,
            data_collator=data_collator,
            args=training_args,
            compute_metrics=compute_asr_metric
            if is_text_target else compute_classification_metric,
            train_dataset=vectorized_datasets["train"]
            if training_args.do_train else None,
            eval_dataset=vectorized_datasets["eval"]
            if training_args.do_eval else None,
            tokenizer=feature_extractor,
        )

    # 8. Finally, we can start training

    # Training
    if training_args.do_train:

        # use last checkpoint if exist
        if last_checkpoint is not None:
            checkpoint = last_checkpoint
        elif os.path.isdir(model_args.model_name_or_path):
            checkpoint = model_args.model_name_or_path
        else:
            checkpoint = None

        train_result = trainer.train(resume_from_checkpoint=checkpoint)
        trainer.save_model()

        metrics = train_result.metrics
        max_train_samples = (data_args.max_train_samples
                             if data_args.max_train_samples is not None else
                             len(vectorized_datasets["train"]))
        metrics["train_samples"] = min(max_train_samples,
                                       len(vectorized_datasets["train"]))

        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()

    # Evaluation on the test set
    results = {}
    if training_args.do_predict:
        logger.info(
            f"*** Evaluating on the `{data_args.predict_split_name}` set ***")
        metrics = trainer.evaluate(vectorized_datasets["predict"])
        max_predict_samples = (data_args.max_predict_samples
                               if data_args.max_predict_samples is not None
                               else len(vectorized_datasets["predict"]))
        metrics["predict_samples"] = min(max_predict_samples,
                                         len(vectorized_datasets["predict"]))

        trainer.log_metrics("predict", metrics)
        trainer.save_metrics("predict", metrics)

    # Write model card and (optionally) push to hub
    kwargs = {
        "finetuned_from": model_args.model_name_or_path,
        "tasks": task_name,
        "tags": [task_name, data_args.dataset_name],
        "dataset_args":
        f"Config: {config_name}, Training split: {data_args.train_split_name}, Eval split: {data_args.eval_split_name}, Predict split: {data_args.predict_split_name}",
        "dataset": f"{data_args.dataset_name.upper()} - {config_name.upper()}",
        "language": data_args.language,
    }

    if training_args.push_to_hub:
        trainer.push_to_hub(**kwargs)
    else:
        trainer.create_model_card(**kwargs)

    return results
示例#12
0
def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser(
        (ModelArguments, DataTrainingArguments, TrainingArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(
            json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses(
        )

    # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
    # information sent is the one passed as arguments along with your Python/PyTorch versions.
    send_example_telemetry("run_speech_recognition_ctc", model_args, data_args)

    # Detecting last checkpoint.
    last_checkpoint = None
    if os.path.isdir(
            training_args.output_dir
    ) and training_args.do_train and not training_args.overwrite_output_dir:
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        if last_checkpoint is None and len(os.listdir(
                training_args.output_dir)) > 0:
            raise ValueError(
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
                "Use --overwrite_output_dir to overcome.")
        elif last_checkpoint is not None:
            logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )
    logger.setLevel(logging.INFO if is_main_process(training_args.local_rank
                                                    ) else logging.WARN)

    # Log on each process the small summary:
    logger.warning(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
        f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
    )
    # Set the verbosity to info of the Transformers logger (on main process only):
    if is_main_process(training_args.local_rank):
        transformers.utils.logging.set_verbosity_info()
    logger.info("Training/evaluation parameters %s", training_args)

    # Set seed before initializing model.
    set_seed(training_args.seed)

    # 1. First, let's load the dataset
    raw_datasets = DatasetDict()

    if training_args.do_train:
        raw_datasets["train"] = load_dataset(
            data_args.dataset_name,
            data_args.dataset_config_name,
            split=data_args.train_split_name,
            use_auth_token=data_args.use_auth_token,
        )

        if data_args.audio_column_name not in raw_datasets[
                "train"].column_names:
            raise ValueError(
                f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'."
                " Make sure to set `--audio_column_name` to the correct audio column - one of"
                f" {', '.join(raw_datasets['train'].column_names)}.")

        if data_args.text_column_name not in raw_datasets[
                "train"].column_names:
            raise ValueError(
                f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. "
                "Make sure to set `--text_column_name` to the correct text column - one of "
                f"{', '.join(raw_datasets['train'].column_names)}.")

        if data_args.max_train_samples is not None:
            raw_datasets["train"] = raw_datasets["train"].select(
                range(data_args.max_train_samples))

    if training_args.do_eval:
        raw_datasets["eval"] = load_dataset(
            data_args.dataset_name,
            data_args.dataset_config_name,
            split=data_args.eval_split_name,
            use_auth_token=data_args.use_auth_token,
        )

        if data_args.max_eval_samples is not None:
            raw_datasets["eval"] = raw_datasets["eval"].select(
                range(data_args.max_eval_samples))

    # 2. We remove some special characters from the datasets
    # that make training complicated and do not help in transcribing the speech
    # E.g. characters, such as `,` and `.` do not really have an acoustic characteristic
    # that could be easily picked up by the model
    chars_to_ignore_regex = (f'[{"".join(data_args.chars_to_ignore)}]' if
                             data_args.chars_to_ignore is not None else None)
    text_column_name = data_args.text_column_name

    def remove_special_characters(batch):
        if chars_to_ignore_regex is not None:
            batch["target_text"] = re.sub(
                chars_to_ignore_regex, "",
                batch[text_column_name]).lower() + " "
        else:
            batch["target_text"] = batch[text_column_name].lower() + " "
        return batch

    with training_args.main_process_first(
            desc="dataset map special characters removal"):
        raw_datasets = raw_datasets.map(
            remove_special_characters,
            remove_columns=[text_column_name],
            desc="remove special characters from datasets",
        )

    # save special tokens for tokenizer
    word_delimiter_token = data_args.word_delimiter_token
    unk_token = data_args.unk_token
    pad_token = data_args.pad_token

    # 3. Next, let's load the config as we might need it to create
    # the tokenizer
    # load config
    config = AutoConfig.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
        use_auth_token=data_args.use_auth_token)

    # 4. Next, if no tokenizer file is defined,
    # we create the vocabulary of the model by extracting all unique characters from
    # the training and evaluation datasets
    # We need to make sure that only first rank saves vocabulary
    # make sure all processes wait until vocab is created
    tokenizer_name_or_path = model_args.tokenizer_name_or_path
    tokenizer_kwargs = {}
    if tokenizer_name_or_path is None:
        # save vocab in training output dir
        tokenizer_name_or_path = training_args.output_dir

        vocab_file = os.path.join(tokenizer_name_or_path, "vocab.json")

        with training_args.main_process_first():
            if training_args.overwrite_output_dir and os.path.isfile(
                    vocab_file):
                try:
                    os.remove(vocab_file)
                except OSError:
                    # in shared file-systems it might be the case that
                    # two processes try to delete the vocab file at the some time
                    pass

        with training_args.main_process_first(
                desc="dataset map vocabulary creation"):
            if not os.path.isfile(vocab_file):
                os.makedirs(tokenizer_name_or_path, exist_ok=True)
                vocab_dict = create_vocabulary_from_data(
                    raw_datasets,
                    word_delimiter_token=word_delimiter_token,
                    unk_token=unk_token,
                    pad_token=pad_token,
                )

                # save vocab dict to be loaded into tokenizer
                with open(vocab_file, "w") as file:
                    json.dump(vocab_dict, file)

        # if tokenizer has just been created
        # it is defined by `tokenizer_class` if present in config else by `model_type`
        tokenizer_kwargs = {
            "config":
            config if config.tokenizer_class is not None else None,
            "tokenizer_type":
            config.model_type if config.tokenizer_class is None else None,
            "unk_token":
            unk_token,
            "pad_token":
            pad_token,
            "word_delimiter_token":
            word_delimiter_token,
        }

    # 5. Now we can instantiate the feature extractor, tokenizer and model
    # Note for distributed training, the .from_pretrained methods guarantee that only
    # one local process can concurrently download model & vocab.

    # load feature_extractor and tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        tokenizer_name_or_path,
        use_auth_token=data_args.use_auth_token,
        **tokenizer_kwargs,
    )
    feature_extractor = AutoFeatureExtractor.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
        use_auth_token=data_args.use_auth_token)

    # adapt config
    config.update({
        "feat_proj_dropout": model_args.feat_proj_dropout,
        "attention_dropout": model_args.attention_dropout,
        "hidden_dropout": model_args.hidden_dropout,
        "final_dropout": model_args.final_dropout,
        "mask_time_prob": model_args.mask_time_prob,
        "mask_time_length": model_args.mask_time_length,
        "mask_feature_prob": model_args.mask_feature_prob,
        "mask_feature_length": model_args.mask_feature_length,
        "gradient_checkpointing": training_args.gradient_checkpointing,
        "layerdrop": model_args.layerdrop,
        "ctc_loss_reduction": model_args.ctc_loss_reduction,
        "pad_token_id": tokenizer.pad_token_id,
        "vocab_size": len(tokenizer),
        "activation_dropout": model_args.activation_dropout,
    })

    # create model
    model = AutoModelForCTC.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
        config=config,
        use_auth_token=data_args.use_auth_token,
    )

    # freeze encoder
    if model_args.freeze_feature_encoder:
        model.freeze_feature_encoder()

    # 6. Now we preprocess the datasets including loading the audio, resampling and normalization
    # Thankfully, `datasets` takes care of automatically loading and resampling the audio,
    # so that we just need to set the correct target sampling rate and normalize the input
    # via the `feature_extractor`

    # make sure that dataset decodes audio with correct sampling rate
    dataset_sampling_rate = next(iter(raw_datasets.values())).features[
        data_args.audio_column_name].sampling_rate
    if dataset_sampling_rate != feature_extractor.sampling_rate:
        raw_datasets = raw_datasets.cast_column(
            data_args.audio_column_name,
            datasets.features.Audio(
                sampling_rate=feature_extractor.sampling_rate))

    # derive max & min input length for sample rate & max duration
    max_input_length = data_args.max_duration_in_seconds * feature_extractor.sampling_rate
    min_input_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate
    audio_column_name = data_args.audio_column_name
    num_workers = data_args.preprocessing_num_workers

    # `phoneme_language` is only relevant if the model is fine-tuned on phoneme classification
    phoneme_language = data_args.phoneme_language

    # Preprocessing the datasets.
    # We need to read the audio files as arrays and tokenize the targets.
    def prepare_dataset(batch):
        # load audio
        sample = batch[audio_column_name]

        inputs = feature_extractor(sample["array"],
                                   sampling_rate=sample["sampling_rate"])
        batch["input_values"] = inputs.input_values[0]
        batch["input_length"] = len(batch["input_values"])

        # encode targets
        additional_kwargs = {}
        if phoneme_language is not None:
            additional_kwargs["phonemizer_lang"] = phoneme_language

        batch["labels"] = tokenizer(batch["target_text"],
                                    **additional_kwargs).input_ids
        return batch

    with training_args.main_process_first(desc="dataset map preprocessing"):
        vectorized_datasets = raw_datasets.map(
            prepare_dataset,
            remove_columns=next(iter(raw_datasets.values())).column_names,
            num_proc=num_workers,
            desc="preprocess datasets",
        )

        def is_audio_in_length_range(length):
            return length > min_input_length and length < max_input_length

        # filter data that is shorter than min_input_length
        vectorized_datasets = vectorized_datasets.filter(
            is_audio_in_length_range,
            num_proc=num_workers,
            input_columns=["input_length"],
        )

    # 7. Next, we can prepare the training.
    # Let's use word error rate (WER) as our evaluation metric,
    # instantiate a data collator and the trainer

    # Define evaluation metrics during training, *i.e.* word error rate, character error rate
    eval_metrics = {
        metric: evaluate.load(metric)
        for metric in data_args.eval_metrics
    }

    # for large datasets it is advised to run the preprocessing on a
    # single machine first with ``args.preprocessing_only`` since there will mostly likely
    # be a timeout when running the script in distributed mode.
    # In a second step ``args.preprocessing_only`` can then be set to `False` to load the
    # cached dataset
    if data_args.preprocessing_only:
        logger.info(
            f"Data preprocessing finished. Files cached at {vectorized_datasets.cache_files}"
        )
        return

    def compute_metrics(pred):
        pred_logits = pred.predictions
        pred_ids = np.argmax(pred_logits, axis=-1)

        pred.label_ids[pred.label_ids == -100] = tokenizer.pad_token_id

        pred_str = tokenizer.batch_decode(pred_ids)
        # we do not want to group tokens when computing the metrics
        label_str = tokenizer.batch_decode(pred.label_ids, group_tokens=False)

        metrics = {
            k: v.compute(predictions=pred_str, references=label_str)
            for k, v in eval_metrics.items()
        }

        return metrics

    # Now save everything to be able to create a single processor later
    if is_main_process(training_args.local_rank):
        # save feature extractor, tokenizer and config
        feature_extractor.save_pretrained(training_args.output_dir)
        tokenizer.save_pretrained(training_args.output_dir)
        config.save_pretrained(training_args.output_dir)

    try:
        processor = AutoProcessor.from_pretrained(training_args.output_dir)
    except (OSError, KeyError):
        warnings.warn(
            "Loading a processor from a feature extractor config that does not"
            " include a `processor_class` attribute is deprecated and will be removed in v5. Please add the following "
            " attribute to your `preprocessor_config.json` file to suppress this warning: "
            " `'processor_class': 'Wav2Vec2Processor'`",
            FutureWarning,
        )
        processor = Wav2Vec2Processor.from_pretrained(training_args.output_dir)

    # Instantiate custom data collator
    data_collator = DataCollatorCTCWithPadding(processor=processor)

    # Initialize Trainer
    trainer = Trainer(
        model=model,
        data_collator=data_collator,
        args=training_args,
        compute_metrics=compute_metrics,
        train_dataset=vectorized_datasets["train"]
        if training_args.do_train else None,
        eval_dataset=vectorized_datasets["eval"]
        if training_args.do_eval else None,
        tokenizer=feature_extractor,
    )

    # 8. Finally, we can start training

    # Training
    if training_args.do_train:

        # use last checkpoint if exist
        if last_checkpoint is not None:
            checkpoint = last_checkpoint
        elif os.path.isdir(model_args.model_name_or_path):
            checkpoint = model_args.model_name_or_path
        else:
            checkpoint = None

        train_result = trainer.train(resume_from_checkpoint=checkpoint)
        trainer.save_model()

        metrics = train_result.metrics
        max_train_samples = (data_args.max_train_samples
                             if data_args.max_train_samples is not None else
                             len(vectorized_datasets["train"]))
        metrics["train_samples"] = min(max_train_samples,
                                       len(vectorized_datasets["train"]))

        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()

    # Evaluation
    results = {}
    if training_args.do_eval:
        logger.info("*** Evaluate ***")
        metrics = trainer.evaluate()
        max_eval_samples = (data_args.max_eval_samples
                            if data_args.max_eval_samples is not None else len(
                                vectorized_datasets["eval"]))
        metrics["eval_samples"] = min(max_eval_samples,
                                      len(vectorized_datasets["eval"]))

        trainer.log_metrics("eval", metrics)
        trainer.save_metrics("eval", metrics)

    # Write model card and (optionally) push to hub
    config_name = data_args.dataset_config_name if data_args.dataset_config_name is not None else "na"
    kwargs = {
        "finetuned_from":
        model_args.model_name_or_path,
        "tasks":
        "speech-recognition",
        "tags": ["automatic-speech-recognition", data_args.dataset_name],
        "dataset_args":
        (f"Config: {config_name}, Training split: {data_args.train_split_name}, Eval split:"
         f" {data_args.eval_split_name}"),
        "dataset":
        f"{data_args.dataset_name.upper()} - {config_name.upper()}",
    }
    if "common_voice" in data_args.dataset_name:
        kwargs["language"] = config_name

    if training_args.push_to_hub:
        trainer.push_to_hub(**kwargs)
    else:
        trainer.create_model_card(**kwargs)

    return results
 def test_auto_processor_reverts_to_feature_extractor(self):
     processor = AutoProcessor.from_pretrained(
         "microsoft/beit-base-patch16-224")
     self.assertIsInstance(processor, BeitFeatureExtractor)
 def test_auto_processor_reverts_to_tokenizer(self):
     processor = AutoProcessor.from_pretrained("bert-base-cased")
     self.assertIsInstance(processor, BertTokenizerFast)
示例#15
0
    def test_word_time_stamp_integration(self):
        import torch

        ds = load_dataset("common_voice", "en", split="train", streaming=True)
        ds = ds.cast_column("audio", datasets.Audio(sampling_rate=16_000))
        ds_iter = iter(ds)
        sample = next(ds_iter)

        processor = AutoProcessor.from_pretrained(
            "patrickvonplaten/wav2vec2-base-100h-with-lm")
        model = Wav2Vec2ForCTC.from_pretrained(
            "patrickvonplaten/wav2vec2-base-100h-with-lm")

        # compare to filename `common_voice_en_100038.mp3` of dataset viewer on https://huggingface.co/datasets/common_voice/viewer/en/train
        input_values = processor(sample["audio"]["array"],
                                 return_tensors="pt").input_values

        with torch.no_grad():
            logits = model(input_values).logits.cpu().numpy()

        output = processor.decode(logits[0], output_word_offsets=True)

        time_offset = model.config.inputs_to_logits_ratio / processor.feature_extractor.sampling_rate
        word_time_stamps = [{
            "start_time": d["start_offset"] * time_offset,
            "end_time": d["end_offset"] * time_offset,
            "word": d["word"],
        } for d in output["word_offsets"]]

        EXPECTED_TEXT = "WHY DOES A MILE SANDRA LOOK LIKE SHE WANTS TO CONSUME JOHN SNOW ON THE RIVER AT THE WALL"

        # output words
        self.assertEqual(
            " ".join(self.get_from_offsets(word_time_stamps, "word")),
            EXPECTED_TEXT)
        self.assertEqual(
            " ".join(self.get_from_offsets(word_time_stamps, "word")),
            output.text)

        # output times
        start_times = [
            round(x, 2)
            for x in self.get_from_offsets(word_time_stamps, "start_time")
        ]
        end_times = [
            round(x, 2)
            for x in self.get_from_offsets(word_time_stamps, "end_time")
        ]

        # fmt: off
        self.assertListEqual(
            start_times,
            [
                1.42,
                1.64,
                2.12,
                2.26,
                2.54,
                3.0,
                3.24,
                3.6,
                3.8,
                4.1,
                4.26,
                4.94,
                5.28,
                5.66,
                5.78,
                5.94,
                6.32,
                6.54,
                6.66,
            ],
        )

        self.assertListEqual(
            end_times,
            [
                1.54,
                1.88,
                2.14,
                2.46,
                2.9,
                3.18,
                3.54,
                3.72,
                4.02,
                4.18,
                4.76,
                5.16,
                5.56,
                5.7,
                5.86,
                6.2,
                6.38,
                6.62,
                6.94,
            ],
        )
示例#16
0
def main():
    raw_datasets = DatasetDict()

    if training_args.do_train:
        raw_datasets["train"] = load_dataset(
            data_args.dataset_name,
            data_args.dataset_config,
            split=data_args.train_split_name,
            use_auth_token=data_args.use_auth_token,
        )

        if data_args.audio_column not in raw_datasets["train"].column_names:
            raise ValueError(
                f"--audio_column '{data_args.audio_column}' not found in dataset '{data_args.dataset_name}'. "
                "Make sure to set `--audio_column` to the correct audio column - one of "
                f"{', '.join(raw_datasets['train'].column_names)}."
            )

        if data_args.text_column not in raw_datasets["train"].column_names:
            raise ValueError(
                f"--text_column {data_args.text_column} not found in dataset '{data_args.dataset_name}'. "
                "Make sure to set `--text_column` to the correct text column - one of "
                f"{', '.join(raw_datasets['train'].column_names)}."
            )

        if data_args.max_train_samples is not None:
            raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))

    if training_args.do_eval:
        raw_datasets["eval"] = load_dataset(
            data_args.dataset_name,
            data_args.dataset_config,
            split=data_args.eval_split_name,
            use_auth_token=data_args.use_auth_token,
        )

        if data_args.max_eval_samples is not None:
            raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))

    # 2. We remove some special characters from the datasets
    # that make training complicated and do not help in transcribing the speech
    # E.g. characters, such as `,` and `.` do not really have an acoustic characteristic
    # that could be easily picked up by the model
    chars_to_ignore_regex = (
        f'[{"".join(data_args.chars_to_ignore)}]' if data_args.chars_to_ignore is not None else None
    )
    text_column = data_args.text_column

    def remove_special_characters(batch):
        if chars_to_ignore_regex is not None:
            batch["target_text"] = (
                re.sub(chars_to_ignore_regex, "", batch[text_column]).lower() + " "
            )
        else:
            batch["target_text"] = batch[text_column].lower() + " "
        return batch

    with training_args.main_process_first(desc="dataset map special characters removal"):
        raw_datasets = raw_datasets.map(
            remove_special_characters,
            remove_columns=[text_column],
            desc="remove special characters from datasets",
        )

    # save special tokens for tokenizer
    word_delimiter_token = data_args.word_delimiter_token
    unk = data_args.unk
    pad = data_args.pad

    # 3. Next, let's load the config as we might need it to create
    # the tokenizer
    # load config
    config = AutoConfig.from_pretrained(
        model_args.model_name,
        cache_dir=model_args.cache_dir,
        use_auth_token=data_args.use_auth_token,
    )

    # 4. Next, if no tokenizer file is defined,
    # we create the vocabulary of the model by extracting all unique characters from
    # the training and evaluation datasets
    # We need to make sure that only first rank saves vocabulary
    # make sure all processes wait until vocab is created
    tokenizer_name_or_path = model_args.tokenizer_name_or_path
    tokenizer_kw = {}
    if tokenizer_name_or_path is None:
        # save vocab in training output dir
        tokenizer_name_or_path = training_args.out_dir

        vocab_file = os.path.join(tokenizer_name_or_path, "vocab.json")

        with training_args.main_process_first():
            if training_args.overwrite_out_dir and os.path.isfile(vocab_file):
                os.remove(vocab_file)

        with training_args.main_process_first(desc="dataset map vocabulary creation"):
            if not os.path.isfile(vocab_file):
                os.makedirs(tokenizer_name_or_path, exist_ok=True)
                vocab_dict = create_vocabulary_from_data(
                    raw_datasets,
                    word_delimiter_token=word_delimiter_token,
                    unk=unk,
                    pad=pad,
                )

                # save vocab dict to be loaded into tokenizer
                with open(vocab_file, "w") as file:
                    json.dump(vocab_dict, file)

        # if tokenizer has just been created
        # it is defined by `tokenizer_class` if present in config else by `model_type`
        tokenizer_kw = {
            "config": config if config.tokenizer_class is not None else None,
            "tokenizer_type": config.model_type if config.tokenizer_class is None else None,
            "unk": unk,
            "pad": pad,
            "word_delimiter_token": word_delimiter_token,
        }

    tokenizer = AutoTokenizer.from_pretrained(
        tokenizer_name_or_path,
        use_auth_token=data_args.use_auth_token,
        **tokenizer_kw,
    )
    feature_extractor = AutoFeatureExtractor.from_pretrained(
        model_args.model_name,
        cache_dir=model_args.cache_dir,
        use_auth_token=data_args.use_auth_token,
    )

    # adapt config
    config.update(
        {
            "feat_proj_dropout": model_args.feat_proj_dropout,
            "drop_attn": model_args.drop_attn,
            "drop": model_args.drop,
            "final_dropout": model_args.final_dropout,
            "mask_time_prob": model_args.mask_time_prob,
            "mask_time_length": model_args.mask_time_length,
            "mask_feature_prob": model_args.mask_feature_prob,
            "mask_feature_length": model_args.mask_feature_length,
            "grad_checkpoint": training_args.grad_checkpoint,
            "layerdrop": model_args.layerdrop,
            "ctc_loss_reduction": model_args.ctc_loss_reduction,
            "PAD": tokenizer.PAD,
            "s_vocab": len(tokenizer),
            "drop_act": model_args.drop_act,
        }
    )

    # create model
    model = AutoModelForCTC.from_pretrained(
        model_args.model_name,
        cache_dir=model_args.cache_dir,
        config=config,
        use_auth_token=data_args.use_auth_token,
    )

    # freeze encoder
    if model_args.freeze_feature_encoder:
        model.freeze_feature_encoder()

    # 6. Now we preprocess the datasets including loading the audio, resampling and normalization
    # Thankfully, `datasets` takes care of automatically loading and resampling the audio,
    # so that we just need to set the correct target sampling rate and normalize the input
    # via the `feature_extractor`

    # make sure that dataset decodes audio with correct sampling rate
    dataset_sampling_rate = (
        next(iter(raw_datasets.values())).features[data_args.audio_column].sampling_rate
    )
    if dataset_sampling_rate != feature_extractor.sampling_rate:
        raw_datasets = raw_datasets.cast_column(
            data_args.audio_column,
            datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate),
        )

    # derive max & min input length for sample rate & max duration
    max_input_length = data_args.max_duration * feature_extractor.sampling_rate
    min_input_length = data_args.min_duration * feature_extractor.sampling_rate
    audio_column = data_args.audio_column
    num_workers = data_args.num_workers

    # `phoneme_language` is only relevant if the model is fine-tuned on phoneme classification
    phoneme_language = data_args.phoneme_language

    # Preprocessing the datasets.
    # We need to read the audio files as arrays and tokenize the targets.
    def prepare_dataset(batch):
        # load audio
        sample = batch[audio_column]

        inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
        batch["input_values"] = inputs.input_values[0]
        batch["input_length"] = len(batch["input_values"])

        # encode targets
        additional_kw = {}
        if phoneme_language is not None:
            additional_kw["phonemizer_lang"] = phoneme_language

        batch["labels"] = tokenizer(batch["target_text"], **additional_kw).input_ids
        return batch

    with training_args.main_process_first(desc="dataset map preprocessing"):
        vectorized_datasets = raw_datasets.map(
            prepare_dataset,
            remove_columns=next(iter(raw_datasets.values())).column_names,
            num_proc=num_workers,
            desc="preprocess datasets",
        )

        def is_audio_in_length_range(length):
            return length > min_input_length and length < max_input_length

        # filter data that is shorter than min_input_length
        vectorized_datasets = vectorized_datasets.filter(
            is_audio_in_length_range,
            num_proc=num_workers,
            input_columns=["input_length"],
        )

    # 7. Next, we can prepare the training.
    # Let's use word error rate (WER) as our evaluation metric,
    # instantiate a data collator and the trainer

    # Define evaluation metrics during training, *i.e.* word error rate, character error rate
    eval_metrics = {metric: load_metric(metric) for metric in data_args.eval_metrics}

    # for large datasets it is advised to run the preprocessing on a
    # single machine first with ``args.preprocessing_only`` since there will mostly likely
    # be a timeout when running the script in distributed mode.
    # In a second step ``args.preprocessing_only`` can then be set to `False` to load the
    # cached dataset
    if data_args.preprocessing_only:
        logger.info(
            f"Data preprocessing finished. Files cached at {vectorized_datasets.cache_files}"
        )
        return

    def compute_metrics(pred):
        pred_logits = pred.predictions
        pred_ids = np.argmax(pred_logits, axis=-1)

        pred.label_ids[pred.label_ids == -100] = tokenizer.PAD

        pred_str = tokenizer.batch_decode(pred_ids)
        # we do not want to group tokens when computing the metrics
        label_str = tokenizer.batch_decode(pred.label_ids, group_tokens=False)

        metrics = {
            k: v.compute(predictions=pred_str, references=label_str)
            for k, v in eval_metrics.items()
        }

        return metrics

    # Now save everything to be able to create a single processor later
    if is_main_process(training_args.local_rank):
        # save feature extractor, tokenizer and config
        feature_extractor.save_pretrained(training_args.out_dir)
        tokenizer.save_pretrained(training_args.out_dir)
        config.save_pretrained(training_args.out_dir)

    try:
        processor = AutoProcessor.from_pretrained(training_args.out_dir)
    except (OSError, KeyError):
        warnings.warn(
            "Loading a processor from a feature extractor config that does not"
            " include a `processor_class` attribute is deprecated and will be removed in v5. Please add the following "
            " attribute to your `preprocessor_config.json` file to suppress this warning: "
            " `'processor_class': 'Wav2Vec2Processor'`",
            FutureWarning,
        )
        processor = Wav2Vec2Processor.from_pretrained(training_args.out_dir)

    # Instantiate custom data collator
    data_collator = DataCollatorCTCWithPadding(processor=processor)

    # Initialize Trainer
    trainer = Trainer(
        model=model,
        data_collator=data_collator,
        args=training_args,
        compute_metrics=compute_metrics,
        train_dataset=vectorized_datasets["train"] if training_args.do_train else None,
        eval_dataset=vectorized_datasets["eval"] if training_args.do_eval else None,
        tokenizer=feature_extractor,
    )

    # 8. Finally, we can start training

    # Training
    if training_args.do_train:

        # use last checkpoint if exist
        if last_checkpoint is not None:
            checkpoint = last_checkpoint
        elif os.path.isdir(model_args.model_name):
            checkpoint = model_args.model_name
        else:
            checkpoint = None

        train_result = trainer.train(resume_from_checkpoint=checkpoint)
        trainer.save_model()

        metrics = train_result.metrics
        max_train_samples = (
            data_args.max_train_samples
            if data_args.max_train_samples is not None
            else len(vectorized_datasets["train"])
        )
        metrics["train_samples"] = min(max_train_samples, len(vectorized_datasets["train"]))

        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()

    # Evaluation
    results = {}
    if training_args.do_eval:
        logger.info("*** Evaluate ***")
        metrics = trainer.evaluate()
        max_eval_samples = (
            data_args.max_eval_samples
            if data_args.max_eval_samples is not None
            else len(vectorized_datasets["eval"])
        )
        metrics["eval_samples"] = min(max_eval_samples, len(vectorized_datasets["eval"]))

        trainer.log_metrics("eval", metrics)
        trainer.save_metrics("eval", metrics)

    # Write model card and (optionally) push to hub
    config_name = data_args.dataset_config if data_args.dataset_config is not None else "na"
    kw = {
        "finetuned_from": model_args.model_name,
        "tasks": "speech-recognition",
        "tags": ["automatic-speech-recognition", data_args.dataset_name],
        "dataset_args": f"Config: {config_name}, Training split: {data_args.train_split_name}, Eval split: {data_args.eval_split_name}",
        "dataset": f"{data_args.dataset_name.upper()} - {config_name.upper()}",
    }
    if "common_voice" in data_args.dataset_name:
        kw["language"] = config_name

    if training_args.push_to_hub:
        trainer.push_to_hub(**kw)
    else:
        trainer.create_model_card(**kw)

    return results
 def test_processor_from_model_shortcut(self):
     processor = AutoProcessor.from_pretrained(
         "facebook/wav2vec2-base-960h")
     self.assertIsInstance(processor, Wav2Vec2Processor)
示例#18
0
def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser(
        (ModelArguments, DataTrainingArguments, TrainingArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(
            json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses(
        )

    # Detecting last checkpoint.
    last_checkpoint = None
    if os.path.isdir(
            training_args.output_dir
    ) and training_args.do_train and not training_args.overwrite_output_dir:
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        if last_checkpoint is None and len(os.listdir(
                training_args.output_dir)) > 0:
            raise ValueError(
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
                "Use --overwrite_output_dir to overcome.")
        elif last_checkpoint is not None:
            logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )
    logger.setLevel(logging.INFO if is_main_process(training_args.local_rank
                                                    ) else logging.WARN)

    # Log on each process the small summary:
    logger.warning(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
        f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
    )
    # Set the verbosity to info of the Transformers logger (on main process only):
    if is_main_process(training_args.local_rank):
        transformers.utils.logging.set_verbosity_info()
    logger.info("Training/evaluation parameters %s", training_args)

    # Set seed before initializing model.
    set_seed(training_args.seed)

    # 1. First, let's load the dataset
    raw_datasets = IterableDatasetDict()
    raw_column_names = {}

    def load_streaming_dataset(split, sampling_rate, **kwargs):
        if "+" in split:
            dataset_splits = [
                load_dataset(split=split_name, **kwargs)
                for split_name in split.split("+")
            ]
            # `features` and `cast_column` won't be available after interleaving, so we'll use them here
            features = dataset_splits[0].features
            # make sure that the dataset decodes audio with a correct sampling rate
            dataset_splits = [
                dataset.cast_column(
                    data_args.audio_column_name,
                    datasets.features.Audio(sampling_rate=sampling_rate))
                for dataset in dataset_splits
            ]

            interleaved_dataset = interleave_datasets(dataset_splits)
            return interleaved_dataset, features
        else:
            dataset = load_dataset(split=split, **kwargs)
            features = dataset.features
            # make sure that the dataset decodes audio with a correct sampling rate
            dataset = dataset.cast_column(
                data_args.audio_column_name,
                datasets.features.Audio(sampling_rate=sampling_rate))
            return dataset, features

    # `datasets` takes care of automatically loading and resampling the audio,
    # so we just need to set the correct target sampling rate and normalize the input
    # via the `feature_extractor`
    feature_extractor = AutoFeatureExtractor.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
        use_auth_token=data_args.use_auth_token)

    if training_args.do_train:
        raw_datasets["train"], train_features = load_streaming_dataset(
            path=data_args.dataset_name,
            name=data_args.dataset_config_name,
            split=data_args.train_split_name,
            use_auth_token=data_args.use_auth_token,
            streaming=True,
            sampling_rate=feature_extractor.sampling_rate,
        )
        raw_column_names["train"] = list(train_features.keys())

        if data_args.audio_column_name not in raw_column_names["train"]:
            raise ValueError(
                f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'."
                " Make sure to set `--audio_column_name` to the correct audio column - one of"
                f" {', '.join(raw_column_names['train'])}.")

        if data_args.text_column_name not in raw_column_names["train"]:
            raise ValueError(
                f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. "
                "Make sure to set `--text_column_name` to the correct text column - one of "
                f"{', '.join(raw_column_names['train'])}.")

        if data_args.max_train_samples is not None:
            raw_datasets["train"] = raw_datasets["train"].take(
                range(data_args.max_train_samples))

    if training_args.do_eval:
        raw_datasets["eval"], eval_features = load_streaming_dataset(
            path=data_args.dataset_name,
            name=data_args.dataset_config_name,
            split=data_args.eval_split_name,
            use_auth_token=data_args.use_auth_token,
            streaming=True,
            sampling_rate=feature_extractor.sampling_rate,
        )
        raw_column_names["eval"] = list(eval_features.keys())

        if data_args.max_eval_samples is not None:
            raw_datasets["eval"] = raw_datasets["eval"].take(
                range(data_args.max_eval_samples))

    # 2. We remove some special characters from the datasets
    # that make training complicated and do not help in transcribing the speech
    # E.g. characters, such as `,` and `.` do not really have an acoustic characteristic
    # that could be easily picked up by the model
    chars_to_ignore_regex = (f'[{"".join(data_args.chars_to_ignore)}]' if
                             data_args.chars_to_ignore is not None else None)
    text_column_name = data_args.text_column_name

    def remove_special_characters(batch):
        if chars_to_ignore_regex is not None:
            batch["target_text"] = re.sub(
                chars_to_ignore_regex, "",
                batch[text_column_name]).lower() + " "
        else:
            batch["target_text"] = batch[text_column_name].lower() + " "
        return batch

    with training_args.main_process_first(
            desc="dataset map special characters removal"):
        for split, dataset in raw_datasets.items():
            raw_datasets[split] = dataset.map(
                remove_special_characters, ).remove_columns([text_column_name])

    # 3. Next, let's load the config as we might need it to create
    # the tokenizer
    config = AutoConfig.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
        use_auth_token=data_args.use_auth_token)

    # 4. Now we can instantiate the tokenizer and model
    # Note for distributed training, the .from_pretrained methods guarantee that only
    # one local process can concurrently download model & vocab.

    tokenizer_name_or_path = model_args.tokenizer_name_or_path
    if tokenizer_name_or_path is None:
        raise ValueError(
            "Tokenizer has to be created before training in streaming mode. Please specify --tokenizer_name_or_path"
        )
    # load feature_extractor and tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        tokenizer_name_or_path,
        config=config,
        use_auth_token=data_args.use_auth_token,
    )

    # adapt config
    config.update({
        "feat_proj_dropout": model_args.feat_proj_dropout,
        "attention_dropout": model_args.attention_dropout,
        "hidden_dropout": model_args.hidden_dropout,
        "final_dropout": model_args.final_dropout,
        "mask_time_prob": model_args.mask_time_prob,
        "mask_time_length": model_args.mask_time_length,
        "mask_feature_prob": model_args.mask_feature_prob,
        "mask_feature_length": model_args.mask_feature_length,
        "gradient_checkpointing": training_args.gradient_checkpointing,
        "layerdrop": model_args.layerdrop,
        "ctc_loss_reduction": model_args.ctc_loss_reduction,
        "pad_token_id": tokenizer.pad_token_id,
        "vocab_size": len(tokenizer),
        "activation_dropout": model_args.activation_dropout,
    })

    # create model
    model = AutoModelForCTC.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
        config=config,
        use_auth_token=data_args.use_auth_token,
    )

    # freeze encoder
    if model_args.freeze_feature_encoder:
        model.freeze_feature_encoder()

    # 5. Now we preprocess the datasets including loading the audio, resampling and normalization
    audio_column_name = data_args.audio_column_name

    # `phoneme_language` is only relevant if the model is fine-tuned on phoneme classification
    phoneme_language = data_args.phoneme_language

    # Preprocessing the datasets.
    # We need to read the audio files as arrays and tokenize the targets.
    def prepare_dataset(batch):
        # load audio
        sample = batch[audio_column_name]

        inputs = feature_extractor(sample["array"],
                                   sampling_rate=sample["sampling_rate"])
        batch["input_values"] = inputs.input_values[0]
        batch["input_length"] = len(batch["input_values"])

        # encode targets
        additional_kwargs = {}
        if phoneme_language is not None:
            additional_kwargs["phonemizer_lang"] = phoneme_language

        batch["labels"] = tokenizer(batch["target_text"],
                                    **additional_kwargs).input_ids
        return batch

    vectorized_datasets = IterableDatasetDict()
    with training_args.main_process_first(desc="dataset map preprocessing"):
        for split, dataset in raw_datasets.items():
            vectorized_datasets[split] = (
                dataset.map(prepare_dataset).remove_columns(
                    raw_column_names[split] +
                    ["target_text"]).with_format("torch"))
            if split == "train":
                vectorized_datasets[split] = vectorized_datasets[
                    split].shuffle(
                        buffer_size=data_args.shuffle_buffer_size,
                        seed=training_args.seed,
                    )

    # 6. Next, we can prepare the training.
    # Let's use word error rate (WER) as our evaluation metric,
    # instantiate a data collator and the trainer

    # Define evaluation metrics during training, *i.e.* word error rate, character error rate
    eval_metrics = {
        metric: load_metric(metric)
        for metric in data_args.eval_metrics
    }

    def compute_metrics(pred):
        pred_logits = pred.predictions
        pred_ids = np.argmax(pred_logits, axis=-1)

        pred.label_ids[pred.label_ids == -100] = tokenizer.pad_token_id

        pred_str = tokenizer.batch_decode(pred_ids)
        # we do not want to group tokens when computing the metrics
        label_str = tokenizer.batch_decode(pred.label_ids, group_tokens=False)

        metrics = {
            k: v.compute(predictions=pred_str, references=label_str)
            for k, v in eval_metrics.items()
        }

        return metrics

    # Now save everything to be able to create a single processor later
    if is_main_process(training_args.local_rank):
        # save feature extractor, tokenizer and config
        feature_extractor.save_pretrained(training_args.output_dir)
        tokenizer.save_pretrained(training_args.output_dir)
        config.save_pretrained(training_args.output_dir)

    try:
        processor = AutoProcessor.from_pretrained(training_args.output_dir)
    except (OSError, KeyError):
        warnings.warn(
            "Loading a processor from a feature extractor config that does not"
            " include a `processor_class` attribute is deprecated and will be removed in v5. Please add the following "
            " attribute to your `preprocessor_config.json` file to suppress this warning: "
            " `'processor_class': 'Wav2Vec2Processor'`",
            FutureWarning,
        )
        processor = Wav2Vec2Processor.from_pretrained(training_args.output_dir)

    # Instantiate custom data collator
    max_input_length = data_args.max_duration_in_seconds * feature_extractor.sampling_rate
    data_collator = DataCollatorCTCWithPadding(processor=processor,
                                               max_length=max_input_length)

    # trainer callback to reinitialize and reshuffle the streamable datasets at the beginning of each epoch
    class ShuffleCallback(TrainerCallback):
        def on_epoch_begin(self, args, state, control, train_dataloader,
                           **kwargs):
            if isinstance(train_dataloader.dataset, IterableDatasetShard):
                pass  # set_epoch() is handled by the Trainer
            elif isinstance(train_dataloader.dataset, IterableDataset):
                train_dataloader.dataset.set_epoch(
                    train_dataloader.dataset._epoch + 1)

    # Initialize Trainer
    trainer = Trainer(
        model=model,
        data_collator=data_collator,
        args=training_args,
        compute_metrics=compute_metrics,
        train_dataset=vectorized_datasets["train"]
        if training_args.do_train else None,
        eval_dataset=vectorized_datasets["eval"]
        if training_args.do_eval else None,
        tokenizer=processor,
        callbacks=[ShuffleCallback()],
    )

    # 7. Finally, we can start training

    # Training
    if training_args.do_train:

        # use last checkpoint if exist
        if last_checkpoint is not None:
            checkpoint = last_checkpoint
        elif os.path.isdir(model_args.model_name_or_path):
            checkpoint = model_args.model_name_or_path
        else:
            checkpoint = None

        train_result = trainer.train(resume_from_checkpoint=checkpoint)
        trainer.save_model()

        metrics = train_result.metrics
        if data_args.max_train_samples:
            metrics["train_samples"] = data_args.max_train_samples

        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()

    # Evaluation
    results = {}
    if training_args.do_eval:
        logger.info("*** Evaluate ***")
        metrics = trainer.evaluate()
        if data_args.max_eval_samples:
            metrics["eval_samples"] = data_args.max_eval_samples

        trainer.log_metrics("eval", metrics)
        trainer.save_metrics("eval", metrics)

    # Write model card and (optionally) push to hub
    config_name = data_args.dataset_config_name if data_args.dataset_config_name is not None else "na"
    kwargs = {
        "finetuned_from":
        model_args.model_name_or_path,
        "tasks":
        "speech-recognition",
        "tags": ["automatic-speech-recognition", data_args.dataset_name],
        "dataset_args":
        (f"Config: {config_name}, Training split: {data_args.train_split_name}, Eval split:"
         f" {data_args.eval_split_name}"),
        "dataset":
        f"{data_args.dataset_name.upper()} - {config_name.upper()}",
    }
    if "common_voice" in data_args.dataset_name:
        kwargs["language"] = config_name

    if training_args.push_to_hub:
        trainer.push_to_hub(**kwargs)
    else:
        trainer.create_model_card(**kwargs)

    return results
示例#19
0
def main():
    # 1. Parse input arguments
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.
    parser = HfArgumentParser(
        (ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))

    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(
            json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses(
        )

    # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
    # information sent is the one passed as arguments along with your Python/PyTorch versions.
    send_example_telemetry("run_speech_recognition_seq2seq", model_args,
                           data_args)

    # 2. Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )
    log_level = training_args.get_process_log_level()
    logger.setLevel(log_level)
    datasets.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.enable_default_handler()
    transformers.utils.logging.enable_explicit_format()

    logger.setLevel(logging.INFO if is_main_process(training_args.local_rank
                                                    ) else logging.WARN)

    # Log on each process the small summary:
    logger.warning(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
        f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
    )
    logger.info(f"Training/evaluation parameters {training_args}")

    # Set the verbosity to info of the Transformers logger (on main process only):
    if is_main_process(training_args.local_rank):
        transformers.utils.logging.set_verbosity_info()
    logger.info("Training/evaluation parameters %s", training_args)

    # 3. Detecting last checkpoint and eventualy continue from last checkpoint
    last_checkpoint = None
    if os.path.isdir(
            training_args.output_dir
    ) and training_args.do_train and not training_args.overwrite_output_dir:
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        if last_checkpoint is None and len(os.listdir(
                training_args.output_dir)) > 0:
            raise ValueError(
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
                "Use --overwrite_output_dir to overcome.")
        elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
            logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )

    # Set seed before initializing model.
    set_seed(training_args.seed)

    # 4. Load dataset
    raw_datasets = DatasetDict()

    if training_args.do_train:
        raw_datasets["train"] = load_dataset(
            data_args.dataset_name,
            data_args.dataset_config_name,
            split=data_args.train_split_name,
            use_auth_token=True if model_args.use_auth_token else None,
        )

    if training_args.do_eval:
        raw_datasets["eval"] = load_dataset(
            data_args.dataset_name,
            data_args.dataset_config_name,
            split=data_args.eval_split_name,
            use_auth_token=True if model_args.use_auth_token else None,
        )

    if data_args.audio_column_name not in next(iter(
            raw_datasets.values())).column_names:
        raise ValueError(
            f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
            "Make sure to set `--audio_column_name` to the correct audio column - one of "
            f"{', '.join(next(iter(raw_datasets.values())).column_names)}.")

    if data_args.text_column_name not in next(iter(
            raw_datasets.values())).column_names:
        raise ValueError(
            f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. "
            "Make sure to set `--text_column_name` to the correct text column - one of "
            f"{', '.join(next(iter(raw_datasets.values())).column_names)}.")

    # 5. Load pretrained model, tokenizer, and feature extractor
    #
    # Distributed training:
    # The .from_pretrained methods guarantee that only one local process can concurrently
    config = AutoConfig.from_pretrained(
        model_args.config_name
        if model_args.config_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
    )

    feature_extractor = AutoFeatureExtractor.from_pretrained(
        model_args.feature_extractor_name if model_args.feature_extractor_name
        else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
    )
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.tokenizer_name
        if model_args.tokenizer_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
        use_fast=model_args.use_fast_tokenizer,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
    )
    model = AutoModelForSpeechSeq2Seq.from_pretrained(
        model_args.model_name_or_path,
        config=config,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
    )

    if model.config.decoder_start_token_id is None:
        raise ValueError(
            "Make sure that `config.decoder_start_token_id` is correctly defined"
        )

    if model_args.freeze_feature_encoder:
        model.freeze_feature_encoder()

    # 6. Resample speech dataset if necassary
    dataset_sampling_rate = next(iter(raw_datasets.values())).features[
        data_args.audio_column_name].sampling_rate
    if dataset_sampling_rate != feature_extractor.sampling_rate:
        raw_datasets = raw_datasets.cast_column(
            data_args.audio_column_name,
            datasets.features.Audio(
                sampling_rate=feature_extractor.sampling_rate))

    # 7. Preprocessing the datasets.
    # We need to read the audio files as arrays and tokenize the targets.
    max_input_length = data_args.max_duration_in_seconds * feature_extractor.sampling_rate
    min_input_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate
    audio_column_name = data_args.audio_column_name
    num_workers = data_args.preprocessing_num_workers
    text_column_name = data_args.text_column_name
    model_input_name = feature_extractor.model_input_names[0]
    do_lower_case = data_args.do_lower_case

    if data_args.max_train_samples is not None:
        raw_datasets["train"] = raw_datasets["train"].select(
            range(data_args.max_train_samples))

    if data_args.max_eval_samples is not None:
        raw_datasets["eval"] = raw_datasets["eval"].select(
            range(data_args.max_eval_samples))

    def prepare_dataset(batch):
        # process audio
        sample = batch[audio_column_name]
        inputs = feature_extractor(sample["array"],
                                   sampling_rate=sample["sampling_rate"])
        # process audio length
        batch[model_input_name] = inputs.input_values[0]
        batch["input_length"] = len(batch["input_values"])

        # process targets
        input_str = batch[text_column_name].lower(
        ) if do_lower_case else batch[text_column_name]
        batch["labels"] = tokenizer(input_str).input_ids
        return batch

    with training_args.main_process_first(desc="dataset map pre-processing"):
        vectorized_datasets = raw_datasets.map(
            prepare_dataset,
            remove_columns=next(iter(raw_datasets.values())).column_names,
            num_proc=data_args.preprocessing_num_workers,
            desc="preprocess train dataset",
        )

    # filter data that is shorter than min_input_length or longer than
    # max_input_length
    def is_audio_in_length_range(length):
        return length > min_input_length and length < max_input_length

    vectorized_datasets = vectorized_datasets.filter(
        is_audio_in_length_range,
        num_proc=num_workers,
        input_columns=["input_length"],
    )

    # for large datasets it is advised to run the preprocessing on a
    # single machine first with `args.preprocessing_only` since there will mostly likely
    # be a timeout when running the script in distributed mode.
    # In a second step `args.preprocessing_only` can then be set to `False` to load the
    # cached dataset
    if data_args.preprocessing_only:
        cache = {k: v.cache_files for k, v in vectorized_datasets.items()}
        logger.info(f"Data preprocessing finished. Files cached at {cache}.")
        return

    # 8. Load Metric
    metric = evaluate.load("wer")

    def compute_metrics(pred):
        pred_ids = pred.predictions

        pred.label_ids[pred.label_ids == -100] = tokenizer.pad_token_id

        pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
        # we do not want to group tokens when computing the metrics
        label_str = tokenizer.batch_decode(pred.label_ids,
                                           skip_special_tokens=True)

        wer = metric.compute(predictions=pred_str, references=label_str)

        return {"wer": wer}

    # 9. Create a single speech processor
    if is_main_process(training_args.local_rank):
        # save feature extractor, tokenizer and config
        feature_extractor.save_pretrained(training_args.output_dir)
        tokenizer.save_pretrained(training_args.output_dir)
        config.save_pretrained(training_args.output_dir)

    processor = AutoProcessor.from_pretrained(training_args.output_dir)

    # 10. Define data collator
    data_collator = DataCollatorSpeechSeq2SeqWithPadding(
        processor=processor,
        decoder_start_token_id=model.config.decoder_start_token_id)

    # 11. Initialize Trainer
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=vectorized_datasets["train"]
        if training_args.do_train else None,
        eval_dataset=vectorized_datasets["eval"]
        if training_args.do_eval else None,
        tokenizer=feature_extractor,
        data_collator=data_collator,
        compute_metrics=compute_metrics
        if training_args.predict_with_generate else None,
    )

    # 12. Training
    if training_args.do_train:
        checkpoint = None
        if training_args.resume_from_checkpoint is not None:
            checkpoint = training_args.resume_from_checkpoint
        elif last_checkpoint is not None:
            checkpoint = last_checkpoint
        train_result = trainer.train(resume_from_checkpoint=checkpoint)
        trainer.save_model()  # Saves the feature extractor too for easy upload

        metrics = train_result.metrics
        max_train_samples = (data_args.max_train_samples
                             if data_args.max_train_samples is not None else
                             len(vectorized_datasets["train"]))
        metrics["train_samples"] = min(max_train_samples,
                                       len(vectorized_datasets["train"]))
        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()

    # 13. Evaluation
    results = {}
    if training_args.do_eval:
        logger.info("*** Evaluate ***")
        metrics = trainer.evaluate(metric_key_prefix="eval",
                                   max_length=model.config.max_length,
                                   num_beams=model.config.num_beams)
        max_eval_samples = (data_args.max_eval_samples
                            if data_args.max_eval_samples is not None else len(
                                vectorized_datasets["eval"]))
        metrics["eval_samples"] = min(max_eval_samples,
                                      len(vectorized_datasets["eval"]))

        trainer.log_metrics("eval", metrics)
        trainer.save_metrics("eval", metrics)

    # 14. Write Training Stats
    kwargs = {
        "finetuned_from": model_args.model_name_or_path,
        "tasks": "speech recognition"
    }
    if data_args.dataset_name is not None:
        kwargs["dataset_tags"] = data_args.dataset_name
        if data_args.dataset_config_name is not None:
            kwargs["dataset_args"] = data_args.dataset_config_name
            kwargs[
                "dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
        else:
            kwargs["dataset"] = data_args.dataset_name

    if training_args.push_to_hub:
        trainer.push_to_hub(**kwargs)
    else:
        trainer.create_model_card(**kwargs)

    return results
def main():
    raw_datasets = DatasetDict()

    if training_args.do_train:
        raw_datasets["train"] = load_dataset(data_args.dataset_name,
                                             data_args.dataset_config,
                                             split=data_args.train_split_name)

    if training_args.do_eval:
        raw_datasets["eval"] = load_dataset(data_args.dataset_name,
                                            data_args.dataset_config,
                                            split=data_args.eval_split_name)

    if data_args.audio_column not in next(iter(
            raw_datasets.values())).column_names:
        raise ValueError(
            f"--audio_column '{data_args.audio_column}' not found in dataset '{data_args.dataset_name}'. "
            "Make sure to set `--audio_column` to the correct audio column - one of "
            f"{', '.join(next(iter(raw_datasets.values())).column_names)}.")

    if data_args.text_column not in next(iter(
            raw_datasets.values())).column_names:
        raise ValueError(
            f"--text_column {data_args.text_column} not found in dataset '{data_args.dataset_name}'. "
            "Make sure to set `--text_column` to the correct text column - one of "
            f"{', '.join(next(iter(raw_datasets.values())).column_names)}.")

    # 5. Load pretrained model, tokenizer, and feature extractor
    #
    # Distributed training:
    # The .from_pretrained methods guarantee that only one local process can concurrently
    config = AutoConfig.from_pretrained(
        model_args.config_name
        if model_args.config_name else model_args.model_name,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_version,
        use_auth_token=True if model_args.use_auth_token else None,
    )

    feature_extractor = AutoFeatureExtractor.from_pretrained(
        model_args.feature_extractor
        if model_args.feature_extractor else model_args.model_name,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_version,
        use_auth_token=True if model_args.use_auth_token else None,
    )
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.tokenizer_name
        if model_args.tokenizer_name else model_args.model_name,
        cache_dir=model_args.cache_dir,
        use_fast=model_args.use_fast_tokenizer,
        revision=model_args.model_version,
        use_auth_token=True if model_args.use_auth_token else None,
    )
    model = AutoModelForSpeechSeq2Seq.from_pretrained(
        model_args.model_name,
        config=config,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_version,
        use_auth_token=True if model_args.use_auth_token else None,
    )

    if model.config.dec_START is None:
        raise ValueError(
            "Make sure that `config.dec_START` is correctly defined")

    if model_args.freeze_feature_encoder:
        model.freeze_feature_encoder()

    # 6. Resample speech dataset if necassary
    dataset_sampling_rate = (next(iter(
        raw_datasets.values())).features[data_args.audio_column].sampling_rate)
    if dataset_sampling_rate != feature_extractor.sampling_rate:
        raw_datasets = raw_datasets.cast_column(
            data_args.audio_column,
            datasets.features.Audio(
                sampling_rate=feature_extractor.sampling_rate),
        )

    # 7. Preprocessing the datasets.
    # We need to read the audio files as arrays and tokenize the targets.
    max_input_length = data_args.max_duration * feature_extractor.sampling_rate
    min_input_length = data_args.min_duration * feature_extractor.sampling_rate
    audio_column = data_args.audio_column
    num_workers = data_args.num_workers
    text_column = data_args.text_column
    model_input_name = feature_extractor.model_input_names[0]
    lower_case = data_args.lower_case

    if data_args.max_train_samples is not None:
        raw_datasets["train"] = raw_datasets["train"].select(
            range(data_args.max_train_samples))

    if data_args.max_eval_samples is not None:
        raw_datasets["eval"] = raw_datasets["eval"].select(
            range(data_args.max_eval_samples))

    def prepare_dataset(batch):
        # process audio
        sample = batch[audio_column]
        inputs = feature_extractor(sample["array"],
                                   sampling_rate=sample["sampling_rate"])
        # process audio length
        batch[model_input_name] = inputs.input_values[0]
        batch["input_length"] = len(batch["input_values"])

        # process targets
        input_str = batch[text_column].lower(
        ) if lower_case else batch[text_column]
        batch["labels"] = tokenizer(input_str).input_ids
        return batch

    with training_args.main_process_first(desc="dataset map pre-processing"):
        vectorized_datasets = raw_datasets.map(
            prepare_dataset,
            remove_columns=next(iter(raw_datasets.values())).column_names,
            num_proc=data_args.num_workers,
            desc="preprocess train dataset",
        )

    # filter data that is shorter than min_input_length or longer than
    # max_input_length
    def is_audio_in_length_range(length):
        return length > min_input_length and length < max_input_length

    vectorized_datasets = vectorized_datasets.filter(
        is_audio_in_length_range,
        num_proc=num_workers,
        input_columns=["input_length"],
    )

    # for large datasets it is advised to run the preprocessing on a
    # single machine first with `args.preprocessing_only` since there will mostly likely
    # be a timeout when running the script in distributed mode.
    # In a second step `args.preprocessing_only` can then be set to `False` to load the
    # cached dataset
    if data_args.preprocessing_only:
        cache = {k: v.cache_files for k, v in vectorized_datasets.items()}
        logger.info(f"Data preprocessing finished. Files cached at {cache}.")
        return

    # 8. Load Metric
    metric = load_metric("wer")

    def compute_metrics(pred):
        pred_ids = pred.predictions

        pred.label_ids[pred.label_ids == -100] = tokenizer.PAD

        pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
        # we do not want to group tokens when computing the metrics
        label_str = tokenizer.batch_decode(pred.label_ids,
                                           skip_special_tokens=True)

        wer = metric.compute(predictions=pred_str, references=label_str)

        return {"wer": wer}

    # 9. Create a single speech processor
    if is_main_process(training_args.local_rank):
        # save feature extractor, tokenizer and config
        feature_extractor.save_pretrained(training_args.out_dir)
        tokenizer.save_pretrained(training_args.out_dir)
        config.save_pretrained(training_args.out_dir)

    processor = AutoProcessor.from_pretrained(training_args.out_dir)

    # 10. Define data collator
    data_collator = DataCollatorSpeechSeq2SeqWithPadding(
        processor=processor, dec_START=model.config.dec_START)

    # 11. Initialize Trainer
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=vectorized_datasets["train"]
        if training_args.do_train else None,
        eval_dataset=vectorized_datasets["eval"]
        if training_args.do_eval else None,
        tokenizer=feature_extractor,
        data_collator=data_collator,
        compute_metrics=compute_metrics
        if training_args.test_with_gen else None,
    )

    # 12. Training
    if training_args.do_train:
        checkpoint = None
        if training_args.resume_from_checkpoint is not None:
            checkpoint = training_args.resume_from_checkpoint
        elif last_checkpoint is not None:
            checkpoint = last_checkpoint
        train_result = trainer.train(resume_from_checkpoint=checkpoint)
        trainer.save_model()  # Saves the feature extractor too for easy upload

        metrics = train_result.metrics
        max_train_samples = (data_args.max_train_samples
                             if data_args.max_train_samples is not None else
                             len(vectorized_datasets["train"]))
        metrics["train_samples"] = min(max_train_samples,
                                       len(vectorized_datasets["train"]))
        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()

    # 13. Evaluation
    results = {}
    if training_args.do_eval:
        logger.info("*** Evaluate ***")
        metrics = trainer.evaluate(
            metric_key_prefix="eval",
            max_len=model.config.max_len,
            n_beams=model.config.n_beams,
        )
        max_eval_samples = (data_args.max_eval_samples
                            if data_args.max_eval_samples is not None else len(
                                vectorized_datasets["eval"]))
        metrics["eval_samples"] = min(max_eval_samples,
                                      len(vectorized_datasets["eval"]))

        trainer.log_metrics("eval", metrics)
        trainer.save_metrics("eval", metrics)

    # 14. Write Training Stats
    kw = {
        "finetuned_from": model_args.model_name,
        "tasks": "speech recognition"
    }
    if data_args.dataset_name is not None:
        kw["dataset_tags"] = data_args.dataset_name
        if data_args.dataset_config is not None:
            kw["dataset_args"] = data_args.dataset_config
            kw["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config}"
        else:
            kw["dataset"] = data_args.dataset_name

    if training_args.push_to_hub:
        trainer.push_to_hub(**kw)
    else:
        trainer.create_model_card(**kw)

    return results