Beispiel #1
0
    def load_dataset(self,
                     split: str,
                     task_cfg: FairseqDataclass = None,
                     **kwargs):
        data_path = self.cfg.data
        task_cfg = task_cfg or self.cfg

        # upgrade old task
        if isinstance(task_cfg, Namespace):
            if not hasattr(task_cfg, "autoregressive"):
                task_cfg.autoregressive = not task_cfg.criterion == "ctc"

        text_compression_level = getattr(TextCompressionLevel,
                                         str(self.cfg.text_compression_level))
        if getattr(task_cfg, "binarized_dataset", False):
            self.datasets[split] = BinarizedAudioDataset(
                data_path,
                split=split,
                sample_rate=task_cfg.get("sample_rate", self.cfg.sample_rate),
                max_sample_size=self.cfg.max_sample_size,
                min_sample_size=self.cfg.min_sample_size,
                pad=task_cfg.labels is not None or task_cfg.enable_padding,
                normalize=task_cfg.normalize,
                num_buckets=self.cfg.num_batch_buckets or int(self.cfg.tpu),
                compute_mask_indices=(self.cfg.precompute_mask_indices
                                      or self.cfg.tpu),
                **self._get_mask_precompute_kwargs(task_cfg),
            )
        else:
            manifest_path = os.path.join(data_path, "{}.tsv".format(split))

            self.datasets[split] = FileAudioDataset(
                manifest_path=manifest_path,
                sample_rate=task_cfg.get("sample_rate", self.cfg.sample_rate),
                max_sample_size=self.cfg.max_sample_size,
                min_sample_size=self.cfg.min_sample_size,
                pad=task_cfg.labels is not None or task_cfg.enable_padding,
                normalize=task_cfg.normalize,
                num_buckets=self.cfg.num_batch_buckets or int(self.cfg.tpu),
                compute_mask_indices=(self.cfg.precompute_mask_indices
                                      or self.cfg.tpu),
                text_compression_level=text_compression_level,
                **self._get_mask_precompute_kwargs(task_cfg),
            )

        if self.cfg.tpu and task_cfg.inferred_w2v_config.mask_channel_prob == 0.0:
            logger.info(
                "Pretraining on TPUs may suffer convergence "
                "issues when training with `mask_channel_prob` value of "
                "0. You may want to set this to a low value close to 0.")
