def __init__(self, source, save_path, output_norm=True, freeze=True, pretrain=True): super().__init__() # Download the extractor from HuggingFace. # The extractor is only used to retrieve the normalisation self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( source, cache_dir=save_path) # Download the model from HuggingFace. # if pretrain is False, we do not download the pretrained weights # it it is True, we download and load them. if not (pretrain): config = Wav2Vec2Config.from_pretrained(source, cache_dir=save_path) self.model = Wav2Vec2Model(config) else: self.model = Wav2Vec2Model.from_pretrained(source, cache_dir=save_path) # We check if inputs need to be normalized w.r.t pretrained wav2vec2 self.normalize_wav = self.feature_extractor.do_normalize self.freeze = freeze self.output_norm = output_norm if self.freeze: self.model.eval() else: self.model.train()
def convert_wav2vec2_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path=None, dict_path=None, is_finetuned=True): """ Copy/paste/tweak model's weights to transformers design. """ if config_path is not None: config = Wav2Vec2Config.from_pretrained(config_path) else: config = Wav2Vec2Config() if is_finetuned: hf_wav2vec = Wav2Vec2ForCTC(config) else: hf_wav2vec = Wav2Vec2Model(config) if is_finetuned: model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task( [checkpoint_path], arg_overrides={"data": dict_path}) else: model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task( [checkpoint_path]) model = model[0].eval() recursively_load_weights(model, hf_wav2vec, is_finetuned) hf_wav2vec.save_pretrained(pytorch_dump_folder_path)
def convert_s3prl_checkpoint(base_model_name, config_path, checkpoint_path, model_dump_path): """ Copy/paste/tweak model's weights to transformers design. """ checkpoint = torch.load(checkpoint_path, map_location="cpu") downstream_dict = checkpoint["Downstream"] hf_config = Wav2Vec2Config.from_pretrained(config_path) hf_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( base_model_name, return_attention_mask=True, do_normalize=False) arch = hf_config.architectures[0] if arch.endswith("ForSequenceClassification"): hf_model = convert_classification(base_model_name, hf_config, downstream_dict) elif arch.endswith("ForAudioFrameClassification"): hf_model = convert_diarization(base_model_name, hf_config, downstream_dict) elif arch.endswith("ForXVector"): hf_model = convert_xvector(base_model_name, hf_config, downstream_dict) else: raise NotImplementedError( f"S3PRL weights conversion is not supported for {arch}") if hf_config.use_weighted_layer_sum: hf_model.layer_weights.data = checkpoint["Featurizer"]["weights"] hf_feature_extractor.save_pretrained(model_dump_path) hf_model.save_pretrained(model_dump_path)
def test_inference_pretrained(self): model = FlaxWav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-large-lv60", from_pt=True) feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( "facebook/wav2vec2-large-lv60", return_attention_mask=True ) input_speech = self._load_datasamples(2) inputs_dict = feature_extractor(input_speech, return_tensors="np", padding=True) features_shape = ( inputs_dict["input_values"].shape[0], model._get_feat_extract_output_lengths(np.array(inputs_dict["input_values"].shape[1])), ) mask_time_indices = _compute_mask_indices( features_shape, model.config.mask_time_prob, model.config.mask_time_length, min_masks=2, ) outputs = model( inputs_dict.input_values, attention_mask=inputs_dict.attention_mask, mask_time_indices=mask_time_indices, ) # compute cosine similarity cosine_sim = optax.cosine_similarity( outputs.projected_states, outputs.projected_quantized_states, epsilon=1e-8 ) # retrieve cosine sim of masked features cosine_sim_masked = cosine_sim[mask_time_indices] # ... now compare to randomly initialized model config = Wav2Vec2Config.from_pretrained("facebook/wav2vec2-large-lv60") model_rand = FlaxWav2Vec2ForPreTraining(config) outputs_rand = model_rand( inputs_dict.input_values, attention_mask=inputs_dict.attention_mask, mask_time_indices=mask_time_indices, ) # compute cosine similarity cosine_sim_rand = optax.cosine_similarity( outputs_rand.projected_states, outputs_rand.projected_quantized_states ) # retrieve cosine sim of masked features cosine_sim_masked_rand = cosine_sim_rand[mask_time_indices] # a pretrained wav2vec2 model has learned to predict the quantized latent states # => the cosine similarity between quantized states and predicted states > 0.5 # a random wav2vec2 model has not learned to predict the quantized latent states # => the cosine similarity between quantized states and predicted states is very likely < 0.1 self.assertTrue(cosine_sim_masked.mean().item() - 5 * cosine_sim_masked_rand.mean().item() > 0)
def test_pretrained_checkpoints_are_set_correctly(self): # this test makes sure that models that are using # group norm don't have their tokenizer return the # attention_mask for model_id in WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST: config = Wav2Vec2Config.from_pretrained(model_id) tokenizer = Wav2Vec2Tokenizer.from_pretrained(model_id) # only "layer" feature extraction norm should make use of # attention_mask self.assertEqual(tokenizer.return_attention_mask, config.feat_extract_norm == "layer")
def __init__( self, source, save_path, mask_prob=0.65, mask_length=10, normalize_wav=True, ): super().__init__() self.mask_prob = mask_prob self.mask_length = mask_length self.normalize_wav = normalize_wav # Download the config of the model from HuggingFace. self.config = Wav2Vec2Config.from_pretrained(source, cache_dir=save_path) self.config.output_hidden_states = ( True # We want the hidden states as well! ) self.model = Wav2Vec2ForPreTraining(self.config) self.model.gradient_checkpointing_disable() # Required by DDP self.model.train()
def convert_wav2vec2_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path=None, dict_path=None, is_finetuned=True): """ Copy/paste/tweak model's weights to transformers design. """ if config_path is not None: config = Wav2Vec2Config.from_pretrained(config_path) else: config = Wav2Vec2Config() if is_finetuned: if dict_path: target_dict = Dictionary.load(dict_path) config.bos_token_id = target_dict.bos_index config.eos_token_id = target_dict.eos_index config.pad_token_id = target_dict.pad_index config.vocab_size = len(target_dict.symbols) vocab_path = os.path.join(pytorch_dump_folder_path, "vocab.json") if not os.path.isdir(pytorch_dump_folder_path): logger.error( "--pytorch_dump_folder_path ({}) should be a directory". format(pytorch_dump_folder_path)) return os.makedirs(pytorch_dump_folder_path, exist_ok=True) with open(vocab_path, "w", encoding="utf-8") as vocab_handle: json.dump(target_dict.indices, vocab_handle) tokenizer = Wav2Vec2CTCTokenizer( vocab_path, unk_token=target_dict.unk_word, pad_token=target_dict.pad_word, bos_token=target_dict.bos_word, eos_token=target_dict.eos_word, word_delimiter_token="|", do_lower_case=False, ) return_attention_mask = True if config.feat_extract_norm == "layer" else False feature_extractor = Wav2Vec2FeatureExtractor( feature_size=1, sampling_rate=16000, padding_value=0, do_normalize=True, return_attention_mask=return_attention_mask, ) processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer) processor.save_pretrained(pytorch_dump_folder_path) hf_wav2vec = Wav2Vec2ForCTC(config) else: hf_wav2vec = Wav2Vec2Model(config) if is_finetuned: model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task( [checkpoint_path], arg_overrides={"data": dict_path}) else: model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task( [checkpoint_path]) model = model[0].eval() recursively_load_weights(model, hf_wav2vec, is_finetuned) hf_wav2vec.save_pretrained(pytorch_dump_folder_path)
def main(): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. # We now keep distinct sets of args, for a cleaner separation of concerns. parser = HfArgumentParser( (ModelArguments, DataTrainingArguments, TrainingArguments)) model_args, data_args, training_args = parser.parse_args_into_dataclasses() configure_logger(model_args, training_args) # Downloading and loading a dataset from the hub. datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir) if "validation" not in datasets.keys(): # make sure only "validation" and "train" keys remain" datasets = DatasetDict() datasets["validation"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split= f"{data_args.train_split_name}[:{data_args.validation_split_percentage}%]", cache_dir=model_args.cache_dir, ) datasets["train"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split= f"{data_args.train_split_name}[{data_args.validation_split_percentage}%:]", cache_dir=model_args.cache_dir, ) else: # make sure only "validation" and "train" keys remain" datasets = DatasetDict() datasets["validation"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split="validation", cache_dir=model_args.cache_dir, ) datasets["train"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split=f"{data_args.train_split_name}", cache_dir=model_args.cache_dir, ) # only normalized-inputs-training is supported feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( model_args.model_name_or_path, cache_dir=model_args.cache_dir, do_normalize=True) def prepare_dataset(batch): # check that all files have the correct sampling rate batch["speech"], _ = librosa.load(batch[data_args.speech_file_column], sr=feature_extractor.sampling_rate) return batch # load audio files into numpy arrays vectorized_datasets = datasets.map( prepare_dataset, num_proc=data_args.preprocessing_num_workers, remove_columns=datasets["train"].column_names) # filter audio files that are too long vectorized_datasets = vectorized_datasets.filter(lambda data: len(data[ "speech"]) < int(data_args.max_duration_in_seconds * feature_extractor. sampling_rate)) def normalize(batch): return feature_extractor(batch["speech"], sampling_rate=feature_extractor.sampling_rate) # normalize and transform to `BatchFeatures` vectorized_datasets = vectorized_datasets.map( normalize, batched=True, num_proc=data_args.preprocessing_num_workers, load_from_cache_file=not data_args.overwrite_cache, remove_columns=vectorized_datasets["train"].column_names, ) # pretraining is only supported for "newer" stable layer norm architecture # apply_spec_augment has to be True, mask_feature_prob has to be 0.0 config = Wav2Vec2Config.from_pretrained( model_args.model_name_or_path, cache_dir=model_args.cache_dir, gradient_checkpointing=training_args.gradient_checkpointing, ) if not config.do_stable_layer_norm or config.feat_extract_norm != "layer": raise ValueError( "PreTraining is only supported for ``config.do_stable_layer_norm=True`` and" " ``config.feat_extract_norm='layer'") model = Wav2Vec2ForPreTraining(config) data_collator = DataCollatorForWav2Vec2Pretraining( model=model, feature_extractor=feature_extractor) trainer = Wav2Vec2PreTrainer( model=model, data_collator=data_collator, args=training_args, train_dataset=vectorized_datasets["train"], eval_dataset=vectorized_datasets["validation"], tokenizer=feature_extractor, max_gumbel_temp=model_args.max_gumbel_temperature, min_gumbel_temp=model_args.min_gumbel_temperature, gumbel_temp_decay=model_args.gumbel_temperature_decay, ) trainer.train()
def test_inference_pretrained(self): model = Wav2Vec2ForPreTraining.from_pretrained( "facebook/wav2vec2-base") model.to(torch_device) feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( "facebook/wav2vec2-base", return_attention_mask=True) input_speech = self._load_datasamples(2) inputs_dict = feature_extractor(input_speech, return_tensors="pt", padding=True) features_shape = ( inputs_dict["input_values"].shape[0], model._get_feat_extract_output_lengths( torch.tensor(inputs_dict["input_values"].shape[1])), ) torch.manual_seed(0) mask_time_indices = _compute_mask_indices( features_shape, model.config.mask_time_prob, model.config.mask_time_length, device=inputs_dict["input_values"].device, min_masks=2, ).to(torch_device) with torch.no_grad(): outputs = model( inputs_dict.input_values.to(torch_device), attention_mask=inputs_dict.attention_mask.to(torch_device), mask_time_indices=mask_time_indices, ) # compute cosine similarity cosine_sim = torch.cosine_similarity( outputs.projected_states, outputs.projected_quantized_states, dim=-1) # retrieve cosine sim of masked features cosine_sim_masked = cosine_sim[mask_time_indices] # ... now compare to randomly initialized model config = Wav2Vec2Config.from_pretrained("facebook/wav2vec2-base") model_rand = Wav2Vec2ForPreTraining(config).to(torch_device).eval() with torch.no_grad(): outputs_rand = model_rand( inputs_dict.input_values.to(torch_device), attention_mask=inputs_dict.attention_mask.to(torch_device), mask_time_indices=mask_time_indices, ) # compute cosine similarity cosine_sim_rand = torch.cosine_similarity( outputs_rand.projected_states, outputs_rand.projected_quantized_states, dim=-1) # retrieve cosine sim of masked features cosine_sim_masked_rand = cosine_sim_rand[mask_time_indices] # a pretrained wav2vec2 model has learned to predict the quantized latent states # => the cosine similarity between quantized states and predicted states > 0.5 # a random wav2vec2 model has not learned to predict the quantized latent states # => the cosine similarity between quantized states and predicted states is very likely < 0.1 self.assertTrue(cosine_sim_masked.mean().item() - 5 * cosine_sim_masked_rand.mean().item() > 0)
def main(): # See all possible arguments in src/transformers/args.py # or by passing the --help flag to this script. # We now keep distinct sets of args, for a cleaner separation of concerns. args = parse_args() # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. accelerator = Accelerator() logger.info(accelerator.state, main_process_only=False) if accelerator.is_local_main_process: datasets.utils.logging.set_verbosity_warning() transformers.utils.logging.set_verbosity_info() # set up weights and biases if available if is_wandb_available(): import wandb wandb.init(project=args.output_dir.split("/")[-1]) else: datasets.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error() # If passed along, set the training seed now. if args.seed is not None: set_seed(args.seed) # Handle the repository creation if accelerator.is_main_process: if args.push_to_hub and not args.preprocessing_only: if args.hub_model_id is None: repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) else: repo_name = args.hub_model_id repo = Repository(args.output_dir, clone_from=repo_name) elif args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) accelerator.wait_for_everyone() # 1. Download and create train, validation dataset # We load all dataset configuration and datset split pairs passed in # ``args.dataset_config_names`` and ``args.dataset_split_names`` datasets_splits = [] for dataset_config_name, train_split_name in zip(args.dataset_config_names, args.dataset_split_names): # load dataset dataset_split = load_dataset( args.dataset_name, dataset_config_name, split=train_split_name, cache_dir=args.cache_dir, ) datasets_splits.append(dataset_split) # Next, we concatenate all configurations and splits into a single training dataset raw_datasets = DatasetDict() if len(datasets_splits) > 1: raw_datasets["train"] = concatenate_datasets(datasets_splits).shuffle(seed=args.seed) else: raw_datasets["train"] = datasets_splits[0] # Take ``args.validation_split_percentage`` from the training dataset for the validation_split_percentage num_validation_samples = raw_datasets["train"].num_rows * args.validation_split_percentage // 100 if num_validation_samples == 0: raise ValueError( "`args.validation_split_percentage` is less than a single sample " f"for {len(raw_datasets['train'])} training samples. Increase " "`args.num_validation_split_percentage`. " ) raw_datasets["validation"] = raw_datasets["train"].select(range(num_validation_samples)) raw_datasets["train"] = raw_datasets["train"].select(range(num_validation_samples, raw_datasets["train"].num_rows)) # 2. Now we preprocess the datasets including loading the audio, resampling and normalization # Thankfully, `datasets` takes care of automatically loading and resampling the audio, # so that we just need to set the correct target sampling rate and normalize the input # via the `feature_extractor` feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(args.model_name_or_path) # make sure that dataset decodes audio with correct sampling rate raw_datasets = raw_datasets.cast_column( args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate) ) # only normalized-inputs-training is supported if not feature_extractor.do_normalize: raise ValueError( "Training is only supported for normalized inputs. Make sure ``feature_extractor.do_normalize == True``" ) # set max & min audio length in number of samples max_length = int(args.max_duration_in_seconds * feature_extractor.sampling_rate) min_length = int(args.min_duration_in_seconds * feature_extractor.sampling_rate) def prepare_dataset(batch): sample = batch[args.audio_column_name] inputs = feature_extractor( sample["array"], sampling_rate=sample["sampling_rate"], max_length=max_length, truncation=True ) batch["input_values"] = inputs.input_values[0] batch["input_length"] = len(inputs.input_values[0]) return batch # load via mapped files via path cache_file_names = None if args.train_cache_file_name is not None: cache_file_names = {"train": args.train_cache_file_name, "validation": args.validation_cache_file_name} # load audio files into numpy arrays with accelerator.main_process_first(): vectorized_datasets = raw_datasets.map( prepare_dataset, num_proc=args.preprocessing_num_workers, remove_columns=raw_datasets["train"].column_names, cache_file_names=cache_file_names, ) if min_length > 0.0: vectorized_datasets = vectorized_datasets.filter( lambda x: x > min_length, num_proc=args.preprocessing_num_workers, input_columns=["input_length"], ) vectorized_datasets = vectorized_datasets.remove_columns("input_length") # for large datasets it is advised to run the preprocessing on a # single machine first with ``args.preprocessing_only`` since there will mostly likely # be a timeout when running the script in distributed mode. # In a second step ``args.preprocessing_only`` can then be set to `False` to load the # cached dataset if args.preprocessing_only: return # 3. Load model config = Wav2Vec2Config.from_pretrained(args.model_name_or_path) # pretraining is only supported for "newer" stable layer norm architecture # apply_spec_augment has to be True, mask_feature_prob has to be 0.0 if not config.do_stable_layer_norm or config.feat_extract_norm != "layer": raise ValueError( "PreTraining is only supported for ``config.do_stable_layer_norm=True`` and" " ``config.feat_extract_norm='layer'" ) # initialize random model model = Wav2Vec2ForPreTraining(config) # Activate gradient checkpointing if needed if args.gradient_checkpointing: model.gradient_checkpointing_enable() # 4. Define data collator, optimizer and scheduler data_collator = DataCollatorForWav2Vec2Pretraining( model=model, feature_extractor=feature_extractor, pad_to_multiple_of=args.pad_to_multiple_of ) train_dataloader = DataLoader( vectorized_datasets["train"], shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size, ) eval_dataloader = DataLoader( vectorized_datasets["validation"], collate_fn=data_collator, batch_size=args.per_device_eval_batch_size ) # Optimizer optimizer = AdamW( list(model.parameters()), lr=args.learning_rate, betas=[args.adam_beta1, args.adam_beta2], eps=args.adam_epsilon, ) # Prepare everything with our `accelerator`. model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare( model, optimizer, train_dataloader, eval_dataloader ) # Scheduler and math around the number of training steps. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch else: args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) lr_scheduler = get_scheduler( name=args.lr_scheduler_type, optimizer=optimizer, num_warmup_steps=args.num_warmup_steps, num_training_steps=args.max_train_steps, ) # 5. Train total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps logger.info("***** Running training *****") logger.info(f" Num examples = {len(vectorized_datasets['train'])}") logger.info(f" Num Epochs = {args.num_train_epochs}") logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {args.max_train_steps}") completed_steps = 0 starting_epoch = 0 # Only show the progress bar once on each machine. progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) completed_steps = 0 starting_epoch = 0 for epoch in range(starting_epoch, args.num_train_epochs): model.train() for step, batch in enumerate(train_dataloader): # compute num of losses num_losses = batch["mask_time_indices"].sum() sub_attention_mask = batch.pop("sub_attention_mask", None) sub_attention_mask = ( sub_attention_mask if sub_attention_mask is not None else torch.ones_like(batch["mask_time_indices"]) ) percent_masked = num_losses / sub_attention_mask.sum() # forward outputs = model(**batch) # divide loss by gradient accumulation steps since gradients # are accumulated for multiple backward passes in PyTorch loss = outputs.loss / args.gradient_accumulation_steps accelerator.backward(loss) # make sure that `num_losses` is summed for distributed training # and average gradients over losses of all devices if accelerator.state.num_processes > 1: num_losses = accelerator.gather(num_losses).sum() gradient_multiplier = accelerator.state.num_processes / num_losses multiply_grads(model.module.parameters(), gradient_multiplier) else: multiply_grads(model.parameters(), 1 / num_losses) # update step if (step + 1) % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: # compute grad norm for monitoring scale = ( accelerator.scaler._scale.item() if hasattr(accelerator, "scaler") and accelerator.scaler is not None else 1 ) if accelerator.state.num_processes > 1: grad_norm = get_grad_norm(model.module.parameters(), scale) else: grad_norm = get_grad_norm(model.parameters(), scale) # update parameters optimizer.step() optimizer.zero_grad() if not accelerator.optimizer_step_was_skipped: lr_scheduler.step() elif accelerator.is_local_main_process: progress_bar.write( f"Gradients have overflown - skipping update step... Updating gradient scale to {scale}..." ) # update gumbel temperature gumbel_temperature = max( args.max_gumbel_temperature * args.gumbel_temperature_decay**completed_steps, args.min_gumbel_temperature, ) if hasattr(model, "module"): model.module.set_gumbel_temperature(gumbel_temperature) else: model.set_gumbel_temperature(gumbel_temperature) progress_bar.update(1) completed_steps += 1 # 6. Log all results if (step + 1) % (args.gradient_accumulation_steps * args.logging_steps) == 0: loss.detach() outputs.contrastive_loss.detach() outputs.diversity_loss.detach() if accelerator.state.num_processes > 1: loss = accelerator.gather(loss).sum() outputs.contrastive_loss = accelerator.gather(outputs.contrastive_loss).sum() outputs.diversity_loss = accelerator.gather(outputs.diversity_loss).sum() percent_masked = accelerator.gather(percent_masked).sum() train_logs = { "loss": (loss * args.gradient_accumulation_steps) / num_losses, "constrast_loss": outputs.contrastive_loss / num_losses, "div_loss": outputs.diversity_loss / num_losses, "%_mask_idx": percent_masked / accelerator.num_processes, "ppl": outputs.codevector_perplexity, "lr": torch.tensor(optimizer.param_groups[0]["lr"]), "temp": torch.tensor(gumbel_temperature), "grad_norm": torch.tensor(grad_norm), } log_str = "" for k, v in train_logs.items(): log_str += "| {}: {:.3e}".format(k, v.item()) if accelerator.is_local_main_process: progress_bar.write(log_str) if is_wandb_available(): wandb.log(train_logs) # save model every `args.saving_steps` steps if (step + 1) % (args.gradient_accumulation_steps * args.saving_steps) == 0: if (args.push_to_hub and epoch < args.num_train_epochs - 1) or args.output_dir is not None: accelerator.wait_for_everyone() unwrapped_model = accelerator.unwrap_model(model) unwrapped_model.save_pretrained( args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save ) if (args.push_to_hub and epoch < args.num_train_epochs - 1) and accelerator.is_main_process: repo.push_to_hub( commit_message=f"Training in progress step {completed_steps}", blocking=False, auto_lfs_prune=True, ) # if completed steps > `args.max_train_steps` stop if completed_steps >= args.max_train_steps: break # 7. Validate! model.eval() # init logs val_logs = { "val_loss": 0, "val_contrastive_loss": 0, "val_diversity_loss": 0, "val_num_losses": 0, } for step, batch in enumerate(eval_dataloader): with torch.no_grad(): batch.pop("sub_attention_mask", None) outputs = model(**batch) val_logs["val_loss"] += outputs.loss val_logs["val_contrastive_loss"] += outputs.contrastive_loss val_logs["val_diversity_loss"] += outputs.diversity_loss val_logs["val_num_losses"] += batch["mask_time_indices"].sum() # sum over devices in multi-processing if accelerator.num_processes > 1: val_logs = {k: accelerator.gather(v).sum() for k, v in val_logs.items()} val_logs = {k: v / val_logs["val_num_losses"] for k, v in val_logs.items()} log_str = "" for k, v in val_logs.items(): log_str += "| {}: {:.3e}".format(k, v.item()) if accelerator.is_local_main_process: progress_bar.write(log_str) if is_wandb_available(): wandb.log(val_logs) if args.output_dir is not None: accelerator.wait_for_everyone() unwrapped_model = accelerator.unwrap_model(model) unwrapped_model.save_pretrained( args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save ) if accelerator.is_main_process: if args.push_to_hub: repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
def main(): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. # We now keep distinct sets of args, for a cleaner separation of concerns. parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) model_args, data_args, training_args = parser.parse_args_into_dataclasses() configure_logger(model_args, training_args) # Downloading and loading a dataset from the hub. datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir) if "validation" not in datasets.keys(): # make sure only "validation" and "train" keys remain" datasets = DatasetDict() datasets["validation"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split=f"{data_args.train_split_name}[:{data_args.validation_split_percentage}%]", cache_dir=model_args.cache_dir, ) datasets["train"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split=f"{data_args.train_split_name}[{data_args.validation_split_percentage}%:]", cache_dir=model_args.cache_dir, ) else: # make sure only "validation" and "train" keys remain" datasets = DatasetDict() datasets["validation"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split="validation", cache_dir=model_args.cache_dir, ) datasets["train"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split=f"{data_args.train_split_name}", cache_dir=model_args.cache_dir, ) # only normalized-inputs-training is supported feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( model_args.model_name_or_path, cache_dir=model_args.cache_dir, do_normalize=True ) def prepare_dataset(batch): # check that all files have the correct sampling rate batch["speech"], _ = librosa.load(batch[data_args.speech_file_column], sr=feature_extractor.sampling_rate) return batch # load audio files into numpy arrays vectorized_datasets = datasets.map( prepare_dataset, num_proc=data_args.preprocessing_num_workers, remove_columns=datasets["train"].column_names ) # filter audio files that are too long vectorized_datasets = vectorized_datasets.filter( lambda data: len(data["speech"]) < int(data_args.max_duration_in_seconds * feature_extractor.sampling_rate) ) def normalize(batch): return feature_extractor(batch["speech"], sampling_rate=feature_extractor.sampling_rate) # normalize and transform to `BatchFeatures` vectorized_datasets = vectorized_datasets.map( normalize, batched=True, num_proc=data_args.preprocessing_num_workers, load_from_cache_file=not data_args.overwrite_cache, remove_columns=vectorized_datasets["train"].column_names, ) # pretraining is only supported for "newer" stable layer norm architecture # apply_spec_augment has to be True, mask_feature_prob has to be 0.0 config = Wav2Vec2Config.from_pretrained( model_args.model_name_or_path, cache_dir=model_args.cache_dir, gradient_checkpointing=model_args.gradient_checkpointing, ) if not config.do_stable_layer_norm or config.feat_extract_norm != "layer": raise ValueError( "PreTraining is only supported for ``config.do_stable_layer_norm=True`` and ``config.feat_extract_norm='layer'" ) model = FlaxWav2Vec2ForPreTraining( config, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype) ) data_collator = FlaxDataCollatorForWav2Vec2Pretraining( model=model, feature_extractor=feature_extractor, pad_to_multiple_of=data_args.pad_to_multiple_of ) # Enable tensorboard only on the master node has_tensorboard = is_tensorboard_available() if has_tensorboard and jax.process_index() == 0: try: from flax.metrics.tensorboard import SummaryWriter summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir)) except ImportError as ie: has_tensorboard = False logger.warning( f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" ) else: logger.warning( "Unable to display metrics through TensorBoard because the package is not installed: " "Please run pip install tensorboard to enable." ) # Initialize our training rng = jax.random.PRNGKey(training_args.seed) dropout_rngs = jax.random.split(rng, jax.local_device_count()) gumbel_rngs = jax.random.split(rng, jax.local_device_count()) num_epochs = int(training_args.num_train_epochs) train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() num_train_steps = len(vectorized_datasets["train"]) // train_batch_size * num_epochs # Create learning rate schedule warmup_fn = optax.linear_schedule( init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps ) decay_fn = optax.linear_schedule( init_value=training_args.learning_rate, end_value=0, transition_steps=num_train_steps - training_args.warmup_steps, ) linear_decay_lr_schedule_fn = optax.join_schedules( schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps] ) # We use Optax's "masking" functionality to not apply weight decay # to bias and LayerNorm scale parameters. decay_mask_fn returns a # mask boolean with the same structure as the parameters. # The mask is True for parameters that should be decayed. def decay_mask_fn(params): flat_params = traverse_util.flatten_dict(params) flat_mask = { path: (path[-1] != "bias" and path[-2:] not in [("layer_norm", "scale"), ("final_layer_norm", "scale")]) for path in flat_params } return traverse_util.unflatten_dict(flat_mask) # create adam optimizer adamw = optax.adamw( learning_rate=linear_decay_lr_schedule_fn, b1=training_args.adam_beta1, b2=training_args.adam_beta2, eps=training_args.adam_epsilon, weight_decay=training_args.weight_decay, mask=decay_mask_fn, ) # Setup train state and define training hyper-parameters state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw) num_negatives = model.config.num_negatives contrastive_logits_temperature = model.config.contrastive_logits_temperature num_codevectors = model.config.num_codevectors_per_group * model.config.num_codevector_groups diversity_loss_weight = model.config.diversity_loss_weight # Define gradient update step fn def train_step(state, batch, dropout_rng, gumbel_rng): dropout_rng, new_dropout_rng = jax.random.split(dropout_rng) gumbel_rng, new_gumbel_rng = jax.random.split(gumbel_rng) def loss_fn(params): negative_indices = batch.pop("sampled_negative_indices") gumbel_temperature = jnp.clip( model_args.max_gumbel_temperature * model_args.gumbel_temperature_decay ** state.step, a_min=model_args.min_gumbel_temperature, ) outputs = state.apply_fn( **batch, gumbel_temperature=gumbel_temperature, params=params, dropout_rng=dropout_rng, gumbel_rng=gumbel_rng, train=True, ) contrastive_loss = compute_contrastive_loss( outputs.projected_quantized_states, outputs.projected_states, negative_indices, batch["mask_time_indices"], contrastive_logits_temperature, num_negatives, ) diversity_loss = (num_codevectors - outputs.codevector_perplexity) / num_codevectors loss = contrastive_loss + diversity_loss_weight * diversity_loss return loss grad_fn = jax.value_and_grad(loss_fn) loss, grad = grad_fn(state.params) grad = jax.lax.pmean(grad, "batch") new_state = state.apply_gradients(grads=grad) metrics = jax.lax.pmean( {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch" ) return new_state, metrics, new_dropout_rng, new_gumbel_rng # Create parallel version of the train step p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,)) # Define eval fn def eval_step(params, batch): negative_indices = batch.pop("sampled_negative_indices") outputs = model(**batch, params=params, train=False) contrastive_loss = compute_contrastive_loss( outputs.projected_quantized_states, outputs.projected_states, negative_indices, batch["mask_time_indices"], contrastive_logits_temperature, num_negatives, ) diversity_loss = (num_codevectors - outputs.codevector_perplexity) / num_codevectors loss = contrastive_loss + diversity_loss_weight * diversity_loss # summarize metrics metrics = {"loss": loss.mean(), "codevector_perplexity": outputs.codevector_perplexity} metrics = jax.lax.pmean(metrics, axis_name="batch") return metrics p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,)) # Replicate the train state on each device state = jax_utils.replicate(state) train_time = 0 epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) for epoch in epochs: # ======================== Training ================================ train_start = time.time() train_metrics = [] # Create sampling rng rng, input_rng = jax.random.split(rng) # Generate an epoch by shuffling sampling indices from the train dataset num_train_samples = len(vectorized_datasets["train"]) train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples)) train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size) # Gather the indexes for creating the batch and do a training step for i, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)): samples = [vectorized_datasets["train"][int(idx)] for idx in batch_idx] model_inputs = data_collator(samples) model_inputs = shard(model_inputs.data) # Model forward state, train_metric, dropout_rngs, gumbel_rngs = p_train_step( state, model_inputs, dropout_rngs, gumbel_rngs ) train_metrics.append(train_metric) train_time += time.time() - train_start epochs.write( f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})" ) # ======================== Evaluating ============================== num_eval_samples = len(vectorized_datasets["validation"]) eval_samples_idx = jnp.arange(num_eval_samples) eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) eval_metrics = [] for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): samples = [vectorized_datasets["validation"][int(idx)] for idx in batch_idx] model_inputs = data_collator(samples) # Model forward model_inputs = shard(model_inputs.data) metrics = p_eval_step(state.params, model_inputs) eval_metrics.append(metrics) # get eval metrics eval_metrics = get_metrics(eval_metrics) eval_metrics = jax.tree_map(jnp.mean, eval_metrics) # Update progress bar epochs.write( f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {eval_metrics['loss']}, Perplexity: {eval_metrics['codevector_perplexity']})" ) # Save metrics if has_tensorboard and jax.process_index() == 0: cur_step = epoch * (len(vectorized_datasets["train"]) // train_batch_size) write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step) # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) model.save_pretrained(training_args.output_dir, params=params, push_to_hub=training_args.push_to_hub)
def convert_wav2vec2_checkpoint( checkpoint_path, pytorch_dump_folder_path, dict_path, encoder_config_path, decoder_config_path, vocab_size, num_decoder_layers, ): """ Copy/paste/tweak model's weights to transformers design. """ encoder_config = Wav2Vec2Config.from_pretrained(encoder_config_path) decoder_config = Speech2Text2Config.from_pretrained( decoder_config_path, vocab_size=vocab_size, decoder_layers=num_decoder_layers, do_stable_layer_norm=True ) feature_extractor = Wav2Vec2FeatureExtractor( feature_size=1, sampling_rate=16000, padding_value=0, do_normalize=True, return_attention_mask=True, ) model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task( [checkpoint_path], arg_overrides={"data": "/".join(dict_path.split("/")[:-1])} ) model = model[0].eval() # set weights for wav2vec2 encoder hf_encoder = Wav2Vec2Model(encoder_config) projection_layer = recursively_load_weights_wav2vec2(model.encoder, hf_encoder) hf_decoder = Speech2Text2ForCausalLM(decoder_config) missing_keys, unexpected_keys = hf_decoder.model.decoder.load_state_dict(model.decoder.state_dict(), strict=False) # set output linear layer unexpected_keys.remove("embed_out") hf_decoder.lm_head.weight = nn.Parameter(model.decoder.embed_out.detach()) # layer norm is init to identity matrix so leaving it is fine logger.warning(f"The following keys are missing when loading the decoder weights: {missing_keys}") logger.warning(f"The following keys are unexpected when loading the decoder weights: {unexpected_keys}") hf_wav2vec = SpeechEncoderDecoderModel(encoder=hf_encoder, decoder=hf_decoder) hf_wav2vec.config.tie_word_embeddings = False # add projection layer hf_wav2vec.enc_to_dec_proj.weight = nn.Parameter(projection_layer.weight) hf_wav2vec.enc_to_dec_proj.bias = nn.Parameter(projection_layer.bias) vocab_dict = create_vocab_dict(dict_path) with open(os.path.join(pytorch_dump_folder_path, "vocab.json"), "w") as fp: json.dump(vocab_dict, fp) tokenizer = Speech2Text2Tokenizer(os.path.join(pytorch_dump_folder_path, "vocab.json")) tokenizer.save_pretrained(pytorch_dump_folder_path) config = hf_wav2vec.config.to_dict() config["pad_token_id"] = tokenizer.pad_token_id config["bos_token_id"] = tokenizer.bos_token_id config["eos_token_id"] = tokenizer.eos_token_id config["tokenizer_class"] = "speech_to_text_2" config["feature_extractor_type"] = "wav2vec2" hf_wav2vec.config = SpeechEncoderDecoderConfig.from_dict(config) hf_wav2vec.save_pretrained(pytorch_dump_folder_path) feature_extractor.save_pretrained(pytorch_dump_folder_path)
if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--seed', type=int, default=1234) parser.add_argument('--num_workers', type=int, default=2) parser.add_argument('--batch_size', type=int, default=16) parser.add_argument('--n_epochs', type=int, default=100) parser.add_argument('--lr', type=float, default=0.000005) parser.add_argument('--verbose', type=bool, default=True) parser.add_argument('--accumulate_steps', type=int, default=1) parser.add_argument('--step_scheduler', type=bool, default=False) parser.add_argument('--validation_scheduler', type=bool, default=True) parser.add_argument('--weight_decay', type=float, default=1e-6) parser.add_argument('--val_freq', type=int, default=1) parser.add_argument('--data_dir', type=str, default='data/fluent_speech_commands_dataset') parser.add_argument('--output_dir', type=str, default='data/model/slu') parser.add_argument('--pretrained_dir', type=str, default='data/model/wav2vec2-base-960h') parser.add_argument('--pretrained_model', type=str, default='facebook/wav2vec2-base-960h') args = parser.parse_args() set_seed(args.seed) _config = Wav2Vec2Config.from_pretrained(args.pretrained_model, cache_dir=args.pretrained_dir) _device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') _net = SLUModel(_config, args.pretrained_model, args.pretrained_dir).to(_device) run_training(_net, args, _device)
def convert_wav2vec2_checkpoint( checkpoint_path, pytorch_dump_folder_path, dict_path, config_yaml_path, encoder_config_path, decoder_config_path, add_adapter, adapter_kernel_size, adapter_stride, decoder_start_token_id, encoder_output_dim, ): """ Copy/paste/tweak model's weights to transformers design. """ # load configs encoder_config = Wav2Vec2Config.from_pretrained( encoder_config_path, add_adapter=True, adapter_stride=adapter_stride, adapter_kernel_size=adapter_kernel_size, use_auth_token=True, output_hidden_size=encoder_output_dim, ) decoder_config = MBartConfig.from_pretrained(decoder_config_path) # load model model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task( [checkpoint_path], arg_overrides={ "config_yaml": config_yaml_path, "data": "/".join(dict_path.split("/")[:-1]), "w2v_path": checkpoint_path, "load_pretrained_decoder_from": None, }, ) model = model[0].eval() # load feature extractor feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( encoder_config_path, use_auth_token=True) # set weights for wav2vec2 encoder hf_encoder = Wav2Vec2Model(encoder_config) recursively_load_weights_wav2vec2(model.encoder, hf_encoder) # load decoder weights hf_decoder = MBartForCausalLM(decoder_config) missing_keys, unexpected_keys = hf_decoder.model.decoder.load_state_dict( model.decoder.state_dict(), strict=False) logger.warning( f"The following keys are missing when loading the decoder weights: {missing_keys}" ) logger.warning( f"The following keys are unexpected when loading the decoder weights: {unexpected_keys}" ) hf_wav2vec = SpeechEncoderDecoderModel(encoder=hf_encoder, decoder=hf_decoder) hf_wav2vec.config.tie_word_embeddings = False tokenizer = MBart50Tokenizer(dict_path) tokenizer.save_pretrained(pytorch_dump_folder_path) config = hf_wav2vec.config.to_dict() config["pad_token_id"] = tokenizer.pad_token_id config["bos_token_id"] = tokenizer.bos_token_id config["eos_token_id"] = tokenizer.eos_token_id config["tokenizer_class"] = "mbart50" config["feature_extractor_type"] = "wav2vec2" config["decoder_start_token_id"] = tokenizer.eos_token_id config["forced_bos_token_id"] = 250004 config["forced_eos_token_id"] = tokenizer.eos_token_id hf_wav2vec.config = SpeechEncoderDecoderConfig.from_dict(config) hf_wav2vec.save_pretrained(pytorch_dump_folder_path) feature_extractor.save_pretrained(pytorch_dump_folder_path)