def __init__( self, *, manifest_filepath: str, labels: List[str], featurizer, min_duration: Optional[float] = 0.1, max_duration: Optional[float] = None, trim: bool = False, load_audio: bool = True, ): super().__init__() self.collection = collections.ASRSpeechLabel( manifests_files=manifest_filepath.split(','), min_duration=min_duration, max_duration=max_duration, ) self.featurizer = featurizer self.trim = trim self.load_audio = load_audio self.labels = labels if labels else self.collection.uniq_labels self.num_classes = len(self.labels) self.label2id, self.id2label = {}, {} for label_id, label in enumerate(self.labels): self.label2id[label] = label_id self.id2label[label_id] = label for idx in range(len(self.labels[:5])): logging.debug(" label id {} and its mapped label {}".format( idx, self.id2label[idx]))
def __init__( self, manifest_filepath, featurizer, labels=None, max_duration=None, min_duration=None, trim=False, load_audio=True, ): self.collection = collections.ASRSpeechLabel( manifests_files=manifest_filepath.split(','), min_duration=min_duration, max_duration=max_duration, ) self.featurizer = featurizer self.trim = trim self.load_audio = load_audio self.labels = labels if labels else self.collection.uniq_labels self.num_commands = len(self.labels) self.label2id, self.id2label = {}, {} for label_id, label in enumerate(self.labels): self.label2id[label] = label_id self.id2label[label_id] = label
def __init__( self, *, audio_tar_filepaths: Union[str, List[str]], manifest_filepath: str, labels: List[str], featurizer, shuffle_n: int = 0, min_duration: Optional[float] = 0.1, max_duration: Optional[float] = None, trim: bool = False, load_audio: bool = True, shard_strategy: str = "scatter", global_rank: int = 0, world_size: int = 0, ): self.collection = collections.ASRSpeechLabel( manifests_files=manifest_filepath.split(','), min_duration=min_duration, max_duration=max_duration, index_by_file_id= True, # Must set this so the manifest lines can be indexed by file ID ) self.file_occurence = count_occurence(self.collection.mapping) self.featurizer = featurizer self.trim = trim self.load_audio = load_audio self.labels = labels if labels else self.collection.uniq_labels self.num_classes = len(self.labels) self.label2id, self.id2label = {}, {} for label_id, label in enumerate(self.labels): self.label2id[label] = label_id self.id2label[label_id] = label for idx in range(len(self.labels[:5])): logging.debug(" label id {} and its mapped label {}".format( idx, self.id2label[idx])) valid_shard_strategies = ['scatter', 'replicate'] if shard_strategy not in valid_shard_strategies: raise ValueError( f"`shard_strategy` must be one of {valid_shard_strategies}") if isinstance(audio_tar_filepaths, str): # Replace '(' and '[' with '{' brace_keys_open = ['(', '[', '<', '_OP_'] for bkey in brace_keys_open: if bkey in audio_tar_filepaths: audio_tar_filepaths = audio_tar_filepaths.replace( bkey, "{") # Replace ')' and ']' with '}' brace_keys_close = [')', ']', '>', '_CL_'] for bkey in brace_keys_close: if bkey in audio_tar_filepaths: audio_tar_filepaths = audio_tar_filepaths.replace( bkey, "}") # Check for distributed and partition shards accordingly if world_size > 1: if isinstance(audio_tar_filepaths, str): # Brace expand audio_tar_filepaths = list( braceexpand.braceexpand(audio_tar_filepaths)) if shard_strategy == 'scatter': logging.info( "All tarred dataset shards will be scattered evenly across all nodes." ) if len(audio_tar_filepaths) % world_size != 0: logging.warning( f"Number of shards in tarred dataset ({len(audio_tar_filepaths)}) is not divisible " f"by number of distributed workers ({world_size}).") begin_idx = (len(audio_tar_filepaths) // world_size) * global_rank end_idx = begin_idx + (len(audio_tar_filepaths) // world_size) audio_tar_filepaths = audio_tar_filepaths[begin_idx:end_idx] logging.info( "Partitioning tarred dataset: process (%d) taking shards [%d, %d)", global_rank, begin_idx, end_idx) elif shard_strategy == 'replicate': logging.info( "All tarred dataset shards will be replicated across all nodes." ) else: raise ValueError( f"Invalid shard strategy ! Allowed values are : {valid_shard_strategies}" ) # Put together WebDataset self._dataset = wd.WebDataset(audio_tar_filepaths) if shuffle_n > 0: self._dataset = self._dataset.shuffle(shuffle_n) else: logging.info( "WebDataset will not shuffle files within the tar files.") self._dataset = (self._dataset.rename( audio='wav', key='__key__').to_tuple('audio', 'key').pipe( self._filter).map(f=self._build_sample))