Beispiel #2
0
    def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs):
        data_path = self.cfg.data
        task_cfg = task_cfg or self.cfg

        # upgrade old task
        if isinstance(task_cfg, Namespace):
            if not hasattr(task_cfg, "autoregressive"):
                task_cfg.autoregressive = not task_cfg.criterion == "ctc"

        if getattr(task_cfg, 'binarized_dataset', False):
            self.datasets[split] = BinarizedAudioDataset(
                data_path,
                split=split,
                sample_rate=task_cfg.get("sample_rate", self.cfg.sample_rate),
                max_sample_size=self.cfg.max_sample_size,
                min_sample_size=self.cfg.min_sample_size,
                pad=task_cfg.labels is not None or task_cfg.enable_padding,
                normalize=task_cfg.normalize,
                num_buckets=self.cfg.num_batch_buckets or int(self.cfg.tpu),
                compute_mask_indices=(self.cfg.precompute_mask_indices or self.cfg.tpu),
                **self._get_mask_precompute_kwargs(task_cfg),
            )
        else:
            manifest_path = os.path.join(data_path, "{}.tsv".format(split))

            self.datasets[split] = FileAudioDataset(
                manifest_path=manifest_path,
                sample_rate=task_cfg.get("sample_rate", self.cfg.sample_rate),
                max_sample_size=self.cfg.max_sample_size,
                min_sample_size=self.cfg.min_sample_size,
                pad=task_cfg.labels is not None or task_cfg.enable_padding,
                normalize=task_cfg.normalize,
                num_buckets=self.cfg.num_batch_buckets or int(self.cfg.tpu),
                compute_mask_indices=(self.cfg.precompute_mask_indices or self.cfg.tpu),
                **self._get_mask_precompute_kwargs(task_cfg),
            )

        if self.cfg.tpu and task_cfg["mask_channel_prob"] == 0.0:
            logger.info(
                "Pretraining on TPUs may suffer convergence "
                "issues when training with `mask_channel_prob` value of "
                "0. You may want to set this to a low value close to 0."
            )

        if task_cfg.labels:
            label_path = os.path.join(data_path, f"{split}.{task_cfg.labels}")
            skipped_indices = getattr(self.datasets[split], 'skipped_indices', set())
            with open(label_path, "r") as f:
                labels = [
                    line
                    for i, line in enumerate(f)
                    if i not in skipped_indices
                ]

            assert len(labels) == len(self.datasets[split]), (
                f"labels length ({len(labels)}) and dataset length "
                f"({len(self.datasets[split])}) do not match"
            )

            process_label = LabelEncoder(self.target_dictionary)

            self.datasets[split] = AddTargetDataset(
                self.datasets[split],
                labels,
                pad=self.target_dictionary.pad(),
                eos=self.target_dictionary.eos(),
                batch_targets=True,
                process_label=process_label,
                add_to_input=task_cfg.get("autoregressive", False),
            )
    def load_dataset(self,
                     split: str,
                     task_cfg: FairseqDataclass = None,
                     **kwargs):
        data_path_parent = self.cfg.data
        task_cfg = task_cfg or self.cfg
        data_path_list = [
            os.path.join(data_path_parent, path)
            for path in os.listdir(data_path_parent)
        ]

        # upgrade old task
        if isinstance(task_cfg, Namespace):
            if not hasattr(task_cfg, "autoregressive"):
                task_cfg.autoregressive = not task_cfg.criterion == "ctc"

        dataset_map = OrderedDict()
        datasets_lengths = []
        for data_path in data_path_list:
            if getattr(task_cfg, "binarized_dataset", False):
                dataset_map[data_path] = BinarizedAudioDataset(
                    data_path,
                    split=split,
                    sample_rate=task_cfg.get("sample_rate",
                                             self.cfg.sample_rate),
                    max_sample_size=self.cfg.max_sample_size,
                    min_sample_size=self.cfg.min_sample_size,
                    pad=task_cfg.labels is not None or task_cfg.enable_padding,
                    normalize=task_cfg.normalize,
                    num_buckets=self.cfg.num_batch_buckets
                    or int(self.cfg.tpu),
                    compute_mask_indices=(self.cfg.precompute_mask_indices
                                          or self.cfg.tpu),
                    **self._get_mask_precompute_kwargs(task_cfg),
                )
            else:
                manifest_path = os.path.join(data_path, "{}.tsv".format(split))

                dataset_map[data_path] = FileAudioDataset(
                    manifest_path=manifest_path,
                    sample_rate=task_cfg.get("sample_rate",
                                             self.cfg.sample_rate),
                    max_sample_size=self.cfg.max_sample_size,
                    min_sample_size=self.cfg.min_sample_size,
                    pad=task_cfg.labels is not None or task_cfg.enable_padding,
                    normalize=task_cfg.normalize,
                    num_buckets=self.cfg.num_batch_buckets
                    or int(self.cfg.tpu),
                    compute_mask_indices=(self.cfg.precompute_mask_indices
                                          or self.cfg.tpu),
                    **self._get_mask_precompute_kwargs(task_cfg),
                )

            if self.cfg.tpu and task_cfg["mask_channel_prob"] == 0.0:
                logger.info(
                    "Pretraining on TPUs may suffer convergence "
                    "issues when training with `mask_channel_prob` value of "
                    "0. You may want to set this to a low value close to 0.")

            if task_cfg.labels:
                label_path = os.path.join(data_path,
                                          f"{split}.{task_cfg.labels}")
                if os.path.exists(label_path):
                    skipped_indices = getattr(dataset_map[data_path],
                                              "skipped_indices", set())

                    with open(label_path, "r") as f:
                        labels = [
                            line for i, line in enumerate(f)
                            if i not in skipped_indices
                        ]

                    assert len(labels) == len(dataset_map[data_path]), (
                        f"labels length ({len(labels)}) and dataset length "
                        f"({len(dataset_map[data_path])}) do not match")

                    process_label = LabelEncoder(self.target_dictionary)

                    dataset_map[data_path] = AddTargetDataset(
                        dataset_map[data_path],
                        labels,
                        pad=self.target_dictionary.pad(),
                        eos=self.target_dictionary.eos(),
                        batch_targets=True,
                        process_label=process_label,
                        add_to_input=task_cfg.get("autoregressive", False),
                    )

            datasets_lengths.append(
                sum(dataset_map[data_path].sizes) / task_cfg.sample_rate /
                3600)

        datasets_lengths = np.array(datasets_lengths)
        self.sample_probs = self._get_sample_prob(datasets_lengths)
        size_ratio = (self.sample_probs *
                      datasets_lengths.sum()) / datasets_lengths
        for id, data_path in enumerate(data_path_list):
            logger.info(
                "Up/Down Sampling ratio by datasets: {} : {:.2f} to prob:{:.2f}".\
                    format(data_path.split('/')[-1], size_ratio[id],self.sample_probs[id])
            )

        self.datasets[split] = MultiCorpusSampledDataset(
            dataset_map, sampling_func=self.dataset_sampler)
        logger.info('{} {} examples'.format(split, len(self.datasets[split])))