Exemple #1
0
    def load_sentence(self, split, sentence):
        loaded_datasets = []
        words = sentence.split(' ')
        ds = IndexedRawTextDataset(words, self.dictionary)
        loaded_datasets.append(
            TokenBlockDataset(
                ds,
                ds.sizes,
                self.args.tokens_per_sample,
                pad=self.dictionary.pad(),
                eos=self.dictionary.eos(),
                break_mode=self.args.sample_break_mode,
                include_targets=True,
            ))
        if len(loaded_datasets) == 1:
            dataset = loaded_datasets[0]
            sizes = dataset.sizes
        else:
            dataset = ConcatDataset(loaded_datasets)
            sizes = np.concatenate([ds.sizes for ds in loaded_datasets])

        add_eos_for_other_targets = self.args.sample_break_mode is not None and self.args.sample_break_mode != 'none'

        self.datasets[split] = MonolingualDataset(
            dataset,
            sizes,
            self.dictionary,
            self.output_dictionary,
            add_eos_for_other_targets=add_eos_for_other_targets,
            shuffle=True,
            targets=self.targets,
        )
Exemple #2
0
    def load_dataset(self, split, epoch=0, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = self.args.data.split(':')
        assert len(paths) > 0
        data_path = paths[epoch % len(paths)]
        split_path = os.path.join(data_path, split)

        dataset = data_utils.load_indexed_dataset(
            split_path,
            self.dictionary,
            self.args.dataset_impl,
            combine=combine,
        )
        if dataset is None:
            raise FileNotFoundError('Dataset not found: {} ({})'.format(split, split_path))

        dataset = TokenBlockDataset(
            dataset, dataset.sizes, self.args.tokens_per_sample,
            pad=self.dictionary.pad(), eos=self.dictionary.eos(),
            break_mode=self.args.sample_break_mode, include_targets=True,
        )

        add_eos_for_other_targets = self.args.sample_break_mode is not None and self.args.sample_break_mode != 'none'

        self.datasets[split] = MonolingualDataset(
            dataset, dataset.sizes, self.dictionary, self.output_dictionary,
            add_eos_for_other_targets=add_eos_for_other_targets, shuffle=True,
            targets=self.targets, add_bos_token=self.args.add_bos_token,
        )
    def load_dataset(self, split, epoch=0, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        use_ctx_dataset = getattr(self.vqvae_args, 'use_context_dataset', 0)
        paths = self.vqvae_args.data.split(":")
        assert len(paths) > 0

        data_path = paths[epoch % len(paths)]
        split_path = os.path.join(data_path, split)

        dataset = data_utils.load_indexed_dataset(
            split_path, self.dictionary, self.vqvae_args.dataset_impl, combine=combine
        )
        if dataset is None:
            raise FileNotFoundError(
                "Dataset not found: {} ({})".format(split, split_path)
            )

        if use_ctx_dataset:
            dataset = DocBlockDataset(
                dataset,
                dataset.sizes,
                self.vqvae_args.tokens_per_sample,
                pad=self.dictionary.pad(),
                eos=self.dictionary.eos(),
                break_mode=self.vqvae_args.sample_break_mode,
                include_targets=True,
                context_mode=self.vqvae_args.context_mode,
                window_size=self.vqvae_args.window_size,
            )
        else:
            dataset = TokenBlockDataset(
                dataset,
                dataset.sizes,
                self.vqvae_args.tokens_per_sample,
                pad=self.dictionary.pad(),
                eos=self.dictionary.eos(),
                break_mode=self.vqvae_args.sample_break_mode,
                include_targets=True,
            )

        add_eos_for_other_targets = (
                self.vqvae_args.sample_break_mode is not None
                and self.vqvae_args.sample_break_mode != "none"
        )

        self.datasets[split] = MonolingualDataset(
            dataset,
            dataset.sizes,
            self.dictionary,
            self.output_dictionary,
            add_eos_for_other_targets=add_eos_for_other_targets,
            shuffle=True,
            targets=self.targets,
            add_bos_token=self.vqvae_args.add_bos_token,
        )
    def load_dataset(self, split, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """

        loaded_datasets = []

        for k in itertools.count():
            split_k = split + (str(k) if k > 0 else '')
            path = os.path.join(self.args.data, split_k)

            if self.args.raw_text and IndexedRawTextDataset.exists(path):
                ds = IndexedRawTextDataset(path, self.dictionary)
            elif not self.args.raw_text and IndexedDataset.exists(path):
                ds = IndexedDataset(path, fix_lua_indexing=True)
            else:
                if k > 0:
                    break
                else:
                    raise FileNotFoundError(
                        'Dataset not found: {} ({})'.format(
                            split, self.args.data))

            loaded_datasets.append(
                TokenBlockDataset(
                    ds,
                    self.args.tokens_per_sample,
                    pad=self.dictionary.pad(),
                    eos=self.dictionary.eos(),
                    break_mode=self.args.sample_break_mode,
                    include_targets=True,
                ))

            print('| {} {} {} examples'.format(self.args.data, split_k,
                                               len(loaded_datasets[-1])))

            if not combine:
                break

        if len(loaded_datasets) == 1:
            dataset = loaded_datasets[0]
            sizes = dataset.sizes
        else:
            dataset = ConcatDataset(loaded_datasets)
            sizes = np.concatenate([ds.sizes for ds in loaded_datasets])

        add_eos_for_other_targets = self.args.sample_break_mode is not None and self.args.sample_break_mode != 'none'

        self.datasets[split] = MonolingualDataset(
            dataset,
            sizes,
            self.dictionary,
            self.output_dictionary,
            add_eos_for_other_targets=add_eos_for_other_targets,
            shuffle=True,
            targets=self.targets,
        )
Exemple #5
0
    def load_dataset(self,
                     split: str,
                     epoch=1,
                     combine=False,
                     **kwargs) -> MonolingualDataset:
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = utils.split_paths(self.args.data)
        assert len(paths) > 0

        data_path = paths[(epoch - 1) % len(paths)]
        split_path = os.path.join(data_path, split)

        # each process has its own copy of the raw data (likely to be an np.memmap)
        dataset = data_utils.load_indexed_dataset(split_path,
                                                  self.dictionary,
                                                  self.args.dataset_impl,
                                                  combine=combine)
        if dataset is None:
            raise FileNotFoundError(
                f"Dataset not found: {split} ({split_path})")

        dataset = maybe_shorten_dataset(
            dataset,
            split,
            self.args.shorten_data_split_list,
            self.args.shorten_method,
            self.args.tokens_per_sample,
            self.args.seed,
        )
        dataset = TokenBlockDataset(
            dataset,
            dataset.sizes,
            self.args.tokens_per_sample,
            pad=self.dictionary.pad(),
            eos=self.dictionary.eos(),
            break_mode=self.args.sample_break_mode,
            include_targets=True,
            use_plasma_view=self.args.use_plasma_view,
            split_path=split_path,
            plasma_path=self.args.plasma_path,
        )

        add_eos_for_other_targets = (self.args.sample_break_mode is not None
                                     and self.args.sample_break_mode != "none")

        self.datasets[split] = MonolingualDataset(
            dataset=dataset,
            sizes=dataset.sizes,
            src_vocab=self.dictionary,
            tgt_vocab=self.output_dictionary,
            add_eos_for_other_targets=add_eos_for_other_targets,
            shuffle=True,
            targets=self.targets,
            add_bos_token=self.args.add_bos_token,
        )
Exemple #6
0
    def load_dataset(self, split, combine=False):
        """Load a dataset split."""

        loaded_datasets = []

        for k in itertools.count():
            split_k = split + (str(k) if k > 0 else '')
            path = os.path.join(self.args.data, split_k)

            if self.args.raw_text and IndexedRawTextDataset.exists(path):
                ds = IndexedRawTextDataset(path, self.dictionary)
                tokens = [t for l in ds.tokens_list for t in l]
            elif not self.args.raw_text and IndexedInMemoryDataset.exists(
                    path):
                ds = IndexedInMemoryDataset(path, fix_lua_indexing=True)
                tokens = ds.buffer
            else:
                if k > 0:
                    break
                else:
                    raise FileNotFoundError(
                        'Dataset not found: {} ({})'.format(
                            split, self.args.data))

            cbt_booktitle_idx = None
            if self.args.sample_break_mode == 'cbt_booktitle':
                if self.dictionary.index(
                        '_BOOK_TITLE_') != self.dictionary.unk():
                    cbt_booktitle_idx = self.dictionary.index('_BOOK_TITLE_')

            loaded_datasets.append(
                TokenBlockDataset(
                    tokens,
                    ds.sizes,
                    self.args.tokens_per_sample,
                    self.args.sample_break_mode,
                    include_targets=True,
                    cbt_booktitle_idx=cbt_booktitle_idx,
                ))

            print('| {} {} {} examples'.format(self.args.data, split_k,
                                               len(loaded_datasets[-1])))

            if not combine:
                break

        if len(loaded_datasets) == 1:
            dataset = loaded_datasets[0]
            sizes = dataset.sizes
        else:
            dataset = ConcatDataset(loaded_datasets)
            sizes = np.concatenate([ds.sizes for ds in loaded_datasets])

        self.datasets[split] = MonolingualDataset(dataset,
                                                  sizes,
                                                  self.dictionary,
                                                  shuffle=False)
Exemple #7
0
    def load_dataset(self, split, epoch=0, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """

        print("This is the split", split)

        from fairseq.data.cvit.utils import monoling_select
        dataset = monoling_select(self.data['corpora'], split)

        from ilmulti.sentencepiece import SentencePieceTokenizer

        hard_code_dict = self.data['hard_coded_dict']

        tokenizer = SentencePieceTokenizer(hard_code_dict)
        dataset = CVITIndexedRawTextDataset(dataset, tokenizer,
                                            self.dictionary)

        if dataset is None:
            raise FileNotFoundError('Dataset not found: {} ({})'.format(
                split, split_path))

        dataset = TokenBlockDataset(
            dataset,
            dataset.sizes,
            self.args.tokens_per_sample,
            pad=self.dictionary.pad(),
            eos=self.dictionary.eos(),
            break_mode=self.args.sample_break_mode,
            include_targets=True,
        )

        add_eos_for_other_targets = self.args.sample_break_mode is not None and self.args.sample_break_mode != 'none'

        self.datasets[split] = MonolingualDataset(
            dataset,
            dataset.sizes,
            self.dictionary,
            self.output_dictionary,
            add_eos_for_other_targets=add_eos_for_other_targets,
            shuffle=True,
            targets=self.targets,
            add_bos_token=self.args.add_bos_token,
        )
Exemple #8
0
    def load_dataset_ordering(self, input_ordered_file, input_shuffled_file):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """

        loaded_datasets = []

        assert self.args.raw_text and IndexedRawTextDataset.exists(
            input_shuffled_file)
        ds = IndexedRawTextDataset(input_shuffled_file, self.dictionary)
        tokens = [t for l in ds.tokens_list for t in l]

        loaded_datasets.append(
            TokenBlockDataset(
                tokens,
                ds.sizes,
                self.args.tokens_per_sample,
                pad=self.dictionary.pad(),
                eos=self.dictionary.eos(),
                break_mode=self.args.sample_break_mode,
                include_targets=True,
            ))

        print('| {} {} examples'.format(input_shuffled_file,
                                        len(loaded_datasets[-1])))

        # if not combine:
        #     break

        assert len(loaded_datasets) == 1
        dataset = loaded_datasets[0]
        sizes = dataset.sizes

        add_eos_for_other_targets = self.args.sample_break_mode is not None and self.args.sample_break_mode != 'none'

        self.datasets['test'] = MonolingualDataset(
            dataset,
            sizes,
            self.dictionary,
            self.output_dictionary,
            add_eos_for_other_targets=add_eos_for_other_targets,
            shuffle=False,
            targets=self.targets,
        )
    def load_dataset(self, split):
        """Load a dataset split."""
        path = os.path.join(self.args.data, split)
        if self.args.raw_text and IndexedRawTextDataset.exists(path):
            ds = IndexedRawTextDataset(path, self.dictionary)
            tokens = ds.tokens_list
        elif not self.args.raw_text and IndexedInMemoryDataset.exists(path):
            ds = IndexedInMemoryDataset(path, fix_lua_indexing=True)
            tokens = ds.buffer
        else:
            raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data))

        dataset = TokenBlockDataset(
            tokens, ds.sizes, self.args.tokens_per_sample, self.args.sample_break_mode,
            include_targets=True,  # return next tokens as targets
        )
        self.datasets[split] = MonolingualDataset(dataset, dataset.sizes, self.dictionary, shuffle=False)
    def test_eval_dataloader(self):
        dictionary = test_utils.dummy_dictionary(10)
        assert len(dictionary) == 14  # 4 extra special symbols
        assert dictionary.pad() == 1

        dataset = test_utils.TestDataset([
            torch.tensor([4, 5, 6, 7], dtype=torch.long),
            torch.tensor([8, 9, 10, 11], dtype=torch.long),
            torch.tensor([12, 13], dtype=torch.long),
        ])
        dataset = MonolingualDataset(dataset,
                                     sizes=[4, 4, 2],
                                     src_vocab=dictionary)

        config = LanguageModelingConfig(tokens_per_sample=4)
        task = LanguageModelingTask(config, dictionary)

        eval_dataloader = task.eval_lm_dataloader(
            dataset=dataset,
            batch_size=1,
            context_window=2,
            num_workers=0,
        )

        batch = next(eval_dataloader)
        assert batch["net_input"]["src_tokens"][0].tolist() == [
            4, 5, 6, 7, 1, 1
        ]
        assert batch["target"][0].tolist() == [4, 5, 6, 7, 1, 1]

        batch = next(eval_dataloader)
        assert batch["net_input"]["src_tokens"][0].tolist() == [
            6, 7, 8, 9, 10, 11
        ]
        assert batch["target"][0].tolist() == [1, 1, 8, 9, 10, 11]

        batch = next(eval_dataloader)
        assert batch["net_input"]["src_tokens"][0].tolist() == [10, 11, 12, 13]
        assert batch["target"][0].tolist() == [1, 1, 12, 13]
Exemple #11
0
 def build_dataset_for_inference(self, src_tokens, src_lengths):
     return TransformEosDataset(
         MonolingualDataset(
             TokenBlockDataset(
                 src_tokens,
                 src_lengths,
                 block_size=None,
                 pad=self.source_dictionary.pad(),
                 eos=self.source_dictionary.eos(),
                 break_mode='eos',
                 include_targets=False,
             ),
             src_lengths,
             self.source_dictionary,
             self.target_dictionary,
             add_eos_for_other_targets=False,
             shuffle=False,
         ),
         eos=self.source_dictionary.eos(),
         # remove EOS since this will be used as a prefix for generation
         remove_eos_from_src=True,
         has_target=False,
     )
Exemple #12
0
 def _initialize_dataset(self, **kwargs):
     return MonolingualDataset(**kwargs)
Exemple #13
0
    def load_dataset(self, split, epoch=0, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = self.args.data.split(":")
        assert len(paths) > 0

        if self.args.multiple_datasets:
            if len(paths) == 1:
                paths = [
                    os.path.join(paths[0], p)
                    for p in next(os.walk(paths[0]))[1]
                ]
            datasets = [
                ShardedDataset(
                    self.dictionary,
                    self.args.dataset_impl,
                    path,
                    split,
                    epoch,
                    combine=combine,
                ) for path in paths
            ]

            if split in self.subsample_splits:
                sizes = [sum(d.sizes) for d in datasets]
                min_sz = min(sizes)
                ratios = [min_sz / sz for sz in sizes]
                datasets = [
                    SubsampleDataset(d, r) if r < 1 else d
                    for d, r in zip(datasets, ratios)
                ]

            dataset = ConcatDataset(datasets)
        else:
            data_path = paths[epoch % len(paths)]
            split_path = os.path.join(data_path, split)

            dataset = data_utils.load_indexed_dataset(split_path,
                                                      self.dictionary,
                                                      self.args.dataset_impl,
                                                      combine=combine)
            if dataset is None:
                raise FileNotFoundError("Dataset not found: {} ({})".format(
                    split, split_path))

        dataset = TokenBlockDataset(
            dataset,
            dataset.sizes,
            self.args.tokens_per_sample,
            pad=self.dictionary.pad(),
            eos=self.dictionary.eos(),
            break_mode=self.args.sample_break_mode,
            include_targets=True,
        )

        if self.args.prepend_ds_name:
            dataset = self.make_prepended_ds(dataset)

        dataset = ReplaceDataset(
            dataset, {self.dictionary.eos(): self.dictionary.indices['\\n']},
            offset=1)

        add_eos_for_other_targets = (self.args.sample_break_mode is not None
                                     and self.args.sample_break_mode != "none")

        self.datasets[split] = MonolingualDataset(
            dataset,
            dataset.sizes,
            self.dictionary,
            self.output_dictionary,
            add_eos_for_other_targets=add_eos_for_other_targets,
            shuffle=True,
            targets=self.targets,
            add_bos_token=self.args.add_bos_token,
        )
Exemple #14
0
    def load_dataset(self, split, epoch=0, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """

        loaded_datasets = []

        paths = self.args.data.split(':')
        assert len(paths) > 0
        data_path = paths[epoch % len(paths)]

        for k in itertools.count():
            split_k = split + (str(k) if k > 0 else '')
            path = os.path.join(data_path, split_k)
            ds = indexed_dataset.make_dataset(path,
                                              impl=self.args.dataset_impl,
                                              fix_lua_indexing=True,
                                              dictionary=self.dictionary)

            if ds is None:
                if k > 0:
                    break
                else:
                    raise FileNotFoundError(
                        'Dataset not found: {} ({})'.format(split, data_path))

            loaded_datasets.append(
                TokenBlockDataset(
                    ds,
                    ds.sizes,
                    self.args.tokens_per_sample,
                    pad=self.dictionary.pad(),
                    eos=self.dictionary.eos(),
                    break_mode=self.args.sample_break_mode,
                    include_targets=True,
                ))

            print('| {} {} {} examples'.format(data_path, split_k,
                                               len(loaded_datasets[-1])))

            if not combine:
                break

        if len(loaded_datasets) == 1:
            dataset = loaded_datasets[0]
            sizes = dataset.sizes
        else:
            dataset = ConcatDataset(loaded_datasets)
            sizes = np.concatenate([ds.sizes for ds in loaded_datasets])

        add_eos_for_other_targets = self.args.sample_break_mode is not None and self.args.sample_break_mode != 'none'

        self.datasets[split] = MonolingualDataset(
            dataset,
            sizes,
            self.dictionary,
            self.output_dictionary,
            add_eos_for_other_targets=add_eos_for_other_targets,
            shuffle=True,
            targets=self.targets,
            add_bos_token=self.args.add_bos_token,
        )
    def load_dataset(self, split: str, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        languages, data_path = MultilingualLanguageModelingTask._get_langs(
            self.args, epoch)
        lang_to_offline_shard_ratio = None
        if self.args.lang_to_offline_shard_ratio != "":
            lang_to_offline_shard_ratio = {}
            assert os.path.exists(
                self.args.lang_to_offline_shard_ratio
            ), "provided offline shard ratio file doesn't exist: {0}".format(
                self.args.lang_to_offline_shard_ratio)
            with open(self.args.lang_to_offline_shard_ratio) as fin:
                for line in fin:
                    lang, ratio = line.strip().split("\t")
                    ratio = float(ratio)
                    lang_to_offline_shard_ratio[lang] = ratio

            logger.info(
                "Found offline sharded ratio: %s",
                lang_to_offline_shard_ratio,
            )

        if split == self.args.train_subset:
            logger.info("Training on {0} languages: {1}".format(
                len(languages), languages))
        else:
            logger.info("Evaluating on {0} languages: {1}".format(
                len(languages), languages))

        tokens_per_sample = self.args.tokens_per_sample - int(
            self.args.add_bos_token)

        fixed_pad_length = None
        if self.args.pad_to_fixed_length:
            fixed_pad_length = self.args.tokens_per_sample

        pad_to_bsz = None
        if self.args.pad_to_fixed_bsz:
            pad_to_bsz = (self.args.batch_size_valid
                          if "valid" in split else self.args.batch_size)

        lang_datasets = []
        for lang_id, language in enumerate(languages):
            split_path = os.path.join(data_path, language, split)
            dataset = data_utils.load_indexed_dataset(split_path,
                                                      self.dictionary,
                                                      self.args.dataset_impl,
                                                      combine=combine)
            # print('len(dataset) =', len(dataset))
            if dataset is None:
                raise FileNotFoundError("Dataset not found: {} ({})".format(
                    split, split_path))

            dataset = maybe_shorten_dataset(
                dataset,
                split,
                self.args.shorten_data_split_list,
                self.args.shorten_method,
                tokens_per_sample,
                self.args.seed,
            )

            dataset = TokenBlockDataset(
                dataset,
                dataset.sizes,
                tokens_per_sample,
                pad=self.dictionary.pad(),
                eos=self.dictionary.eos(),
                break_mode=self.args.sample_break_mode,
                include_targets=True,
            )

            add_eos_for_other_targets = (
                self.args.sample_break_mode is not None
                and self.args.sample_break_mode != "none")
            src_lang_idx, tgt_lang_idx = None, None
            if self.args.add_bos_token:
                src_lang_idx = self.dictionary.index(lang_token(language))
                tgt_lang_idx = self.output_dictionary.index(
                    lang_token(language))

            lang_datasets.append(
                MonolingualDataset(
                    dataset=dataset,
                    sizes=dataset.sizes,
                    src_vocab=self.dictionary,
                    tgt_vocab=self.output_dictionary,
                    add_eos_for_other_targets=add_eos_for_other_targets,
                    shuffle=True,
                    targets=self.targets,
                    fixed_pad_length=fixed_pad_length,
                    pad_to_bsz=pad_to_bsz,
                    add_bos_token=self.args.add_bos_token,
                    src_lang_idx=src_lang_idx,
                    tgt_lang_idx=tgt_lang_idx,
                ))

        dataset_lengths = np.array(
            [len(d) for d in lang_datasets],
            dtype=float,
        )
        logger.info("loaded total {} blocks for all languages".format(
            dataset_lengths.sum(), ))
        if split == self.args.train_subset:
            dataset_lengths_ratio_multiplier = np.ones(len(dataset_lengths))
            if lang_to_offline_shard_ratio is not None:
                dataset_lengths_ratio_multiplier = []
                for lang in languages:
                    assert (
                        lang in lang_to_offline_shard_ratio
                    ), "Lang: {0} missing in offline shard ratio file: {1}".format(
                        lang,
                        self.args.lang_to_offline_shard_ratio,
                    )
                    dataset_lengths_ratio_multiplier.append(
                        lang_to_offline_shard_ratio[lang])
                dataset_lengths_ratio_multiplier = np.array(
                    dataset_lengths_ratio_multiplier)
                true_dataset_lengths = (dataset_lengths *
                                        dataset_lengths_ratio_multiplier)
            else:
                true_dataset_lengths = dataset_lengths
            # For train subset, additionally up or down sample languages.
            sample_probs = self._get_sample_prob(true_dataset_lengths)

            logger.info(
                "Sample probability by language: %s",
                {
                    lang: "{0:.4f}".format(sample_probs[id])
                    for id, lang in enumerate(languages)
                },
            )
            size_ratio = (sample_probs *
                          true_dataset_lengths.sum()) / dataset_lengths
            # TODO: add an option for shrinking all size ratios to below 1
            # if self.args.multilang_sampling_alpha != 1:
            #   size_ratio /= size_ratio.max()

            # Fix numeric errors in size ratio computation
            #   0.999999999999999999 -> 1
            #   1.000000000000000002 -> 1
            for i in range(len(size_ratio)):
                size_ratio[i] = round(size_ratio[i], 8)

            logger.info(
                "Up/Down Sampling ratio by language: %s",
                {
                    lang: "{0:.2f}".format(size_ratio[id])
                    for id, lang in enumerate(languages)
                },
            )
            logger.info(
                "Actual dataset size by language: %s",
                {
                    lang: "{0:.2f}".format(len(lang_datasets[id]))
                    for id, lang in enumerate(languages)
                },
            )
            resampled_lang_datasets = [
                ResamplingDataset(
                    lang_datasets[i],
                    size_ratio=size_ratio[i],
                    seed=self.args.seed,
                    epoch=epoch,
                    replace=size_ratio[i] > 1.0,
                ) for i, d in enumerate(lang_datasets)
            ]
            logger.info(
                "Resampled dataset size by language: %s",
                {
                    lang: "{0:.2f}".format(len(resampled_lang_datasets[id]))
                    for id, lang in enumerate(languages)
                },
            )
            dataset = ConcatDataset(resampled_lang_datasets)
        else:
            dataset = ConcatDataset(lang_datasets)
            lang_splits = [split]
            for lang_id, lang_dataset in enumerate(lang_datasets):
                split_name = split + "_" + languages[lang_id]
                lang_splits.append(split_name)
                self.datasets[split_name] = lang_dataset

            # [TODO]: This is hacky for now to print validation ppl for each
            # language individually. Maybe need task API changes to allow it
            # in more generic ways.
            if split in self.args.valid_subset:
                self.args.valid_subset = self.args.valid_subset.replace(
                    split, ",".join(lang_splits))

        with data_utils.numpy_seed(self.args.seed + epoch):
            shuffle = np.random.permutation(len(dataset))

        self.datasets[split] = SortDataset(
            dataset,
            sort_order=[
                shuffle,
                dataset.sizes,
            ],
        )
    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = utils.split_paths(self.args.data)
        assert len(paths) > 0

        if self.args.multiple_datasets:
            if len(paths) == 1:
                paths = [
                    os.path.join(paths[0], p)
                    for p in next(os.walk(paths[0]))[1]
                ]
            datasets = [
                ShardedDataset(
                    self.dictionary,
                    self.args.dataset_impl,
                    path,
                    split,
                    epoch - 1,
                    combine=combine,
                ) for path in paths
            ]

            ds_names = [ds.name for ds in datasets]

            if split in self.subsample_splits:
                sizes = [sum(d.sizes) for d in datasets]
                min_sz = min(sizes)
                ratios = [min_sz / sz for sz in sizes]
                datasets = [
                    SubsampleDataset(d, r) if r < 1 else d
                    for d, r in zip(datasets, ratios)
                ]

            dataset = ConcatDataset(datasets)
        else:
            data_path = paths[(epoch - 1) % len(paths)]
            split_path = os.path.join(data_path, split)

            dataset = data_utils.load_indexed_dataset(split_path,
                                                      self.dictionary,
                                                      self.args.dataset_impl,
                                                      combine=combine)
            if dataset is None:
                raise FileNotFoundError("Dataset not found: {} ({})".format(
                    split, split_path))
            ds_names = [None]

        dataset = TokenBlockDataset(
            dataset,
            dataset.sizes,
            self.args.tokens_per_sample,
            pad=self.dictionary.pad(),
            eos=self.dictionary.eos(),
            break_mode=self.args.sample_break_mode,
            include_targets=True,
        )

        if self.args.prepend_ds_name:
            dataset = PrependDataset(
                dataset,
                prepend_getter=ds_name_getter(
                    offset=0,
                    generic_ds_name_chance=self.args.generic_ds_name_chance,
                    dictionary=self.dictionary,
                ),
                ensure_first_token_is=self.dictionary.eos(),
            )

        dataset = ReplaceDataset(
            dataset,
            replace_map={
                self.dictionary.eos(): self.dictionary.indices["\\n"]
            },
            offsets=[1, -1],
        )

        add_eos_for_other_targets = (self.args.sample_break_mode is not None
                                     and self.args.sample_break_mode != "none")

        dataset = MonolingualDataset(
            dataset,
            dataset.sizes,
            self.dictionary,
            self.output_dictionary,
            add_eos_for_other_targets=add_eos_for_other_targets,
            shuffle=True,
            targets=self.targets,
            add_bos_token=self.args.add_bos_token,
        )

        if self.args.colorize_ds_name:
            ds_names.append("generic")
            min_ds = min(self.dictionary.indices[n] for n in ds_names)
            dataset = ColorizeDataset(
                dataset,
                color_getter=ds_name_getter(
                    offset=-min_ds,
                    generic_ds_name_chance=self.args.generic_ds_name_chance,
                    dictionary=self.dictionary,
                ),
            )

        self.datasets[split] = dataset
    def load_dataset(self, split, epoch=0, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = self.args.data.split(':')
        assert len(paths) > 0
        data_path = paths[epoch % len(paths)]

        context_compress = None
        if self.args.context_form != 'code' and self.args.context_compress is not None:
            context_compress = list(
                map(int,
                    self.args.context_compress.strip().split(',')))

        # infer langcode
        src, tgt = self.args.source_lang, self.args.target_lang

        langpair_dataset = load_langpair_dataset(
            data_path,
            split,
            src,
            self.src_dict,
            tgt,
            self.tgt_dict,
            combine=combine,
            dataset_impl=self.args.dataset_impl,
            upsample_primary=self.args.upsample_primary,
            left_pad_source=self.args.left_pad_source,
            left_pad_target=self.args.left_pad_target,
            max_source_positions=self.args.max_source_positions,
            max_target_positions=self.args.max_target_positions,
            prepend_bos=(self.args.input_form == 'cat'),
        )

        ctx_path = os.path.join(data_path,
                                split + '.' + self.args.context_suffix)
        if self.args.context_form == 'codes':
            # ctx_dataset = RawLabelDataset([torch.IntTensor(map(int, line.strip().split())) for line in open(ctx_path).readlines()])
            # ctx_dataset = ReferenceDataset(ctx_dataset, index_list, sizes=ctx_dataset.sizes)
            raise NotImplementedError
        elif self.args.context_form == 'sent':
            ctx_dataset = langpair_dataset.src
        elif self.args.context_form == 'doc' or self.args.context_form == 'window':
            ctx_dataset = data_utils.load_indexed_dataset(
                ctx_path, self.ctx_dict, self.args.dataset_impl, combine=False
            )  # in fact, the binary datasets doesn't need the dict
            if ctx_dataset is None:
                raise FileNotFoundError("Dataset not found: {}".format(
                    os.path.join(data_path, ctx_path)))

            dataset = DocBlockDataset(
                ctx_dataset,
                ctx_dataset.sizes,
                self.args.tokens_per_sample,
                pad=self.ctx_dict.pad(),
                eos=self.ctx_dict.eos(),
                break_mode='complete_doc',
                include_targets=False,
                context_mode=self.args.context_form,
                window_size=self.args.window_size,
            )
            print("| Loaded {} documents/context!".format(len(dataset)))
            assert len(dataset) == len(langpair_dataset.src)
            # return {'id': index, 'source': source, 'target': target}: target = None
            ctx_dataset = MonolingualDataset(
                dataset,
                dataset.sizes,
                self.ctx_dict,
                self.ctx_dict,
                add_eos_for_other_targets=False,
                shuffle=False,
                targets=None,
                add_bos_token=False,
            )
        else:
            raise ValueError

        self.datasets[split] = ContextLanguagePairDataset(
            ctx_dataset,
            langpair_dataset,
            input_form=self.args.input_form,
            context_form=self.args.context_form,
            context_compress=context_compress,
            context_dict=self.ctx_dict)