def __init__( self, jointnet: Dict[str, Any], num_classes: int, vocabulary: Optional[List] = None, log_softmax: Optional[bool] = None, preserve_memory: bool = False, fuse_loss_wer: bool = False, fused_batch_size: Optional[int] = None, experimental_fuse_loss_wer: Any = None, ): super().__init__() self.vocabulary = vocabulary self._vocab_size = num_classes self._num_classes = num_classes + 1 # add 1 for blank symbol if experimental_fuse_loss_wer is not None: # TODO: Deprecate in 1.6 logging.warning( "`experimental_fuse_loss_wer` will be deprecated in NeMo 1.6. Please use `fuse_loss_wer` instead." ) # Override fuse_loss_wer from deprecated argument fuse_loss_wer = experimental_fuse_loss_wer self._fuse_loss_wer = fuse_loss_wer self._fused_batch_size = fused_batch_size if fuse_loss_wer and (fused_batch_size is None): raise ValueError("If `fuse_loss_wer` is set, then `fused_batch_size` cannot be None!") self._loss = None self._wer = None # Log softmax should be applied explicitly only for CPU self.log_softmax = log_softmax self.preserve_memory = preserve_memory if preserve_memory: logging.warning( "`preserve_memory` was set for the Joint Model. Please be aware this will severely impact " "the forward-backward step time. It also might not solve OOM issues if the GPU simply " "does not have enough memory to compute the joint." ) # Required arguments self.encoder_hidden = jointnet['encoder_hidden'] self.pred_hidden = jointnet['pred_hidden'] self.joint_hidden = jointnet['joint_hidden'] self.activation = jointnet['activation'] # Optional arguments dropout = jointnet.get('dropout', 0.0) self.pred, self.enc, self.joint_net = self._joint_net_modules( num_classes=self._num_classes, # add 1 for blank symbol pred_n_hidden=self.pred_hidden, enc_n_hidden=self.encoder_hidden, joint_n_hidden=self.joint_hidden, activation=self.activation, dropout=dropout, ) # Flag needed for RNNT export support self._rnnt_export = False
def check_resume( trainer: 'pytorch_lightning.Trainer', log_dir: str, resume_past_end: bool = False, resume_ignore_no_checkpoint: bool = False, ): """Checks that resume=True was used correctly with the arguments pass to exp_manager. Sets trainer.resume_from_checkpoint as necessary. Returns: log_dir (Path): the log_dir exp_dir (str): the base exp_dir without name nor version name (str): The name of the experiment version (str): The version of the experiment Raises: NotFoundError: If resume is True, resume_ignore_no_checkpoint is False, and checkpoints could not be found. ValueError: If resume is True, and there were more than 1 checkpoint could found. """ if not log_dir: raise ValueError( f"Resuming requires the log_dir {log_dir} to be passed to exp_manager" ) checkpoint_dir = Path(Path(log_dir) / "checkpoints") checkpoint = None end_checkpoints = list(checkpoint_dir.glob("*end.ckpt")) end_checkpoints.extend(list(checkpoint_dir.glob("*.nemo"))) last_checkpoints = list(checkpoint_dir.glob("*last.ckpt")) if not checkpoint_dir.exists(): if resume_ignore_no_checkpoint: logging.warning( f"There was no checkpoint folder at checkpoint_dir :{checkpoint_dir}. Training from scratch." ) return else: raise NotFoundError( f"There was no checkpoint folder at checkpoint_dir :{checkpoint_dir}. Cannot resume." ) elif len(end_checkpoints) > 0: if resume_past_end: if len(end_checkpoints) > 1: raise ValueError( f"Multiple multiple checkpoints {end_checkpoints} that matches *end.ckpt." ) logging.info(f"Resuming from {end_checkpoints[0]}") checkpoint = end_checkpoints[0] else: raise ValueError( f"Found {end_checkpoints[0]} indicating that the last training run has already completed." ) elif not len(last_checkpoints) > 0: if resume_ignore_no_checkpoint: logging.warning( f"There were no checkpoints found in {checkpoint_dir}. Training from scratch." ) return else: raise NotFoundError( f"There were no checkpoints found in {checkpoint_dir}. Cannot resume." ) elif len(last_checkpoints) > 1: raise ValueError( f"Multiple multiple checkpoints {last_checkpoints} that matches *last.ckpt." ) else: logging.info(f"Resuming from {last_checkpoints[0]}") checkpoint = last_checkpoints[0] trainer.resume_from_checkpoint = str(checkpoint) if is_global_rank_zero(): # Check to see if any files exist that need to be moved files_to_move = [] for child in Path(log_dir).iterdir(): if child.is_file(): files_to_move.append(child) if len(files_to_move) > 0: # Move old files to a new folder other_run_dirs = Path(log_dir).glob("run_*") run_count = 0 for fold in other_run_dirs: if fold.is_dir(): run_count += 1 new_run_dir = Path(Path(log_dir) / f"run_{run_count}") new_run_dir.mkdir() for _file in files_to_move: move(str(_file), str(new_run_dir))
def get_final_text(pred_text: str, orig_text: str, do_lower_case: bool, verbose_logging: bool = False): """Project the tokenized prediction back to the original text. When we created the data, we kept track of the alignment between original (whitespace tokenized) tokens and our WordPiece tokenized tokens. So now `orig_text` contains the span of our original text corresponding to the span that we predicted. However, `orig_text` may contain extra characters that we don't want in our prediction. For example, let's say: pred_text = steve smith orig_text = Steve Smith's We don't want to return `orig_text` because it contains the extra "'s". We don't want to return `pred_text` because it's already been normalized (the SQuAD eval script also does punctuation stripping/lower casing but our tokenizer does additional normalization like stripping accent characters). What we really want to return is "Steve Smith". Therefore, we have to apply a semi-complicated alignment heuristic between `pred_text` and `orig_text` to get a character-to-character alignment. This can fail in certain cases in which case we just return `orig_text`.""" def _strip_spaces(text): ns_chars = [] ns_to_s_map = collections.OrderedDict() for (i, c) in enumerate(text): if c == " ": continue ns_to_s_map[len(ns_chars)] = i ns_chars.append(c) ns_text = "".join(ns_chars) return ns_text, ns_to_s_map # We first tokenize `orig_text`, strip whitespace from the result # and `pred_text`, and check if they are the same length. If they are # NOT the same length, the heuristic has failed. If they are the same # length, we assume the characters are one-to-one aligned. tokenizer = BasicTokenizer(do_lower_case=do_lower_case) tok_text = " ".join(tokenizer.tokenize(orig_text)) start_position = tok_text.find(pred_text) if start_position == -1: if verbose_logging: logging.warning("Unable to find text: '%s' in '%s'" % (pred_text, orig_text)) return orig_text end_position = start_position + len(pred_text) - 1 (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text) (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text) if len(orig_ns_text) != len(tok_ns_text): if verbose_logging: logging.warning( "Length not equal after stripping spaces: '%s' vs '%s'", orig_ns_text, tok_ns_text, ) return orig_text # We then project the characters in `pred_text` back to `orig_text` using # the character-to-character alignment. tok_s_to_ns_map = {} for (i, tok_index) in tok_ns_to_s_map.items(): tok_s_to_ns_map[tok_index] = i orig_start_position = None if start_position in tok_s_to_ns_map: ns_start_position = tok_s_to_ns_map[start_position] if ns_start_position in orig_ns_to_s_map: orig_start_position = orig_ns_to_s_map[ns_start_position] if orig_start_position is None: if verbose_logging: logging.warning("Couldn't map start position") return orig_text orig_end_position = None if end_position in tok_s_to_ns_map: ns_end_position = tok_s_to_ns_map[end_position] if ns_end_position in orig_ns_to_s_map: orig_end_position = orig_ns_to_s_map[ns_end_position] if orig_end_position is None: if verbose_logging: logging.warning("Couldn't map end position") return orig_text output_text = orig_text[orig_start_position:(orig_end_position + 1)] return output_text
def _setup_dataloader_from_config(self, config: Optional[Dict]): if 'augmentor' in config: augmentor = process_augmentations(config['augmentor']) else: augmentor = None shuffle = config['shuffle'] device = 'gpu' if torch.cuda.is_available() else 'cpu' if config.get('use_dali', False): device_id = self.local_rank if device == 'gpu' else None dataset = audio_to_text_dataset.get_dali_char_dataset( config=config, shuffle=shuffle, device_id=device_id, global_rank=self.global_rank, world_size=self.world_size, preprocessor_cfg=self._cfg.preprocessor, ) return dataset # Instantiate tarred dataset loader or normal dataset loader if config.get('is_tarred', False): if ('tarred_audio_filepaths' in config and config['tarred_audio_filepaths'] is None) or ( 'manifest_filepath' in config and config['manifest_filepath'] is None): logging.warning( "Could not load dataset as `manifest_filepath` was None or " f"`tarred_audio_filepaths` is None. Provided config : {config}" ) return None shuffle_n = config.get('shuffle_n', 4 * config['batch_size']) if shuffle else 0 dataset = audio_to_text_dataset.get_tarred_char_dataset( config=config, shuffle_n=shuffle_n, global_rank=self.global_rank, world_size=self.world_size, augmentor=augmentor, ) shuffle = False else: if 'manifest_filepath' in config and config[ 'manifest_filepath'] is None: logging.warning( f"Could not load dataset as `manifest_filepath` was None. Provided config : {config}" ) return None dataset = audio_to_text_dataset.get_char_dataset( config=config, augmentor=augmentor) return torch.utils.data.DataLoader( dataset=dataset, batch_size=config['batch_size'], collate_fn=dataset.collate_fn, drop_last=config.get('drop_last', False), shuffle=shuffle, num_workers=config.get('num_workers', 0), pin_memory=config.get('pin_memory', False), )
def perform_clustering(embs_and_timestamps, AUDIO_RTTM_MAP, out_rttm_dir, clustering_params): """ performs spectral clustering on embeddings with time stamps generated from VAD output Args: embs_and_timestamps (dict): This dictionary contains the following items indexed by unique IDs. 'embeddings' : Embeddings with key as unique_id 'time_stamps' : Time stamps list for each audio recording AUDIO_RTTM_MAP (dict): AUDIO_RTTM_MAP for mapping unique id with audio file path and rttm path out_rttm_dir (str): Path to write predicted rttms clustering_params (dict): clustering parameters provided through config that contains max_num_speakers (int), oracle_num_speakers (bool), max_rp_threshold(float), sparse_search_volume(int) and enhance_count_threshold (int) Returns: all_reference (list[uniq_name,Annotation]): reference annotations for score calculation all_hypothesis (list[uniq_name,Annotation]): hypothesis annotations for score calculation """ all_hypothesis = [] all_reference = [] no_references = False max_num_speakers = clustering_params['max_num_speakers'] lines_cluster_labels = [] cuda = True if not torch.cuda.is_available(): logging.warning( "cuda=False, using CPU for Eigen decompostion. This might slow down the clustering process." ) cuda = False for uniq_id, value in tqdm(AUDIO_RTTM_MAP.items()): if clustering_params.oracle_num_speakers: num_speakers = value.get('num_speakers', None) if num_speakers is None: raise ValueError( "Provided option as oracle num of speakers but num_speakers in manifest is null" ) else: num_speakers = None cluster_labels = COSclustering( uniq_embs_and_timestamps=embs_and_timestamps[uniq_id], oracle_num_speakers=num_speakers, max_num_speaker=max_num_speakers, enhanced_count_thres=clustering_params.enhanced_count_thres, max_rp_threshold=clustering_params.max_rp_threshold, sparse_search_volume=clustering_params.sparse_search_volume, cuda=cuda, ) base_scale_idx = max(embs_and_timestamps[uniq_id]['scale_dict'].keys()) lines = embs_and_timestamps[uniq_id]['scale_dict'][base_scale_idx][ 'time_stamps'] assert len(cluster_labels) == len(lines) for idx, label in enumerate(cluster_labels): tag = 'speaker_' + str(label) lines[idx] += tag a = get_contiguous_stamps(lines) labels = merge_stamps(a) if out_rttm_dir: labels_to_rttmfile(labels, uniq_id, out_rttm_dir) lines_cluster_labels.extend( [f'{uniq_id} {seg_line}\n' for seg_line in lines]) hypothesis = labels_to_pyannote_object(labels, uniq_name=uniq_id) all_hypothesis.append([uniq_id, hypothesis]) rttm_file = value.get('rttm_filepath', None) if rttm_file is not None and os.path.exists( rttm_file) and not no_references: ref_labels = rttm_to_labels(rttm_file) reference = labels_to_pyannote_object(ref_labels, uniq_name=uniq_id) all_reference.append([uniq_id, reference]) else: no_references = True all_reference = [] if out_rttm_dir: write_cluster_labels(base_scale_idx, lines_cluster_labels, out_rttm_dir) return all_reference, all_hypothesis
def __init__( self, ids: List[int], audio_files: List[str], durations: List[float], texts: List[str], offsets: List[str], speakers: List[Optional[int]], orig_sampling_rates: List[Optional[int]], parser: parsers.CharParser, min_duration: Optional[float] = None, max_duration: Optional[float] = None, max_number: Optional[int] = None, do_sort_by_duration: bool = False, index_by_file_id: bool = False, ): """Instantiates audio-text manifest with filters and preprocessing. Args: ids: List of examples positions. audio_files: List of audio files. durations: List of float durations. texts: List of raw text transcripts. offsets: List of duration offsets or None. speakers: List of optional speakers ids. orig_sampling_rates: List of original sampling rates of audio files. parser: Instance of `CharParser` to convert string to tokens. min_duration: Minimum duration to keep entry with (default: None). max_duration: Maximum duration to keep entry with (default: None). max_number: Maximum number of samples to collect. do_sort_by_duration: True if sort samples list by duration. Not compatible with index_by_file_id. index_by_file_id: If True, saves a mapping from filename base (ID) to index in data. """ output_type = self.OUTPUT_TYPE data, duration_filtered, num_filtered, total_duration = [], 0.0, 0, 0.0 if index_by_file_id: self.mapping = {} for id_, audio_file, duration, offset, text, speaker, orig_sr in zip( ids, audio_files, durations, offsets, texts, speakers, orig_sampling_rates): # Duration filters. if min_duration is not None and duration < min_duration: duration_filtered += duration num_filtered += 1 continue if max_duration is not None and duration > max_duration: duration_filtered += duration num_filtered += 1 continue text_tokens = parser(text) if text_tokens is None: duration_filtered += duration num_filtered += 1 continue total_duration += duration data.append( output_type(id_, audio_file, duration, text_tokens, offset, text, speaker, orig_sr)) if index_by_file_id: file_id, _ = os.path.splitext(os.path.basename(audio_file)) self.mapping[file_id] = len(data) - 1 # Max number of entities filter. if len(data) == max_number: break if do_sort_by_duration: if index_by_file_id: logging.warning( "Tried to sort dataset by duration, but cannot since index_by_file_id is set." ) else: data.sort(key=lambda entity: entity.duration) logging.info("Dataset loaded with %d files totalling %.2f hours", len(data), total_duration / 3600) logging.info("%d files were filtered totalling %.2f hours", num_filtered, duration_filtered / 3600) super().__init__(data)
def export( self, output: str, input_example=None, output_example=None, verbose=False, export_params=True, do_constant_folding=True, keep_initializers_as_inputs=False, onnx_opset_version: int = 12, try_script: bool = False, set_eval: bool = True, check_trace: bool = True, use_dynamic_axes: bool = True, ): if input_example is not None or output_example is not None: logging.warning( "Passed input and output examples will be ignored and recomputed since" " QAModel consists of two separate models with different" " inputs and outputs.") qual_name = self.__module__ + '.' + self.__class__.__qualname__ output1 = os.path.join(os.path.dirname(output), 'bert_' + os.path.basename(output)) output1_descr = qual_name + ' BERT exported to ONNX' bert_model_onnx = self.bert_model.export( output1, None, # computed by input_example() None, verbose, export_params, do_constant_folding, keep_initializers_as_inputs, onnx_opset_version, try_script, set_eval, check_trace, use_dynamic_axes, ) output2 = os.path.join(os.path.dirname(output), 'classifier_' + os.path.basename(output)) output2_descr = qual_name + ' Classifier exported to ONNX' classifier_onnx = self.classifier.export( output2, None, # computed by input_example() None, verbose, export_params, do_constant_folding, keep_initializers_as_inputs, onnx_opset_version, try_script, set_eval, check_trace, use_dynamic_axes, ) output_model = attach_onnx_to_onnx(bert_model_onnx, classifier_onnx, "QA") output_descr = qual_name + ' BERT+Classifier exported to ONNX' onnx.save(output_model, output) return ([output, output1, output2], [output_descr, output1_descr, output2_descr])
def __init__( self, *, audio_tar_filepaths: Union[str, List[str]], manifest_filepath: str, labels: List[str], featurizer, shuffle_n: int = 0, min_duration: Optional[float] = 0.1, max_duration: Optional[float] = None, trim: bool = False, load_audio: bool = True, shard_strategy: str = "scatter", global_rank: int = 0, world_size: int = 0, ): self.collection = collections.ASRSpeechLabel( manifests_files=manifest_filepath.split(','), min_duration=min_duration, max_duration=max_duration, index_by_file_id= True, # Must set this so the manifest lines can be indexed by file ID ) self.file_occurence = count_occurence(self.collection.mapping) self.featurizer = featurizer self.trim = trim self.load_audio = load_audio self.labels = labels if labels else self.collection.uniq_labels self.num_classes = len(self.labels) self.label2id, self.id2label = {}, {} for label_id, label in enumerate(self.labels): self.label2id[label] = label_id self.id2label[label_id] = label for idx in range(len(self.labels[:5])): logging.debug(" label id {} and its mapped label {}".format( idx, self.id2label[idx])) valid_shard_strategies = ['scatter', 'replicate'] if shard_strategy not in valid_shard_strategies: raise ValueError( f"`shard_strategy` must be one of {valid_shard_strategies}") if isinstance(audio_tar_filepaths, str): # Replace '(' and '[' with '{' brace_keys_open = ['(', '[', '<', '_OP_'] for bkey in brace_keys_open: if bkey in audio_tar_filepaths: audio_tar_filepaths = audio_tar_filepaths.replace( bkey, "{") # Replace ')' and ']' with '}' brace_keys_close = [')', ']', '>', '_CL_'] for bkey in brace_keys_close: if bkey in audio_tar_filepaths: audio_tar_filepaths = audio_tar_filepaths.replace( bkey, "}") # Check for distributed and partition shards accordingly if world_size > 1: if isinstance(audio_tar_filepaths, str): # Brace expand audio_tar_filepaths = list( braceexpand.braceexpand(audio_tar_filepaths)) if shard_strategy == 'scatter': logging.info( "All tarred dataset shards will be scattered evenly across all nodes." ) if len(audio_tar_filepaths) % world_size != 0: logging.warning( f"Number of shards in tarred dataset ({len(audio_tar_filepaths)}) is not divisible " f"by number of distributed workers ({world_size}).") begin_idx = (len(audio_tar_filepaths) // world_size) * global_rank end_idx = begin_idx + (len(audio_tar_filepaths) // world_size) audio_tar_filepaths = audio_tar_filepaths[begin_idx:end_idx] logging.info( "Partitioning tarred dataset: process (%d) taking shards [%d, %d)", global_rank, begin_idx, end_idx) elif shard_strategy == 'replicate': logging.info( "All tarred dataset shards will be replicated across all nodes." ) else: raise ValueError( f"Invalid shard strategy ! Allowed values are : {valid_shard_strategies}" ) # Put together WebDataset self._dataset = ( wd.Dataset(audio_tar_filepaths).shuffle(shuffle_n).rename( audio='wav', key='__key__').to_tuple('audio', 'key').pipe( self._filter).map(f=self._build_sample))
def main(): parser = argparse.ArgumentParser( description='Evaluate an ASR model with beam search decoding and n-gram KenLM language model.' ) parser.add_argument( "--nemo_model_file", required=True, type=str, help="The path of the '.nemo' file of the ASR model" ) parser.add_argument( "--kenlm_model_file", required=False, default=None, type=str, help="The path of the KenLM binary model file" ) parser.add_argument("--input_manifest", required=True, type=str, help="The manifest file of the evaluation set") parser.add_argument( "--preds_output_folder", default=None, type=str, help="The optional folder where the predictions are stored" ) parser.add_argument( "--probs_cache_file", default=None, type=str, help="The cache file for storing the outputs of the model" ) parser.add_argument( "--acoustic_batch_size", default=16, type=int, help="The batch size to calculate log probabilities" ) parser.add_argument( "--device", default="cuda", type=str, help="The device to load the model onto to calculate log probabilities" ) parser.add_argument( "--use_amp", action="store_true", help="Whether to use AMP if available to calculate log probabilities" ) parser.add_argument( "--decoding_mode", choices=["greedy", "beamsearch", "beamsearch_ngram"], default="beamsearch_ngram", type=str, help="The decoding scheme to be used for evaluation.", ) parser.add_argument( "--beam_width", required=True, type=int, nargs="+", help="The width or list of the widths for the beam search decoding", ) parser.add_argument( "--beam_alpha", required=True, type=float, nargs="+", help="The alpha parameter or list of the alphas for the beam search decoding", ) parser.add_argument( "--beam_beta", required=True, type=float, nargs="+", help="The beta parameter or list of the betas for the beam search decoding", ) parser.add_argument( "--beam_batch_size", default=128, type=int, help="The batch size to be used for beam search decoding" ) args = parser.parse_args() asr_model = nemo_asr.models.EncDecCTCModelBPE.restore_from( args.nemo_model_file, map_location=torch.device(args.device) ) target_transcripts = [] with open(args.input_manifest, 'r') as manifest_file: audio_file_paths = [] for line in tqdm(manifest_file, desc=f"Reading Manifest {args.input_manifest} ...", ncols=120): data = json.loads(line) target_transcripts.append(data['text']) audio_file_paths.append(data['audio_filepath']) if args.probs_cache_file and os.path.exists(args.probs_cache_file): logging.info(f"Found a pickle file of probabilities at '{args.probs_cache_file}'.") logging.info(f"Loading the cached pickle file of probabilities from '{args.probs_cache_file}' ...") with open(args.probs_cache_file, 'rb') as probs_file: all_probs = pickle.load(probs_file) if len(all_probs) != len(audio_file_paths): raise ValueError( f"The number of samples in the probabilities file '{args.probs_cache_file}' does not " f"match the manifest file. You may need to delete the probabilities cached file." ) else: if args.use_amp: if torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'): logging.info("AMP is enabled!\n") autocast = torch.cuda.amp.autocast else: @contextlib.contextmanager def autocast(): yield with autocast(): with torch.no_grad(): all_logits = asr_model.transcribe(audio_file_paths, batch_size=args.acoustic_batch_size, logprobs=True) all_probs = [kenlm_utils.softmax(logits) for logits in all_logits] if args.probs_cache_file: logging.info(f"Writing pickle files of probabilities at '{args.probs_cache_file}'...") with open(args.probs_cache_file, 'wb') as f_dump: pickle.dump(all_probs, f_dump) wer_dist_greedy = 0 cer_dist_greedy = 0 words_count = 0 chars_count = 0 for batch_idx, probs in enumerate(all_probs): preds = np.argmax(probs, axis=1) preds_tensor = torch.tensor(preds, device='cpu').unsqueeze(0) pred_text = asr_model._wer.ctc_decoder_predictions_tensor(preds_tensor)[0] pred_split_w = pred_text.split() target_split_w = target_transcripts[batch_idx].split() pred_split_c = list(pred_text) target_split_c = list(target_transcripts[batch_idx]) wer_dist = editdistance.eval(target_split_w, pred_split_w) cer_dist = editdistance.eval(target_split_c, pred_split_c) wer_dist_greedy += wer_dist cer_dist_greedy += cer_dist words_count += len(target_split_w) chars_count += len(target_split_c) logging.info('Greedy WER/CER = {:.2%}/{:.2%}'.format(wer_dist_greedy / words_count, cer_dist_greedy / chars_count)) encoding_level = kenlm_utils.SUPPORTED_MODELS.get(type(asr_model).__name__, None) if not encoding_level: logging.warning( f"Model type '{type(asr_model).__name__}' may not be supported. Would try to train a char-level LM." ) encoding_level = 'char' vocab = asr_model.decoder.vocabulary ids_to_text_func = None if encoding_level == "subword": vocab = [chr(idx + TOKEN_OFFSET) for idx in range(len(vocab))] ids_to_text_func = asr_model.tokenizer.ids_to_text # delete the model to free the memory del asr_model if args.decoding_mode == "beamsearch_ngram": if not os.path.exists(args.kenlm_model_file): raise FileNotFoundError(f"Could not find the KenLM model file '{args.kenlm_model_file}'.") lm_path = args.kenlm_model_file else: lm_path = None # 'greedy' decoding_mode would skip the beam search decoding if args.decoding_mode in ["beamsearch_ngram", "beamsearch"]: params = {'beam_width': args.beam_width, 'beam_alpha': args.beam_alpha, 'beam_beta': args.beam_beta} hp_grid = ParameterGrid(params) hp_grid = list(hp_grid) logging.info(f"==============================Starting the beam search decoding===============================") logging.info(f"Grid search size: {len(hp_grid)}") logging.info(f"It may take some time...") logging.info(f"==============================================================================================") if args.preds_output_folder and not os.path.exists(args.preds_output_folder): os.mkdir(args.preds_output_folder) for hp in hp_grid: if args.preds_output_folder: preds_output_file = os.path.join( args.preds_output_folder, f"preds_out_width{hp['beam_width']}_alpha{hp['beam_alpha']}_beta{hp['beam_beta']}.tsv", ) else: preds_output_file = None beam_search_eval( all_probs=all_probs, target_transcripts=target_transcripts, vocab=vocab, ids_to_text_func=ids_to_text_func, preds_output_file=preds_output_file, lm_path=lm_path, beam_width=hp["beam_width"], beam_alpha=hp["beam_alpha"], beam_beta=hp["beam_beta"], beam_batch_size=args.beam_batch_size, progress_bar=True, )
def setup_optimization(self, optim_config: Optional[Union[DictConfig, Dict]] = None): """ Prepares an optimizer from a string name and its optional config parameters. Args: optim_config: A dictionary containing the following keys: * "lr": mandatory key for learning rate. Will raise ValueError if not provided. * "optimizer": string name pointing to one of the available optimizers in the registry. \ If not provided, defaults to "adam". * "opt_args": Optional list of strings, in the format "arg_name=arg_value". \ The list of "arg_value" will be parsed and a dictionary of optimizer kwargs \ will be built and supplied to instantiate the optimizer. """ # If config was not explicitly passed to us if optim_config is None: # See if internal config has `optim` namespace if self._cfg is not None and hasattr(self._cfg, 'optim'): optim_config = self._cfg.optim # If config is still None, or internal config has no Optim, return without instantiation if optim_config is None: logging.info( 'No optimizer config provided, therefore no optimizer was created' ) return else: # Preserve the configuration if not isinstance(optim_config, DictConfig): optim_config = OmegaConf.create(optim_config) # See if internal config has `optim` namespace before preservation if self._cfg is not None and hasattr(self._cfg, 'optim'): self._cfg.optim = optim_config # Setup optimizer and scheduler if optim_config is not None and isinstance(optim_config, DictConfig): optim_config = OmegaConf.to_container(optim_config) if 'sched' in optim_config and self._trainer is not None: if not isinstance(self._trainer.accumulate_grad_batches, int): raise ValueError( "We do not currently support gradient acculumation that is not an integer." ) if self._trainer.max_steps is None: # Store information needed to calculate max_steps optim_config['sched'][ 't_max_epochs'] = self._trainer.max_epochs optim_config['sched'][ 't_accumulate_grad_batches'] = self._trainer.accumulate_grad_batches if self._trainer.distributed_backend is None: optim_config['sched'][ 't_num_workers'] = self._trainer.num_gpus or 1 elif self._trainer.distributed_backend is "ddp_cpu": optim_config['sched'][ 't_num_workers'] = self._trainer.num_processes * self._trainer.num_nodes elif self._trainer.distributed_backend is "ddp": optim_config['sched'][ 't_num_workers'] = self._trainer.num_gpus * self._trainer.num_nodes else: logging.warning( f"The lightning trainer received accelerator: {self._trainer.distributed_backend }. We " "recommend to use 'ddp' instead.") optim_config['sched'][ 't_num_workers'] = self._trainer.num_gpus * self._trainer.num_nodes else: optim_config['sched']['max_steps'] = self._trainer.max_steps # Force into DictConfig from nested structure optim_config = OmegaConf.create(optim_config) # Get back nested dict so we its mutable optim_config = OmegaConf.to_container(optim_config, resolve=True) # Extract scheduler config if inside optimizer config if 'sched' in optim_config: scheduler_config = optim_config.pop('sched') else: scheduler_config = None # Check if caller provided optimizer name, default to Adam otherwise optimizer_cls = optim_config.get('cls', None) if optimizer_cls is None: # Try to get optimizer name for dynamic resolution, defaulting to Adam optimizer_name = optim_config.get('name', 'adam') else: if inspect.isclass(optimizer_cls): optimizer_name = optimizer_cls.__name__.lower() else: # resolve the class name (lowercase) from the class path if not provided optimizer_name = optimizer_cls.split(".")[-1].lower() # We are guarenteed to have lr since it is required by the argparser # But maybe user forgot to pass it to this function lr = optim_config.get('lr', None) if lr is None: raise ValueError( '`lr` must be passed to `optimizer_config` when setting up the optimization !' ) # Check if caller has optimizer kwargs, default to empty dictionary if 'args' in optim_config: optimizer_args = optim_config.pop('args') optimizer_args = optim.parse_optimizer_args( optimizer_name, optimizer_args) else: optimizer_args = copy.deepcopy(optim_config) # Remove extra parameters from optimizer_args nest # Assume all other parameters are to be passed into optimizer constructor optimizer_args.pop('name', None) optimizer_args.pop('cls', None) optimizer_args.pop('lr', None) # Actually instantiate the optimizer if optimizer_cls is not None: if inspect.isclass(optimizer_cls): optimizer = optimizer_cls(self.parameters(), lr=lr, **optimizer_args) logging.info("Optimizer config = %s", str(optimizer)) self._optimizer = optimizer else: # Attempt class path resolution try: optimizer_cls = OmegaConf.create({'cls': optimizer_cls}) optimizer_config = {'lr': lr} optimizer_config.update(optimizer_args) optimizer_instance = hydra.utils.instantiate( optimizer_cls, self.parameters(), **optimizer_config) # type: DictConfig logging.info("Optimizer config = %s", str(optimizer_instance)) self._optimizer = optimizer_instance except Exception as e: logging.error( "Could not instantiate class path - {} with kwargs {}". format(optimizer_cls, str(optimizer_config))) raise e else: optimizer = optim.get_optimizer(optimizer_name) optimizer = optimizer(self.parameters(), lr=lr, **optimizer_args) logging.info("Optimizer config = %s", str(optimizer)) self._optimizer = optimizer # Try to instantiate scheduler for optimizer self._scheduler = prepare_lr_scheduler( optimizer=self._optimizer, scheduler_config=scheduler_config, train_dataloader=self._train_dl) # Return the optimizer with/without scheduler # This return allows multiple optimizers or schedulers to be created return self._optimizer, self._scheduler
def configure_checkpointing( trainer: 'pytorch_lightning.Trainer', log_dir: Path, name: str, params: 'DictConfig', ): """ Adds ModelCheckpoint to trainer. Raises CheckpointMisconfigurationError if trainer already has a ModelCheckpoint callback or if trainer.weights_save_path was passed to Trainer. """ for callback in trainer.callbacks: if isinstance(callback, ModelCheckpoint): raise CheckpointMisconfigurationError( "The pytorch lightning trainer that was passed to exp_manager contained a ModelCheckpoint " "and create_checkpoint_callback was set to True. Please either set create_checkpoint_callback " "to False, or remove ModelCheckpoint from the lightning trainer" ) if Path(trainer.weights_save_path) != Path.cwd(): raise CheckpointMisconfigurationError( "The pytorch lightning was passed weights_save_path. This variable is ignored by exp_manager" ) # Create the callback and attach it to trainer if "filepath" in params: if params.filepath is not None: logging.warning( "filepath is deprecated. Please switch to dirpath and filename instead" ) if params.dirpath is None: params.dirpath = Path(params.filepath).parent if params.filename is None: params.filename = Path(params.filepath).name with open_dict(params): del params["filepath"] if params.dirpath is None: params.dirpath = Path(log_dir / 'checkpoints') if params.filename is None: params.filename = f'{name}--{{{params.monitor}:.2f}}-{{epoch}}' if params.prefix is None: params.prefix = name NeMoModelCheckpoint.CHECKPOINT_NAME_LAST = params.filename + '-last' logging.debug(params.dirpath) logging.debug(params.filename) logging.debug(params.prefix) if "val" in params.monitor: if (trainer.max_epochs is not None and trainer.max_epochs != -1 and trainer.max_epochs < trainer.check_val_every_n_epoch): logging.error( "The checkpoint callback was told to monitor a validation value but trainer.max_epochs(" f"{trainer.max_epochs}) was less than trainer.check_val_every_n_epoch({trainer.check_val_every_n_epoch}" f"). It is very likely this run will fail with ModelCheckpoint(monitor='{params.monitor}') not found " "in the returned metrics. Please ensure that validation is run within trainer.max_epochs." ) elif trainer.max_steps is not None: logging.warning( "The checkpoint callback was told to monitor a validation value and trainer's max_steps was set to " f"{trainer.max_steps}. Please ensure that max_steps will run for at least " f"{trainer.check_val_every_n_epoch} epochs to ensure that checkpointing will not error out." ) checkpoint_callback = NeMoModelCheckpoint(**params) trainer.callbacks.append(checkpoint_callback)
def __init__(self, cfg: DictConfig, trainer: Trainer = None): """ Base class from which all NeMo models should inherit Args: cfg (DictConfig): configuration object. The cfg object should have (optionally) the following sub-configs: * train_ds - to instantiate training dataset * validation_ds - to instantiate validation dataset * test_ds - to instantiate testing dataset * optim - to instantiate optimizer with learning rate scheduler trainer (Optional): Pytorch Lightning Trainer instance """ if not isinstance(cfg, DictConfig): raise ValueError( f"cfg constructor argument must be of type DictConfig but got {type(cfg)} instead." ) if trainer is not None and not isinstance(trainer, Trainer): raise ValueError( f"trainer constructor argument must be either None or pytroch_lightning.Trainer. But got {type(trainer)} instead." ) super().__init__() if 'target' not in cfg: # This is for Jarvis service. OmegaConf.set_struct(cfg, False) cfg.target = "{0}.{1}".format(self.__class__.__module__, self.__class__.__name__) OmegaConf.set_struct(cfg, True) config = OmegaConf.to_container(cfg, resolve=True) config = OmegaConf.create(config) OmegaConf.set_struct(config, True) self._cfg = config self.save_hyperparameters(self._cfg) self._train_dl = None self._validation_dl = None self._test_dl = None self._optimizer = None self._scheduler = None self._trainer = trainer # Set device_id in AppState if torch.cuda.is_available() and torch.cuda.current_device( ) is not None: app_state = AppState() app_state.device_id = torch.cuda.current_device() if self._cfg is not None and not self._is_model_being_restored(): if 'train_ds' in self._cfg and self._cfg.train_ds is not None: self.setup_training_data(self._cfg.train_ds) if 'validation_ds' in self._cfg and self._cfg.validation_ds is not None: self.setup_multiple_validation_data(val_data_config=None) if 'test_ds' in self._cfg and self._cfg.test_ds is not None: self.setup_multiple_test_data(test_data_config=None) else: if 'train_ds' in self._cfg and self._cfg.train_ds is not None: logging.warning( f"Please call the ModelPT.setup_training_data() method " f"and provide a valid configuration file to setup the train data loader.\n" f"Train config : \n{OmegaConf.to_yaml(self._cfg.train_ds)}" ) if 'validation_ds' in self._cfg and self._cfg.validation_ds is not None: logging.warning( f"Please call the ModelPT.setup_validation_data() or ModelPT.setup_multiple_validation_data() method " f"and provide a valid configuration file to setup the validation data loader(s). \n" f"Validation config : \n{OmegaConf.to_yaml(self._cfg.validation_ds)}" ) if 'test_ds' in self._cfg and self._cfg.test_ds is not None: logging.warning( f"Please call the ModelPT.setup_test_data() or ModelPT.setup_multiple_test_data() method " f"and provide a valid configuration file to setup the test data loader(s).\n" f"Test config : \n{OmegaConf.to_yaml(self._cfg.test_ds)}") # ModelPT wrappers over subclass implementations self.training_step = model_utils.wrap_training_step(self.training_step)
def export( self, output: str, input_example=None, output_example=None, verbose=False, export_params=True, do_constant_folding=True, keep_initializers_as_inputs=False, onnx_opset_version: int = 12, try_script: bool = False, set_eval: bool = True, check_trace: bool = True, use_dynamic_axes: bool = True, ): """ Unlike other models' export() this one creates 5 output files, not 3: punct_<output> - fused punctuation model (BERT+PunctuationClassifier) capit_<output> - fused capitalization model (BERT+CapitalizationClassifier) bert_<output> - common BERT neural net punct_classifier_<output> - Punctuation Classifier neural net capt_classifier_<output> - Capitalization Classifier neural net """ if input_example is not None or output_example is not None: logging.warning( "Passed input and output examples will be ignored and recomputed since" " PunctuationCapitalizationModel consists of three separate models with different" " inputs and outputs.") qual_name = self.__module__ + '.' + self.__class__.__qualname__ output1 = os.path.join(os.path.dirname(output), 'bert_' + os.path.basename(output)) output1_descr = qual_name + ' BERT exported to ONNX' bert_model_onnx = self.bert_model.export( output1, None, # computed by input_example() None, verbose, export_params, do_constant_folding, keep_initializers_as_inputs, onnx_opset_version, try_script, set_eval, check_trace, use_dynamic_axes, ) output2 = os.path.join(os.path.dirname(output), 'punct_classifier_' + os.path.basename(output)) output2_descr = qual_name + ' Punctuation Classifier exported to ONNX' punct_classifier_onnx = self.punct_classifier.export( output2, None, # computed by input_example() None, verbose, export_params, do_constant_folding, keep_initializers_as_inputs, onnx_opset_version, try_script, set_eval, check_trace, use_dynamic_axes, ) output3 = os.path.join(os.path.dirname(output), 'capit_classifier_' + os.path.basename(output)) output3_descr = qual_name + ' Capitalization Classifier exported to ONNX' capit_classifier_onnx = self.capit_classifier.export( output3, None, # computed by input_example() None, verbose, export_params, do_constant_folding, keep_initializers_as_inputs, onnx_opset_version, try_script, set_eval, check_trace, use_dynamic_axes, ) punct_output_model = attach_onnx_to_onnx(bert_model_onnx, punct_classifier_onnx, "PTCL") output4 = os.path.join(os.path.dirname(output), 'punct_' + os.path.basename(output)) output4_descr = qual_name + ' Punctuation BERT+Classifier exported to ONNX' onnx.save(punct_output_model, output4) capit_output_model = attach_onnx_to_onnx(bert_model_onnx, capit_classifier_onnx, "CPCL") output5 = os.path.join(os.path.dirname(output), 'capit_' + os.path.basename(output)) output5_descr = qual_name + ' Capitalization BERT+Classifier exported to ONNX' onnx.save(capit_output_model, output5) return ( [output1, output2, output3, output4, output5], [ output1_descr, output2_descr, output3_descr, output4_descr, output5_descr ], )
def __init__( self, audio_tar_filepaths, manifest_filepath, labels, batch_size, sample_rate=16000, int_values=False, bos_id=None, eos_id=None, pad_id=None, min_duration=0.1, max_duration=None, normalize_transcripts=True, trim_silence=False, shuffle_n=0, num_workers=0, augmentor: Optional[Union[AudioAugmentor, Dict[str, Dict[str, Any]]]] = None, ): super().__init__() self._sample_rate = sample_rate if augmentor is not None: augmentor = _process_augmentations(augmentor) self.collection = ASRAudioText( manifests_files=manifest_filepath.split(','), parser=make_parser(labels=labels, name='en', do_normalize=normalize_transcripts), min_duration=min_duration, max_duration=max_duration, index_by_file_id= True, # Must set this so the manifest lines can be indexed by file ID ) self.featurizer = WaveformFeaturizer(sample_rate=self._sample_rate, int_values=int_values, augmentor=augmentor) self.trim = trim_silence self.eos_id = eos_id self.bos_id = bos_id # Used in creating a sampler (in Actions). self._batch_size = batch_size self._num_workers = num_workers pad_id = 0 if pad_id is None else pad_id self.collate_fn = partial(seq_collate_fn, token_pad_value=pad_id) # Check for distributed and partition shards accordingly if torch.distributed.is_initialized(): global_rank = torch.distributed.get_rank() world_size = torch.distributed.get_world_size() if isinstance(audio_tar_filepaths, str): audio_tar_filepaths = list( braceexpand.braceexpand(audio_tar_filepaths)) if len(audio_tar_filepaths) % world_size != 0: logging.warning( f"Number of shards in tarred dataset ({len(audio_tar_filepaths)}) is not divisible " f"by number of distributed workers ({world_size}).") begin_idx = (len(audio_tar_filepaths) // world_size) * global_rank end_idx = begin_idx + (len(audio_tar_filepaths) // world_size) audio_tar_filepaths = audio_tar_filepaths[begin_idx:end_idx] # Put together WebDataset self._dataset = ( wd.Dataset(audio_tar_filepaths).shuffle(shuffle_n).rename( audio='wav', key='__key__').to_tuple('audio', 'key').pipe( self._filter).map(f=self._build_sample))
def _setup_dataloader_from_config(self, config: DictConfig): OmegaConf.set_struct(config, False) config.is_regression_task = self.is_regression_task OmegaConf.set_struct(config, True) if 'augmentor' in config: augmentor = process_augmentations(config['augmentor']) else: augmentor = None featurizer = WaveformFeaturizer(sample_rate=config['sample_rate'], int_values=config.get( 'int_values', False), augmentor=augmentor) shuffle = config['shuffle'] # Instantiate tarred dataset loader or normal dataset loader if config.get('is_tarred', False): if ('tarred_audio_filepaths' in config and config['tarred_audio_filepaths'] is None) or ( 'manifest_filepath' in config and config['manifest_filepath'] is None): logging.warning( "Could not load dataset as `manifest_filepath` is None or " f"`tarred_audio_filepaths` is None. Provided config : {config}" ) return None if 'vad_stream' in config and config['vad_stream']: logging.warning( "VAD inference does not support tarred dataset now") return None shuffle_n = config.get('shuffle_n', 4 * config['batch_size']) if shuffle else 0 dataset = audio_to_label_dataset.get_tarred_classification_label_dataset( featurizer=featurizer, config=OmegaConf.to_container(config), shuffle_n=shuffle_n, global_rank=self.global_rank, world_size=self.world_size, ) shuffle = False batch_size = config['batch_size'] collate_func = dataset.collate_fn else: if 'manifest_filepath' in config and config[ 'manifest_filepath'] is None: logging.warning( f"Could not load dataset as `manifest_filepath` is None. Provided config : {config}" ) return None if 'vad_stream' in config and config['vad_stream']: logging.info("Perform streaming frame-level VAD") dataset = audio_to_label_dataset.get_speech_label_dataset( featurizer=featurizer, config=OmegaConf.to_container(config)) batch_size = 1 collate_func = dataset.vad_frame_seq_collate_fn else: dataset = audio_to_label_dataset.get_classification_label_dataset( featurizer=featurizer, config=OmegaConf.to_container(config)) batch_size = config['batch_size'] collate_func = dataset.collate_fn return torch.utils.data.DataLoader( dataset=dataset, batch_size=batch_size, collate_fn=collate_func, drop_last=config.get('drop_last', False), shuffle=shuffle, num_workers=config.get('num_workers', 0), pin_memory=config.get('pin_memory', False), )
def main(): parser = argparse.ArgumentParser( parents=[nm_argparse.NemoArgParser()], description='AN4 ASR', conflict_handler='resolve', ) # Overwrite default args parser.add_argument("--train_dataset", type=str, help="training dataset path") parser.add_argument("--eval_datasets", type=str, help="validation dataset path") # Create new args # parser.add_argument("--lm", default="./an4-lm.3gram.binary", type=str) parser.add_argument("--batch_size", default=48, type=int, help="size of the training batch") parser.add_argument("--lm", default=None, type=str) parser.add_argument("--test_after_training", action='store_true') parser.add_argument("--momentum", type=float) parser.add_argument("--beta1", default=0.95, type=float) parser.add_argument("--beta2", default=0.25, type=float) parser.add_argument("--do_not_eval_at_start", action='store_true') parser.set_defaults( model_config="./configs/jasper_an4.yaml", train_dataset="~/TestData/an4_dataset/an4_train.json", eval_datasets="~/TestData/an4_dataset/an4_val.json", work_dir="./tmp", optimizer="novograd", num_epochs=50, lr=0.02, weight_decay=0.005, checkpoint_save_freq=1000, eval_freq=100, amp_opt_level="O1", ) args = parser.parse_args() betas = (args.beta1, args.beta2) wer_thr = 0.20 beam_wer_thr = 0.15 nf = nemo.core.NeuralModuleFactory( local_rank=args.local_rank, files_to_copy=[__file__], optimization_level=args.amp_opt_level, random_seed=0, log_dir=args.work_dir, create_tb_writer=True, cudnn_benchmark=args.cudnn_benchmark, ) tb_writer = nf.tb_writer checkpoint_dir = nf.checkpoint_dir # Load model definition yaml = YAML(typ="safe") with open(args.model_config) as f: jasper_params = yaml.load(f) # Get vocabulary. vocab = jasper_params['labels'] ( loss, eval_tensors, callbacks, total_steps, log_probs_e, encoded_len_e, ) = create_dags(args.model_config, vocab, args, nf) nf.train( tensors_to_optimize=[loss], callbacks=callbacks, optimizer=args.optimizer, lr_policy=CosineAnnealing(total_steps=total_steps, min_lr=args.lr / 100), optimization_params={ "num_epochs": args.num_epochs, "max_steps": args.max_steps, "lr": args.lr, "momentum": args.momentum, "betas": betas, "weight_decay": args.weight_decay, "grad_norm_clip": None, }, batches_per_step=args.iter_per_step, amp_max_loss_scale=256.0, # synced_batchnorm=(nf.global_rank is not None), ) if args.test_after_training: logging.info("Testing greedy and beam search with LM WER.") # Create BeamSearch NM if nf.world_size > 1 or args.lm is None: logging.warning( "Skipping beam search WER as it does not work if doing distributed training." ) else: beam_search_with_lm = nemo_asr.BeamSearchDecoderWithLM( vocab=vocab, beam_width=64, alpha=2.0, beta=1.5, lm_path=args.lm, num_cpus=max(os.cpu_count(), 1), ) beam_predictions = beam_search_with_lm( log_probs=log_probs_e, log_probs_length=encoded_len_e) eval_tensors.append(beam_predictions) evaluated_tensors = nf.infer(eval_tensors) if nf.global_rank in [0, None]: greedy_hypotheses = post_process_predictions( evaluated_tensors[1], vocab) references = post_process_transcripts(evaluated_tensors[2], evaluated_tensors[3], vocab) wer = word_error_rate(hypotheses=greedy_hypotheses, references=references) logging.info("Greedy WER: {:.2f}%".format(wer * 100)) if wer > wer_thr: nf.sync_all_processes(False) raise ValueError(f"Final eval greedy WER {wer * 100:.2f}% > :" f"than {wer_thr * 100:.2f}%") nf.sync_all_processes() if nf.world_size == 1 and args.lm is not None: beam_hypotheses = [] # Over mini-batch for i in evaluated_tensors[-1]: # Over samples for j in i: beam_hypotheses.append(j[0][1]) beam_wer = word_error_rate(hypotheses=beam_hypotheses, references=references) logging.info("Beam WER {:.2f}%".format(beam_wer * 100)) assert beam_wer <= beam_wer_thr, "Final eval beam WER {:.2f}% > than {:.2f}%".format( beam_wer * 100, beam_wer_thr * 100) assert beam_wer <= wer, "Final eval beam WER > than the greedy WER." # Reload model weights and train for extra 10 epochs checkpointer_callback = nemo.core.CheckpointCallback( folder=checkpoint_dir, step_freq=args.checkpoint_save_freq, force_load=True, ) # Distributed Data Parallel changes the underlying class so we need # to reinstantiate Encoder and Decoder args.num_epochs += 10 previous_step_count = total_steps loss, eval_tensors, callbacks, total_steps, _, _ = create_dags( args.model_config, vocab, args, nf) nf.reset_trainer() nf.train( tensors_to_optimize=[loss], callbacks=callbacks, optimizer=args.optimizer, lr_policy=CosineAnnealing(warmup_steps=previous_step_count, total_steps=total_steps), optimization_params={ "num_epochs": args.num_epochs, "lr": args.lr / 100, "momentum": args.momentum, "betas": betas, "weight_decay": args.weight_decay, "grad_norm_clip": None, }, reset=True, amp_max_loss_scale=256.0, # synced_batchnorm=(nf.global_rank is not None), ) evaluated_tensors = nf.infer(eval_tensors) if nf.global_rank in [0, None]: greedy_hypotheses = post_process_predictions( evaluated_tensors[1], vocab) references = post_process_transcripts(evaluated_tensors[2], evaluated_tensors[3], vocab) wer_new = word_error_rate(hypotheses=greedy_hypotheses, references=references) logging.info("New greedy WER: {:.2f}%".format(wer_new * 100)) if wer_new > wer * 1.1: nf.sync_all_processes(False) raise ValueError( f"Fine tuning: new WER {wer_new * 100:.2f}% > than the " f"previous WER {wer * 100:.2f}%") nf.sync_all_processes() # Open the log file and ensure that epochs is strictly increasing if nf._exp_manager.log_file: epochs = [] with open(nf._exp_manager.log_file, "r") as log_file: line = log_file.readline() while line: index = line.find("Starting epoch") if index != -1: epochs.append(int(line[index + len("Starting epoch"):])) line = log_file.readline() for i, e in enumerate(epochs): if i != e: raise ValueError("Epochs from logfile was not understood")
def __init__( self, audio_files: List[str], durations: List[float], labels: List[Union[int, str]], offsets: List[Optional[float]], min_duration: Optional[float] = None, max_duration: Optional[float] = None, max_number: Optional[int] = None, do_sort_by_duration: bool = False, index_by_file_id: bool = False, ): """Instantiates audio-label manifest with filters and preprocessing. Args: audio_files: List of audio files. durations: List of float durations. labels: List of labels. offsets: List of offsets or None. min_duration: Minimum duration to keep entry with (default: None). max_duration: Maximum duration to keep entry with (default: None). max_number: Maximum number of samples to collect. do_sort_by_duration: True if sort samples list by duration. index_by_file_id: If True, saves a mapping from filename base (ID) to index in data. """ if index_by_file_id: self.mapping = {} output_type = self.OUTPUT_TYPE data, duration_filtered = [], 0.0 for audio_file, duration, command, offset in zip( audio_files, durations, labels, offsets): # Duration filters. if min_duration is not None and duration < min_duration: duration_filtered += duration continue if max_duration is not None and duration > max_duration: duration_filtered += duration continue data.append(output_type(audio_file, duration, command, offset)) if index_by_file_id: file_id, _ = os.path.splitext(os.path.basename(audio_file)) self.mapping[file_id] = len(data) - 1 # Max number of entities filter. if len(data) == max_number: break if do_sort_by_duration: if index_by_file_id: logging.warning( "Tried to sort dataset by duration, but cannot since index_by_file_id is set." ) else: data.sort(key=lambda entity: entity.duration) logging.info( "Filtered duration for loading collection is %f.", duration_filtered, ) self.uniq_labels = sorted(set(map(lambda x: x.label, data))) logging.info("# {} files loaded accounting to # {} labels".format( len(data), len(self.uniq_labels))) super().__init__(data)
def __setup_dataloader_from_config(self, config: Optional[Dict]): if 'augmentor' in config: augmentor = process_augmentations(config['augmentor']) else: augmentor = None featurizer = WaveformFeaturizer(sample_rate=config['sample_rate'], int_values=config.get( 'int_values', False), augmentor=augmentor) shuffle = config.get('shuffle', False) if config.get('is_tarred', False): if ('tarred_audio_filepaths' in config and config['tarred_audio_filepaths'] is None) or ( 'manifest_filepath' in config and config['manifest_filepath'] is None): logging.warning( "Could not load dataset as `manifest_filepath` was None or " f"`tarred_audio_filepaths` is None. Provided config : {config}" ) return None shuffle_n = config.get('shuffle_n', 4 * config['batch_size']) if shuffle else 0 dataset = get_tarred_speech_label_dataset( featurizer=featurizer, config=config, shuffle_n=shuffle_n, global_rank=self.global_rank, world_size=self.world_size, ) shuffle = False else: if 'manifest_filepath' in config and config[ 'manifest_filepath'] is None: logging.warning( f"Could not load dataset as `manifest_filepath` was None. Provided config : {config}" ) return None dataset = AudioToSpeechLabelDataset( manifest_filepath=config['manifest_filepath'], labels=config['labels'], featurizer=featurizer, max_duration=config.get('max_duration', None), min_duration=config.get('min_duration', None), trim=config.get('trim_silence', False), time_length=config.get('time_length', 8), shift_length=config.get('shift_length', 0.75), normalize_audio=config.get('normalize_audio', False), ) if type(dataset) is ChainDataset: collate_ds = dataset.datasets[0] else: collate_ds = dataset # self.labels = collate_ds.labels if self.task == 'diarization': logging.info("Setting up diarization parameters") collate_fn = collate_ds.sliced_seq_collate_fn shuffle = False else: logging.info("Setting up identification parameters") collate_fn = collate_ds.fixed_seq_collate_fn batch_size = config['batch_size'] return torch.utils.data.DataLoader( dataset=dataset, batch_size=batch_size, collate_fn=collate_fn, drop_last=config.get('drop_last', False), shuffle=shuffle, num_workers=config.get('num_workers', 0), pin_memory=config.get('pin_memory', False), )
def setup(self, stage: str) -> None: """ PTL hook that is called after DDP is initialized. Called at the beginning of fit and test. Args: stage (str): either 'fit' or 'test' """ # TODO: implement model parallel for test stage if stage == 'fit': # adds self.bert_model config to .nemo file if hasattr(self, 'bert_model') and self.bert_model is not None: self.register_bert_model() app_state = AppState() if app_state.model_parallel_size is not None: if app_state.model_parallel_group is None: self.init_model_parallel(app_state.global_rank, app_state.world_size) # mpu grad clipping needs parameters to have the attribute model_parallel parameters = self._trainer.get_model().parameters() for p in parameters: if not hasattr(p, 'model_parallel'): p.model_parallel = False # Update PTL trainer to use our configure_ddp self._trainer.accelerator_backend.ddp_plugin.configure_ddp = self.configure_ddp # Update PTL trainer to use our _clip_gradients self._trainer.accelerator_backend._clip_gradients = self._clip_gradients self._trainer.checkpoint_connector = NLPCheckpointConnector( self._trainer) # Configure checkpointing for model parallel if app_state.create_checkpoint_callback: # global rank 0 is configured by exp_manager if not is_global_rank_zero( ) and app_state.data_parallel_rank == 0: configure_checkpointing( self._trainer, app_state.log_dir, app_state.checkpoint_name, app_state.checkpoint_callback_params, ) if isinstance(self.bert_model, MegatronBertEncoder): self.bert_model.complete_lazy_init() # model parallel checkpoints need to be restored after torch.distributed is initialized if self._trainer.resume_from_checkpoint is not None: # update path based on model parallel rank filepath = self._trainer.resume_from_checkpoint dirname = os.path.dirname(os.path.dirname(filepath)) basename = os.path.basename(filepath) filepath = f'{dirname}/mp_rank_{app_state.model_parallel_rank:02d}/{basename}' self._trainer.resume_from_checkpoint = filepath logging.info( f'Resuming training from checkpoint {self._trainer.resume_from_checkpoint}' ) # need to set checkpoint version for megatron-lm checkpoint_version = torch.load( self._trainer.resume_from_checkpoint).get( 'checkpoint_version', None) if checkpoint_version is not None: set_checkpoint_version(checkpoint_version) else: logging.warning( 'Megatron-lm checkpoint version not found. Setting checkpoint_version to 0.' ) set_checkpoint_version(0) else: logging.info( f"Restoring from pretrained model parallel checkpoint: {self.bert_model._restore_path}" ) self.bert_model.restore_weights( self.bert_model._restore_path) logging.info( "Replacing sampler with model parallel sampler") mp_sampler = torch.utils.data.distributed.DistributedSampler( self._train_dl.dataset, num_replicas=app_state.data_parallel_size, rank=app_state.data_parallel_rank, ) mp_dl = self._trainer.replace_sampler( self._train_dl, mp_sampler) self._train_dl = mp_dl else: raise NotImplementedError( f'The BERT encoder: {self.bert_model} does not support model parallelism yet.' ) else: # Megatron without model parallelism self.complete_megatron_init() else: # testing stage self.complete_megatron_init()
def __init__( self, text_tar_filepaths: str, metadata_path: str, tokenizer, max_seq_length: int = 512, batch_step: int = None, shuffle_n: int = 1, shard_strategy: str = "scatter", global_rank: int = 0, world_size: int = 0, ): super(TarredL2RLanguageModelingDataset, self).__init__() self.tokenizer = tokenizer self.max_seq_length = max_seq_length self.batch_step = batch_step or self.max_seq_length valid_shard_strategies = ['scatter', 'replicate'] if shard_strategy not in valid_shard_strategies: raise ValueError( f"`shard_strategy` must be one of {valid_shard_strategies}") with open(metadata_path, 'r') as f: metadata = json.load(f) self.metadata = metadata if isinstance(text_tar_filepaths, str): # Replace '(', '[', '<' and '_OP_' with '{' brace_keys_open = ['(', '[', '<', '_OP_'] for bkey in brace_keys_open: if bkey in text_tar_filepaths: text_tar_filepaths = text_tar_filepaths.replace(bkey, "{") # Replace ')', ']', '>' and '_CL_' with '}' brace_keys_close = [')', ']', '>', '_CL_'] for bkey in brace_keys_close: if bkey in text_tar_filepaths: text_tar_filepaths = text_tar_filepaths.replace(bkey, "}") if isinstance(text_tar_filepaths, str): # Brace expand text_tar_filepaths = list( braceexpand.braceexpand(text_tar_filepaths)) if shard_strategy == 'scatter': logging.info( "All tarred dataset shards will be scattered evenly across all nodes." ) if len(text_tar_filepaths) % world_size != 0: logging.warning( f"Number of shards in tarred dataset ({len(text_tar_filepaths)}) is not divisible " f"by number of distributed workers ({world_size}).") begin_idx = (len(text_tar_filepaths) // world_size) * global_rank end_idx = begin_idx + (len(text_tar_filepaths) // world_size) text_tar_filepaths = text_tar_filepaths[begin_idx:end_idx] logging.info( "Partitioning tarred dataset: process (%d) taking shards [%d, %d)", global_rank, begin_idx, end_idx) elif shard_strategy == 'replicate': logging.info( "All tarred dataset shards will be replicated across all nodes." ) else: raise ValueError( f"Invalid shard strategy ! Allowed values are : {valid_shard_strategies}" ) self.tarpath = text_tar_filepaths # Put together WebDataset self._dataset = ( wd.Dataset(text_tar_filepaths).shuffle(shuffle_n).rename( npy='npy', key='__key__').to_tuple('npy', 'key').map(f=self._build_sample))
def change_vocabulary(self, new_vocabulary: List[str], decoding_cfg: Optional[DictConfig] = None): """ Changes vocabulary used during RNNT decoding process. Use this method when fine-tuning a pre-trained model. This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would use it if you want to use pretrained encoder when fine-tuning on data in another language, or when you'd need model to learn capitalization, punctuation and/or special characters. Args: new_vocabulary: list with new vocabulary. Must contain at least 2 elements. Typically, \ this is target alphabet. decoding_cfg: A config for the decoder, which is optional. If the decoding type needs to be changed (from say Greedy to Beam decoding etc), the config can be passed here. Returns: None """ if self.joint.vocabulary == new_vocabulary: logging.warning( f"Old {self.joint.vocabulary} and new {new_vocabulary} match. Not changing anything." ) else: if new_vocabulary is None or len(new_vocabulary) == 0: raise ValueError( f'New vocabulary must be non-empty list of chars. But I got: {new_vocabulary}' ) joint_config = self.joint.to_config_dict() new_joint_config = copy.deepcopy(joint_config) new_joint_config['vocabulary'] = new_vocabulary new_joint_config['num_classes'] = len(new_vocabulary) del self.joint self.joint = EncDecRNNTModel.from_config_dict(new_joint_config) decoder_config = self.decoder.to_config_dict() new_decoder_config = copy.deepcopy(decoder_config) new_decoder_config.vocab_size = len(new_vocabulary) del self.decoder self.decoder = EncDecRNNTModel.from_config_dict(new_decoder_config) del self.loss loss_name, loss_kwargs = self.extract_rnnt_loss_cfg( self.cfg.get('loss', None)) self.loss = RNNTLoss( num_classes=self.joint.num_classes_with_blank - 1, loss_name=loss_name, loss_kwargs=loss_kwargs) if decoding_cfg is None: # Assume same decoding config as before decoding_cfg = self.cfg.decoding self.decoding = RNNTDecoding( decoding_cfg=decoding_cfg, decoder=self.decoder, joint=self.joint, vocabulary=self.joint.vocabulary, ) self.wer = RNNTWER( decoding=self.decoding, batch_dim_index=self.wer.batch_dim_index, use_cer=self.wer.use_cer, log_prediction=self.wer.log_prediction, dist_sync_on_step=True, ) # Setup fused Joint step if self.joint.fuse_loss_wer: self.joint.set_loss(self.loss) self.joint.set_wer(self.wer) # Update config with open_dict(self.cfg.joint): self.cfg.joint = new_joint_config with open_dict(self.cfg.decoder): self.cfg.decoder = new_decoder_config with open_dict(self.cfg.decoding): self.cfg.decoding = decoding_cfg logging.info( f"Changed decoder to output to {self.joint.vocabulary} vocabulary." )
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from nemo_text_processing.inverse_text_normalization.taggers.tokenize_and_classify import ClassifyFst from nemo_text_processing.inverse_text_normalization.verbalizers.verbalize_final import VerbalizeFinalFst from nemo.utils import logging try: import pynini PYNINI_AVAILABLE = True except (ModuleNotFoundError, ImportError): logging.warning("`pynini` is not installed ! \n" "Please run the `nemo_text_processing/setup.sh` script" "prior to usage of this toolkit.") PYNINI_AVAILABLE = False
def __init__( self, audio_tar_filepaths: Union[str, List[str]], manifest_filepath: str, parser: Callable, sample_rate: int, int_values: bool = False, augmentor: Optional[ 'nemo.collections.asr.parts.perturb.AudioAugmentor'] = None, shuffle_n: int = 0, min_duration: Optional[float] = None, max_duration: Optional[float] = None, max_utts: int = 0, trim: bool = False, bos_id: Optional[int] = None, eos_id: Optional[int] = None, pad_id: int = 0, shard_strategy: str = "scatter", global_rank: int = 0, world_size: int = 0, ): self.collection = collections.ASRAudioText( manifests_files=manifest_filepath.split(','), parser=parser, min_duration=min_duration, max_duration=max_duration, max_number=max_utts, index_by_file_id= True, # Must set this so the manifest lines can be indexed by file ID ) self.featurizer = WaveformFeaturizer(sample_rate=sample_rate, int_values=int_values, augmentor=augmentor) self.trim = trim self.eos_id = eos_id self.bos_id = bos_id self.pad_id = pad_id valid_shard_strategies = ['scatter', 'replicate'] if shard_strategy not in valid_shard_strategies: raise ValueError( f"`shard_strategy` must be one of {valid_shard_strategies}") if isinstance(audio_tar_filepaths, str): # Replace '(' and '[' with '{' brace_keys_open = ['(', '[', '<', '_OP_'] for bkey in brace_keys_open: if bkey in audio_tar_filepaths: audio_tar_filepaths = audio_tar_filepaths.replace( bkey, "{") # Replace ')' and ']' with '}' brace_keys_close = [')', ']', '>', '_CL_'] for bkey in brace_keys_close: if bkey in audio_tar_filepaths: audio_tar_filepaths = audio_tar_filepaths.replace( bkey, "}") # Check for distributed and partition shards accordingly if world_size > 1: if isinstance(audio_tar_filepaths, str): # Brace expand audio_tar_filepaths = list( braceexpand.braceexpand(audio_tar_filepaths)) if shard_strategy == 'scatter': logging.info( "All tarred dataset shards will be scattered evenly across all nodes." ) if len(audio_tar_filepaths) % world_size != 0: logging.warning( f"Number of shards in tarred dataset ({len(audio_tar_filepaths)}) is not divisible " f"by number of distributed workers ({world_size}).") begin_idx = (len(audio_tar_filepaths) // world_size) * global_rank end_idx = begin_idx + (len(audio_tar_filepaths) // world_size) audio_tar_filepaths = audio_tar_filepaths[begin_idx:end_idx] logging.info( "Partitioning tarred dataset: process (%d) taking shards [%d, %d)", global_rank, begin_idx, end_idx) elif shard_strategy == 'replicate': logging.info( "All tarred dataset shards will be replicated across all nodes." ) else: raise ValueError( f"Invalid shard strategy ! Allowed values are : {valid_shard_strategies}" ) # Put together WebDataset self._dataset = wd.WebDataset(audio_tar_filepaths) if shuffle_n > 0: self._dataset = self._dataset.shuffle(shuffle_n) else: logging.info( "WebDataset will not shuffle files within the tar files.") self._dataset = (self._dataset.rename( audio='wav', key='__key__').to_tuple('audio', 'key').pipe( self._filter).map(f=self._build_sample))
def main(cfg): if not cfg.dataset: raise ValueError("You must input the path of json file of evaluation data") # each line of dataset should be have different audio_filepath and unique name to simplfiy edge cases or conditions key_meta_map = {} with open(cfg.dataset, 'r') as manifest: for line in manifest.readlines(): audio_filepath = json.loads(line.strip())['audio_filepath'] uniq_audio_name = audio_filepath.split('/')[-1].rsplit('.', 1)[0] if uniq_audio_name in key_meta_map: raise ValueError("Please make sure each line is with different audio_filepath! ") key_meta_map[uniq_audio_name] = {'audio_filepath': audio_filepath} # Prepare manifest for streaming VAD manifest_vad_input = cfg.dataset if cfg.prepare_manifest.auto_split: logging.info("Split long audio file to avoid CUDA memory issue") logging.debug("Try smaller split_duration if you still have CUDA memory issue") config = { 'input': manifest_vad_input, 'window_length_in_sec': cfg.vad.parameters.window_length_in_sec, 'split_duration': cfg.prepare_manifest.split_duration, 'num_workers': cfg.num_workers, 'prepared_manfiest_vad_input': cfg.prepared_manfiest_vad_input, } manifest_vad_input = prepare_manifest(config) else: logging.warning( "If you encounter CUDA memory issue, try splitting manifest entry by split_duration to avoid it." ) torch.set_grad_enabled(False) vad_model = init_vad_model(cfg.vad.model_path) # setup_test_data vad_model.setup_test_data( test_data_config={ 'vad_stream': True, 'sample_rate': 16000, 'manifest_filepath': manifest_vad_input, 'labels': ['infer',], 'num_workers': cfg.num_workers, 'shuffle': False, 'window_length_in_sec': cfg.vad.parameters.window_length_in_sec, 'shift_length_in_sec': cfg.vad.parameters.shift_length_in_sec, 'trim_silence': False, 'normalize_audio': cfg.vad.parameters.normalize_audio, } ) vad_model = vad_model.to(device) vad_model.eval() if not os.path.exists(cfg.frame_out_dir): os.mkdir(cfg.frame_out_dir) else: logging.warning( "Note frame_out_dir exists. If new file has same name as file inside existing folder, it will append result to existing file and might cause mistakes for next steps." ) logging.info("Generating frame level prediction ") pred_dir = generate_vad_frame_pred( vad_model=vad_model, window_length_in_sec=cfg.vad.parameters.window_length_in_sec, shift_length_in_sec=cfg.vad.parameters.shift_length_in_sec, manifest_vad_input=manifest_vad_input, out_dir=cfg.frame_out_dir, ) logging.info( f"Finish generating VAD frame level prediction with window_length_in_sec={cfg.vad.parameters.window_length_in_sec} and shift_length_in_sec={cfg.vad.parameters.shift_length_in_sec}" ) # overlap smoothing filter if cfg.gen_overlap_seq: # Generate predictions with overlapping input segments. Then a smoothing filter is applied to decide the label for a frame spanned by multiple segments. # smoothing_method would be either in majority vote (median) or average (mean) logging.info("Generating predictions with overlapping input segments") smoothing_pred_dir = generate_overlap_vad_seq( frame_pred_dir=pred_dir, smoothing_method=cfg.vad.parameters.smoothing, overlap=cfg.vad.parameters.overlap, window_length_in_sec=cfg.vad.parameters.window_length_in_sec, shift_length_in_sec=cfg.vad.parameters.shift_length_in_sec, num_workers=cfg.num_workers, out_dir=cfg.smoothing_out_dir, ) logging.info( f"Finish generating predictions with overlapping input segments with smoothing_method={cfg.vad.parameters.smoothing} and overlap={cfg.vad.parameters.overlap}" ) pred_dir = smoothing_pred_dir # postprocessing and generate speech segments if cfg.gen_seg_table: logging.info("Converting frame level prediction to speech/no-speech segment in start and end times format.") table_out_dir = generate_vad_segment_table( vad_pred_dir=pred_dir, postprocessing_params=cfg.vad.parameters.postprocessing, shift_length_in_sec=cfg.vad.parameters.shift_length_in_sec, num_workers=cfg.num_workers, out_dir=cfg.table_out_dir, ) logging.info( f"Finish generating speech semgents table with postprocessing_params: {cfg.vad.parameters.postprocessing}" ) if cfg.write_to_manifest: for i in key_meta_map: key_meta_map[i]['rttm_filepath'] = os.path.join(table_out_dir, i + ".txt") if not cfg.out_manifest_filepath: out_manifest_filepath = "vad_out.json" else: out_manifest_filepath = cfg.out_manifest_filepath out_manifest_filepath = write_rttm2manifest(key_meta_map, out_manifest_filepath) logging.info(f"Writing VAD output to manifest: {out_manifest_filepath}")
def write_rttm2manifest(AUDIO_RTTM_MAP, manifest_file): """ writes manifest file based on rttm files (or vad table out files). This manifest file would be used by speaker diarizer to compute embeddings and cluster them. This function also takes care of overlap time stamps Args: AUDIO_RTTM_MAP: dict containing keys to uniqnames, that contains audio filepath and rttm_filepath as its contents, these are used to extract oracle vad timestamps. manifest (str): path to write manifest file Returns: manifest (str): path to write manifest file """ with open(manifest_file, 'w') as outfile: for key in AUDIO_RTTM_MAP: rttm_filename = AUDIO_RTTM_MAP[key]['rttm_filepath'] if rttm_filename and os.path.exists(rttm_filename): f = open(rttm_filename, 'r') else: raise FileNotFoundError( "Requested to construct manifest from rttm with oracle VAD option or from NeMo VAD but received filename as {}" .format(rttm_filename)) audio_path = AUDIO_RTTM_MAP[key]['audio_filepath'] if AUDIO_RTTM_MAP[key].get('duration', None): max_duration = AUDIO_RTTM_MAP[key]['duration'] else: sound = sf.SoundFile(audio_path) max_duration = sound.frames / sound.samplerate lines = f.readlines() time_tup = (-1, -1) for line in lines: vad_out = line.strip().split() if len(vad_out) > 3: start, dur, _ = float(vad_out[3]), float( vad_out[4]), vad_out[7] else: start, dur, _ = float(vad_out[0]), float( vad_out[1]), vad_out[2] start, dur = float("{:.3f}".format(start)), float( "{:.3f}".format(dur)) if start == 0 and dur == 0: # No speech segments continue else: if time_tup[0] >= 0 and start > time_tup[1]: dur2 = float("{:.3f}".format(time_tup[1] - time_tup[0])) if time_tup[0] < max_duration and dur2 > 0: meta = { "audio_filepath": audio_path, "offset": time_tup[0], "duration": dur2, "label": 'UNK', } json.dump(meta, outfile) outfile.write("\n") else: logging.warning( "RTTM label has been truncated since start is greater than duration of audio file" ) time_tup = (start, start + dur) else: if time_tup[0] == -1: end_time = start + dur if end_time > max_duration: end_time = max_duration time_tup = (start, end_time) else: end_time = max(time_tup[1], start + dur) if end_time > max_duration: end_time = max_duration time_tup = (min(time_tup[0], start), end_time) dur2 = float("{:.3f}".format(time_tup[1] - time_tup[0])) if time_tup[0] < max_duration and dur2 > 0: meta = { "audio_filepath": audio_path, "offset": time_tup[0], "duration": dur2, "label": 'UNK' } json.dump(meta, outfile) outfile.write("\n") else: logging.warning( "RTTM label has been truncated since start is greater than duration of audio file" ) f.close() return manifest_file
def _warn_unused_additional_kwargs(loss_name, kwargs): if len(kwargs) > 0: logging.warning( f"Loss function `{loss_name}` was provided with following additional kwargs,\n" f"however they were ignored as it is unused.\n" f"{kwargs}")
def get_log_dir( trainer: 'pytorch_lightning.Trainer', exp_dir: str = None, name: str = None, version: str = None, explicit_log_dir: str = None, use_datetime_version: bool = True, ) -> (Path, str, str, str): """ Obtains the log_dir used for exp_manager. Returns: log_dir (Path): the log_dir exp_dir (str): the base exp_dir without name nor version name (str): The name of the experiment version (str): The version of the experiment Raise: LoggerMisconfigurationError: If trainer is incompatible with arguments NotFoundError: If resume is True, resume_ignore_no_checkpoint is False, and checkpoints could not be found. ValueError: If resume is True, and there were more than 1 checkpoint could found. """ if explicit_log_dir: # If explicit log_dir was passed, short circuit return check_explicit_log_dir(trainer, explicit_log_dir, exp_dir, name, version) # Default exp_dir to ./nemo_experiments if None was passed _exp_dir = exp_dir if exp_dir is None: _exp_dir = str(Path.cwd() / 'nemo_experiments') # If the user has already defined a logger for the trainer, use the logger defaults for logging directory if trainer.logger is not None: if trainer.logger.save_dir: if exp_dir: raise LoggerMisconfigurationError( "The pytorch lightning trainer that was passed to exp_manager contained a logger, the logger's " f"save_dir was not None, and exp_dir ({exp_dir}) was not None. If trainer.logger.save_dir " "exists, exp_manager will use trainer.logger.save_dir as the logging directory and exp_dir " "must be None.") _exp_dir = trainer.logger.save_dir if name: raise LoggerMisconfigurationError( "The pytorch lightning trainer that was passed to exp_manager contained a logger, and name: " f"{name} was also passed to exp_manager. If the trainer contains a " "logger, exp_manager will use trainer.logger.name, and name passed to exp_manager must be None." ) name = trainer.logger.name version = f"version_{trainer.logger.version}" # Use user-defined exp_dir, project_name, exp_name, and versioning options else: name = name or "default" version = version or os.environ.get(NEMO_ENV_VARNAME_VERSION, None) if version is None: if trainer.is_slurm_managing_tasks: logging.warning( "Running on a slurm cluster. exp_manager will not add a version number." ) version = "" elif is_global_rank_zero(): if use_datetime_version: version = time.strftime('%Y-%m-%d_%H-%M-%S') else: tensorboard_logger = TensorBoardLogger( save_dir=Path(_exp_dir), name=name, version=version) version = f"version_{tensorboard_logger.version}" os.environ[NEMO_ENV_VARNAME_VERSION] = version log_dir = Path(_exp_dir) / Path(str(name)) / Path(str(version)) return log_dir, str(_exp_dir), name, version
def __init__(self, cfg: DictConfig, trainer: Trainer = None): # Convert to Hydra 1.0 compatible DictConfig cfg = model_utils.convert_model_config_to_dict_config(cfg) cfg = model_utils.maybe_update_config_version(cfg) # Setup normalizer self.normalizer = None self.text_normalizer_call = None self.text_normalizer_call_kwargs = {} self._setup_normalizer(cfg) self.learn_alignment = cfg.get("learn_alignment", False) # Setup vocabulary (=tokenizer) and input_fft_kwargs (supported only with self.learn_alignment=True) input_fft_kwargs = {} if self.learn_alignment: self.vocab = None self.ds_class_name = cfg.train_ds.dataset._target_.split(".")[-1] if self.ds_class_name == "TTSDataset": self._setup_tokenizer(cfg) assert self.vocab is not None input_fft_kwargs["n_embed"] = len(self.vocab.tokens) input_fft_kwargs["padding_idx"] = self.vocab.pad elif self.ds_class_name == "AudioToCharWithPriorAndPitchDataset": logging.warning( "AudioToCharWithPriorAndPitchDataset class has been deprecated. No support for" " training or finetuning. Only inference is supported." ) tokenizer_conf = self._get_default_text_tokenizer_conf() self._setup_tokenizer(tokenizer_conf) assert self.vocab is not None input_fft_kwargs["n_embed"] = len(self.vocab.tokens) input_fft_kwargs["padding_idx"] = self.vocab.pad else: raise ValueError(f"Unknown dataset class: {self.ds_class_name}") self._parser = None self._tb_logger = None super().__init__(cfg=cfg, trainer=trainer) self.bin_loss_warmup_epochs = cfg.get("bin_loss_warmup_epochs", 100) self.log_train_images = False loss_scale = 0.1 if self.learn_alignment else 1.0 dur_loss_scale = loss_scale pitch_loss_scale = loss_scale if "dur_loss_scale" in cfg: dur_loss_scale = cfg.dur_loss_scale if "pitch_loss_scale" in cfg: pitch_loss_scale = cfg.pitch_loss_scale self.mel_loss = MelLoss() self.pitch_loss = PitchLoss(loss_scale=pitch_loss_scale) self.duration_loss = DurationLoss(loss_scale=dur_loss_scale) self.aligner = None if self.learn_alignment: self.aligner = instantiate(self._cfg.alignment_module) self.forward_sum_loss = ForwardSumLoss() self.bin_loss = BinLoss() self.preprocessor = instantiate(self._cfg.preprocessor) input_fft = instantiate(self._cfg.input_fft, **input_fft_kwargs) output_fft = instantiate(self._cfg.output_fft) duration_predictor = instantiate(self._cfg.duration_predictor) pitch_predictor = instantiate(self._cfg.pitch_predictor) self.fastpitch = FastPitchModule( input_fft, output_fft, duration_predictor, pitch_predictor, self.aligner, cfg.n_speakers, cfg.symbols_embedding_dim, cfg.pitch_embedding_kernel_size, cfg.n_mel_channels, ) self._input_types = self._output_types = None
import pickle as pkl from argparse import ArgumentParser from collections import OrderedDict from typing import Dict import numpy as np import torch from build_index import load_model from omegaconf import DictConfig, OmegaConf from nemo.utils import logging try: import faiss except ModuleNotFoundError: logging.warning( "Faiss is required for building the index. Please install faiss-gpu") device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def get_query_embedding(query, model): """Use entity linking encoder to get embedding for index query""" model_input = model.tokenizer( query, add_special_tokens=True, padding=True, truncation=True, max_length=512, return_token_type_ids=True, return_attention_mask=True, )
def _setup_dataloader_from_config(self, config: Optional[Dict]): if 'augmentor' in config: augmentor = process_augmentations(config['augmentor']) else: augmentor = None shuffle = config['shuffle'] # Instantiate tarred dataset loader or normal dataset loader if config.get('is_tarred', False): if ('tarred_audio_filepaths' in config and config['tarred_audio_filepaths'] is None) or ( 'manifest_filepath' in config and config['manifest_filepath'] is None): logging.warning( "Could not load dataset as `manifest_filepath` was None or " f"`tarred_audio_filepaths` is None. Provided config : {config}" ) return None shuffle_n = config.get('shuffle_n', 4 * config['batch_size']) dataset = TarredAudioToCharDataset( audio_tar_filepaths=config['tarred_audio_filepaths'], manifest_filepath=config['manifest_filepath'], labels=config['labels'], sample_rate=config['sample_rate'], int_values=config.get('int_values', False), augmentor=augmentor, shuffle_n=shuffle_n, max_duration=config.get('max_duration', None), min_duration=config.get('min_duration', None), max_utts=config.get('max_utts', 0), blank_index=config.get('blank_index', -1), unk_index=config.get('unk_index', -1), normalize=config.get('normalize_transcripts', False), trim=config.get('trim_silence', True), parser=config.get('parser', 'en'), add_misc=config.get('add_misc', False), global_rank=self.global_rank, world_size=self.world_size, ) shuffle = False else: if 'manifest_filepath' in config and config[ 'manifest_filepath'] is None: logging.warning( f"Could not load dataset as `manifest_filepath` was None. Provided config : {config}" ) return None dataset = AudioToCharDataset( manifest_filepath=config['manifest_filepath'], labels=config['labels'], sample_rate=config['sample_rate'], int_values=config.get('int_values', False), augmentor=augmentor, max_duration=config.get('max_duration', None), min_duration=config.get('min_duration', None), max_utts=config.get('max_utts', 0), blank_index=config.get('blank_index', -1), unk_index=config.get('unk_index', -1), normalize=config.get('normalize_transcripts', False), trim=config.get('trim_silence', True), load_audio=config.get('load_audio', True), parser=config.get('parser', 'en'), add_misc=config.get('add_misc', False), ) return torch.utils.data.DataLoader( dataset=dataset, batch_size=config['batch_size'], collate_fn=dataset.collate_fn, drop_last=config.get('drop_last', False), shuffle=shuffle, num_workers=config.get('num_workers', 0), pin_memory=config.get('pin_memory', False), )