Beispiel #1
0
    def get_data_loader(
        self,
        config: Coqpit,
        assets: Dict,
        is_eval: bool,
        samples: Union[List[Dict], List[List]],
        verbose: bool,
        num_gpus: int,
        rank: int = None,
    ) -> "DataLoader":
        if is_eval and not config.run_eval:
            loader = None
        else:
            # setup multi-speaker attributes
            if hasattr(self,
                       "speaker_manager") and self.speaker_manager is not None:
                if hasattr(config, "model_args"):
                    speaker_id_mapping = self.speaker_manager.ids if config.model_args.use_speaker_embedding else None
                    d_vector_mapping = self.speaker_manager.embeddings if config.model_args.use_d_vector_file else None
                    config.use_d_vector_file = config.model_args.use_d_vector_file
                else:
                    speaker_id_mapping = self.speaker_manager.ids if config.use_speaker_embedding else None
                    d_vector_mapping = self.speaker_manager.embeddings if config.use_d_vector_file else None
            else:
                speaker_id_mapping = None
                d_vector_mapping = None

            # setup multi-lingual attributes
            if hasattr(
                    self,
                    "language_manager") and self.language_manager is not None:
                language_id_mapping = self.language_manager.ids if self.args.use_language_embedding else None
            else:
                language_id_mapping = None

            # init dataloader
            dataset = TTSDataset(
                outputs_per_step=config.r if "r" in config else 1,
                compute_linear_spec=config.model.lower() == "tacotron"
                or config.compute_linear_spec,
                compute_f0=config.get("compute_f0", False),
                f0_cache_path=config.get("f0_cache_path", None),
                samples=samples,
                ap=self.ap,
                return_wav=config.return_wav
                if "return_wav" in config else False,
                batch_group_size=0 if is_eval else config.batch_group_size *
                config.batch_size,
                min_text_len=config.min_text_len,
                max_text_len=config.max_text_len,
                min_audio_len=config.min_audio_len,
                max_audio_len=config.max_audio_len,
                phoneme_cache_path=config.phoneme_cache_path,
                precompute_num_workers=config.precompute_num_workers,
                use_noise_augment=False
                if is_eval else config.use_noise_augment,
                verbose=verbose,
                speaker_id_mapping=speaker_id_mapping,
                d_vector_mapping=d_vector_mapping
                if config.use_d_vector_file else None,
                tokenizer=self.tokenizer,
                start_by_longest=config.start_by_longest,
                language_id_mapping=language_id_mapping,
            )

            # wait all the DDP process to be ready
            if num_gpus > 1:
                dist.barrier()

            # sort input sequences from short to long
            dataset.preprocess_samples()

            # get samplers
            sampler = self.get_sampler(config, dataset, num_gpus)

            loader = DataLoader(
                dataset,
                batch_size=config.eval_batch_size
                if is_eval else config.batch_size,
                shuffle=False,  # shuffle is done in the dataset.
                collate_fn=dataset.collate_fn,
                drop_last=
                False,  # setting this False might cause issues in AMP training.
                sampler=sampler,
                num_workers=config.num_eval_loader_workers
                if is_eval else config.num_loader_workers,
                pin_memory=False,
            )
        return loader
