Ejemplo n.º 1
0
    def __init__(
        self,
        *,
        audio_tar_filepaths: Union[str, List[str]],
        manifest_filepath: Union[str, List[str]],
        labels: List[str],
        featurizer,
        shuffle_n: int = 0,
        min_duration: Optional[float] = 0.1,
        max_duration: Optional[float] = None,
        trim: bool = False,
        shard_strategy: str = "scatter",
        global_rank: int = 0,
        world_size: int = 0,
        is_regression_task: bool = False,
    ):
        self.collection = collections.ASRSpeechLabel(
            manifests_files=manifest_filepath,
            min_duration=min_duration,
            max_duration=max_duration,
            index_by_file_id=True,  # Must set this so the manifest lines can be indexed by file ID
        )

        self.file_occurence = count_occurence(self.collection.mapping)

        self.featurizer = featurizer
        self.trim = trim

        self.labels = labels if labels else self.collection.uniq_labels
        self.num_classes = len(self.labels)

        self.label2id, self.id2label = {}, {}
        for label_id, label in enumerate(self.labels):
            self.label2id[label] = label_id
            self.id2label[label_id] = label

        for idx in range(len(self.labels[:5])):
            logging.debug(" label id {} and its mapped label {}".format(idx, self.id2label[idx]))

        audio_tar_filepaths = expand_audio_filepaths(
            audio_tar_filepaths=audio_tar_filepaths,
            shard_strategy=shard_strategy,
            world_size=world_size,
            global_rank=global_rank,
        )
        # Put together WebDataset
        self._dataset = wd.WebDataset(urls=audio_tar_filepaths, nodesplitter=None)

        if shuffle_n > 0:
            self._dataset = self._dataset.shuffle(shuffle_n)
        else:
            logging.info("WebDataset will not shuffle files within the tar files.")

        self._dataset = (
            self._dataset.rename(audio=VALID_FILE_FORMATS, key='__key__')
            .to_tuple('audio', 'key')
            .pipe(self._filter)
            .map(f=self._build_sample)
        )
Ejemplo n.º 2
0
    def __init__(
        self,
        *,
        manifest_filepath: str,
        labels: List[str],
        featurizer,
        min_duration: Optional[float] = 0.1,
        max_duration: Optional[float] = None,
        trim: bool = False,
        is_regression_task: bool = False,
    ):
        super().__init__()
        self.collection = collections.ASRSpeechLabel(
            manifests_files=manifest_filepath.split(','),
            min_duration=min_duration,
            max_duration=max_duration,
            is_regression_task=is_regression_task,
        )

        self.featurizer = featurizer
        self.trim = trim
        self.is_regression_task = is_regression_task

        if not is_regression_task:
            self.labels = labels if labels else self.collection.uniq_labels
            self.num_classes = len(
                self.labels) if self.labels is not None else 1
            self.label2id, self.id2label = {}, {}
            for label_id, label in enumerate(self.labels):
                self.label2id[label] = label_id
                self.id2label[label_id] = label

            for idx in range(len(self.labels[:5])):
                logging.debug(" label id {} and its mapped label {}".format(
                    idx, self.id2label[idx]))

        else:
            self.labels = []
            self.num_classes = 1
Ejemplo n.º 3
0
    def __init__(
        self,
        *,
        audio_tar_filepaths: Union[str, List[str]],
        manifest_filepath: str,
        labels: List[str],
        featurizer,
        shuffle_n: int = 0,
        min_duration: Optional[float] = 0.1,
        max_duration: Optional[float] = None,
        trim: bool = False,
        shard_strategy: str = "scatter",
        global_rank: int = 0,
        world_size: int = 0,
        is_regression_task: bool = False,
    ):
        self.collection = collections.ASRSpeechLabel(
            manifests_files=manifest_filepath.split(','),
            min_duration=min_duration,
            max_duration=max_duration,
            index_by_file_id=
            True,  # Must set this so the manifest lines can be indexed by file ID
        )

        self.file_occurence = count_occurence(self.collection.mapping)

        self.featurizer = featurizer
        self.trim = trim

        self.labels = labels if labels else self.collection.uniq_labels
        self.num_classes = len(self.labels)

        self.label2id, self.id2label = {}, {}
        for label_id, label in enumerate(self.labels):
            self.label2id[label] = label_id
            self.id2label[label_id] = label

        for idx in range(len(self.labels[:5])):
            logging.debug(" label id {} and its mapped label {}".format(
                idx, self.id2label[idx]))

        valid_shard_strategies = ['scatter', 'replicate']
        if shard_strategy not in valid_shard_strategies:
            raise ValueError(
                f"`shard_strategy` must be one of {valid_shard_strategies}")

        if isinstance(audio_tar_filepaths, str):
            # Replace '(' and '[' with '{'
            brace_keys_open = ['(', '[', '<', '_OP_']
            for bkey in brace_keys_open:
                if bkey in audio_tar_filepaths:
                    audio_tar_filepaths = audio_tar_filepaths.replace(
                        bkey, "{")

            # Replace ')' and ']' with '}'
            brace_keys_close = [')', ']', '>', '_CL_']
            for bkey in brace_keys_close:
                if bkey in audio_tar_filepaths:
                    audio_tar_filepaths = audio_tar_filepaths.replace(
                        bkey, "}")

        # Check for distributed and partition shards accordingly
        if world_size > 1:
            if isinstance(audio_tar_filepaths, str):
                # Brace expand
                audio_tar_filepaths = list(
                    braceexpand.braceexpand(audio_tar_filepaths))

            if shard_strategy == 'scatter':
                logging.info(
                    "All tarred dataset shards will be scattered evenly across all nodes."
                )

                if len(audio_tar_filepaths) % world_size != 0:
                    logging.warning(
                        f"Number of shards in tarred dataset ({len(audio_tar_filepaths)}) is not divisible "
                        f"by number of distributed workers ({world_size}).")

                begin_idx = (len(audio_tar_filepaths) //
                             world_size) * global_rank
                end_idx = begin_idx + (len(audio_tar_filepaths) // world_size)
                audio_tar_filepaths = audio_tar_filepaths[begin_idx:end_idx]
                logging.info(
                    "Partitioning tarred dataset: process (%d) taking shards [%d, %d)",
                    global_rank, begin_idx, end_idx)

            elif shard_strategy == 'replicate':
                logging.info(
                    "All tarred dataset shards will be replicated across all nodes."
                )

            else:
                raise ValueError(
                    f"Invalid shard strategy ! Allowed values are : {valid_shard_strategies}"
                )

        # Put together WebDataset
        self._dataset = wd.WebDataset(urls=audio_tar_filepaths,
                                      nodesplitter=None)

        if shuffle_n > 0:
            self._dataset = self._dataset.shuffle(shuffle_n)
        else:
            logging.info(
                "WebDataset will not shuffle files within the tar files.")

        self._dataset = (self._dataset.rename(
            audio=VALID_FILE_FORMATS, key='__key__').to_tuple(
                'audio', 'key').pipe(self._filter).map(f=self._build_sample))