def get_tarred_char_dataset( config: dict, shuffle_n: int, global_rank: int, world_size: int, augmentor: Optional['AudioAugmentor'] = None ) -> audio_to_text.TarredAudioToCharDataset: """ Instantiates a Character Encoding based TarredAudioToCharDataset. Args: config: Config of the TarredAudioToCharDataset. shuffle_n: How many samples to look ahead and load to be shuffled. See WebDataset documentation for more details. global_rank: Global rank of this device. world_size: Global world size in the training method. augmentor: Optional AudioAugmentor object for augmentations on audio data. Returns: An instance of TarredAudioToCharDataset. """ dataset = audio_to_text.TarredAudioToCharDataset( audio_tar_filepaths=config['tarred_audio_filepaths'], manifest_filepath=config['manifest_filepath'], labels=config['labels'], sample_rate=config['sample_rate'], int_values=config.get('int_values', False), augmentor=augmentor, shuffle_n=shuffle_n, max_duration=config.get('max_duration', None), min_duration=config.get('min_duration', None), max_utts=config.get('max_utts', 0), blank_index=config.get('blank_index', -1), unk_index=config.get('unk_index', -1), normalize=config.get('normalize_transcripts', False), trim=config.get('trim_silence', True), parser=config.get('parser', 'en'), add_misc=config.get('add_misc', False), shard_strategy=config.get('tarred_shard_strategy', 'scatter'), global_rank=global_rank, world_size=world_size, ) return dataset
def get_tarred_dataset( config: dict, shuffle_n: int, global_rank: int, world_size: int, tokenizer: Optional['TokenizerSpec'] = None, augmentor: Optional['AudioAugmentor'] = None, ) -> Union[audio_to_text.TarredAudioToBPEDataset, audio_to_text.TarredAudioToCharDataset]: """ Instantiates a Word Piece/BPE Encoding based TarredAudioToBPEDataset or a char based TarredAudioToCharDataset. Args: config: Config of the TarredAudioToBPEDataset or TarredAudioToCharDataset. tokenizer: An instance of a TokenizerSpec object if BPE dataset is needed. Passsing None would return a char-based dataset. shuffle_n: How many samples to look ahead and load to be shuffled. See WebDataset documentation for more details. global_rank: Global rank of this device. world_size: Global world size in the training method. augmentor: Optional AudioAugmentor object for augmentations on audio data. Returns: An instance of TarredAudioToBPEDataset or TarredAudioToCharDataset. """ tarred_audio_filepaths = config['tarred_audio_filepaths'] manifest_filepaths = config['manifest_filepath'] datasets = [] tarred_audio_filepaths = convert_to_config_list(tarred_audio_filepaths) manifest_filepaths = convert_to_config_list(manifest_filepaths) if len(manifest_filepaths) != len(tarred_audio_filepaths): raise ValueError( f"manifest_filepaths and tarred_audio_filepaths need to have the same number of buckets." ) for dataset_idx, (tarred_audio_filepath, manifest_filepath) in enumerate( zip(tarred_audio_filepaths, manifest_filepaths)): if len(tarred_audio_filepath) == 1: tarred_audio_filepath = tarred_audio_filepath[0] if tokenizer is None: dataset = audio_to_text.TarredAudioToCharDataset( audio_tar_filepaths=tarred_audio_filepath, manifest_filepath=manifest_filepath, labels=config['labels'], sample_rate=config['sample_rate'], int_values=config.get('int_values', False), augmentor=augmentor, shuffle_n=shuffle_n, max_duration=config.get('max_duration', None), min_duration=config.get('min_duration', None), max_utts=config.get('max_utts', 0), blank_index=config.get('blank_index', -1), unk_index=config.get('unk_index', -1), normalize=config.get('normalize_transcripts', False), trim=config.get('trim_silence', False), parser=config.get('parser', 'en'), shard_strategy=config.get('tarred_shard_strategy', 'scatter'), global_rank=global_rank, world_size=world_size, ) else: dataset = audio_to_text.TarredAudioToBPEDataset( audio_tar_filepaths=tarred_audio_filepath, manifest_filepath=manifest_filepath, tokenizer=tokenizer, sample_rate=config['sample_rate'], int_values=config.get('int_values', False), augmentor=augmentor, shuffle_n=shuffle_n, max_duration=config.get('max_duration', None), min_duration=config.get('min_duration', None), max_utts=config.get('max_utts', 0), trim=config.get('trim_silence', False), use_start_end_token=config.get('use_start_end_token', True), shard_strategy=config.get('tarred_shard_strategy', 'scatter'), global_rank=global_rank, world_size=world_size, ) datasets.append(dataset) if len(datasets) > 1: return ChainDataset(datasets) else: return datasets[0]