Beispiel #2
0
    def get_data_loader(
        self,
        config: Coqpit,
        assets: Dict,
        is_eval: bool,
        data_items: List,
        verbose: bool,
        num_gpus: int,
        rank: int = None,
    ) -> "DataLoader":
        if is_eval and not config.run_eval:
            loader = None
        else:
            ap = assets["audio_processor"]

            # setup multi-speaker attributes
            if hasattr(self,
                       "speaker_manager") and self.speaker_manager is not None:
                if hasattr(config, "model_args"):
                    speaker_id_mapping = (
                        self.speaker_manager.speaker_ids
                        if config.model_args.use_speaker_embedding else None)
                    d_vector_mapping = self.speaker_manager.d_vectors if config.model_args.use_d_vector_file else None
                    config.use_d_vector_file = config.model_args.use_d_vector_file
                else:
                    speaker_id_mapping = self.speaker_manager.speaker_ids if config.use_speaker_embedding else None
                    d_vector_mapping = self.speaker_manager.d_vectors if config.use_d_vector_file else None
            else:
                speaker_id_mapping = None
                d_vector_mapping = None

            # setup custom symbols if needed
            custom_symbols = None
            if hasattr(self, "make_symbols"):
                custom_symbols = self.make_symbols(self.config)

            if hasattr(self, "language_manager"):
                language_id_mapping = (
                    self.language_manager.language_id_mapping
                    if self.args.use_language_embedding else None)
            else:
                language_id_mapping = None

            # init dataloader
            dataset = TTSDataset(
                outputs_per_step=config.r if "r" in config else 1,
                text_cleaner=config.text_cleaner,
                compute_linear_spec=config.model.lower() == "tacotron"
                or config.compute_linear_spec,
                compute_f0=config.get("compute_f0", False),
                f0_cache_path=config.get("f0_cache_path", None),
                meta_data=data_items,
                ap=ap,
                characters=config.characters,
                custom_symbols=custom_symbols,
                add_blank=config["add_blank"],
                return_wav=config.return_wav
                if "return_wav" in config else False,
                batch_group_size=0 if is_eval else config.batch_group_size *
                config.batch_size,
                min_seq_len=config.min_seq_len,
                max_seq_len=config.max_seq_len,
                phoneme_cache_path=config.phoneme_cache_path,
                use_phonemes=config.use_phonemes,
                phoneme_language=config.phoneme_language,
                enable_eos_bos=config.enable_eos_bos_chars,
                use_noise_augment=False
                if is_eval else config.use_noise_augment,
                verbose=verbose,
                speaker_id_mapping=speaker_id_mapping,
                d_vector_mapping=d_vector_mapping,
                language_id_mapping=language_id_mapping,
            )

            # pre-compute phonemes
            if config.use_phonemes and config.compute_input_seq_cache and rank in [
                    None, 0
            ]:
                if hasattr(self, "eval_data_items") and is_eval:
                    dataset.items = self.eval_data_items
                elif hasattr(self, "train_data_items") and not is_eval:
                    dataset.items = self.train_data_items
                else:
                    # precompute phonemes for precise estimate of sequence lengths.
                    # otherwise `dataset.sort_items()` uses raw text lengths
                    dataset.compute_input_seq(config.num_loader_workers)

                    # TODO: find a more efficient solution
                    # cheap hack - store items in the model state to avoid recomputing when reinit the dataset
                    if is_eval:
                        self.eval_data_items = dataset.items
                    else:
                        self.train_data_items = dataset.items

            # halt DDP processes for the main process to finish computing the phoneme cache
            if num_gpus > 1:
                dist.barrier()

            # sort input sequences from short to long
            dataset.sort_and_filter_items(
                config.get("sort_by_audio_len", default=False))

            # compute pitch frames and write to files.
            if config.compute_f0 and rank in [None, 0]:
                if not os.path.exists(config.f0_cache_path):
                    dataset.pitch_extractor.compute_pitch(
                        ap, config.get("f0_cache_path", None),
                        config.num_loader_workers)

            # halt DDP processes for the main process to finish computing the F0 cache
            if num_gpus > 1:
                dist.barrier()

            # load pitch stats computed above by all the workers
            if config.compute_f0:
                dataset.pitch_extractor.load_pitch_stats(
                    config.get("f0_cache_path", None))

            # sampler for DDP
            sampler = DistributedSampler(dataset) if num_gpus > 1 else None

            # Weighted samplers
            assert not (
                num_gpus > 1
                and getattr(config, "use_language_weighted_sampler", False)
            ), "language_weighted_sampler is not supported with DistributedSampler"
            assert not (
                num_gpus > 1
                and getattr(config, "use_speaker_weighted_sampler", False)
            ), "speaker_weighted_sampler is not supported with DistributedSampler"

            if sampler is None:
                if getattr(config, "use_language_weighted_sampler", False):
                    print(" > Using Language weighted sampler")
                    sampler = get_language_weighted_sampler(dataset.items)
                elif getattr(config, "use_speaker_weighted_sampler", False):
                    print(" > Using Language weighted sampler")
                    sampler = get_speaker_weighted_sampler(dataset.items)

            loader = DataLoader(
                dataset,
                batch_size=config.eval_batch_size
                if is_eval else config.batch_size,
                shuffle=False,
                collate_fn=dataset.collate_fn,
                drop_last=False,
                sampler=sampler,
                num_workers=config.num_eval_loader_workers
                if is_eval else config.num_loader_workers,
                pin_memory=False,
            )
        return loader