Пример #1
0
 def __load_dataset(self, split, lang_pair):
     src, tgt = lang_pair.split('-')
     datasets = []
     transcr_datasets = []
     teacher_probs_datasets = []
     teacher_idxs_datasets = []
     for path in self.paths:
         try:
             ds = get_datasets_from_indexed_filterbanks(
                 path, tgt, self.dicts[tgt], split, self.args.dataset_impl,
                 self.args.skip_normalization,
                 self.args.legacy_audio_fix_lua_indexing)
             transcr_ds = data_utils.load_indexed_dataset(
                 os.path.join(path, split) + "." + src, self.dicts[src],
                 self.args.dataset_impl)
             tgt_prefix = os.path.join(path, split) + "." + tgt
             teacher_idxs_fname = tgt_prefix + '.top{}_idx'.format(
                 self.args.distill_topk)
             teacher_probs_fname = tgt_prefix + '.top{}_out'.format(
                 self.args.distill_topk)
             assert IndexedDataset.exists(teacher_idxs_fname) and IndexedDataset.exists(teacher_probs_fname), \
                 "Teacher datasets not found in {}".format(tgt_prefix)
             teacher_probs_datasets.append(
                 TeacherOutputDataset(teacher_probs_fname, np.float32))
             teacher_idxs_datasets.append(
                 TeacherOutputDataset(teacher_idxs_fname, np.int32))
             assert transcr_ds is not None, "Transcription dataset not found in {}".format(
                 os.path.join(path, split))
             transcr_datasets.append(transcr_ds)
             datasets.append(ds)
         except Exception:
             logger.warning("Split {} not found in {}. Skipping...".format(
                 split, path))
     assert len(datasets) > 0
     assert len(datasets) == len(transcr_datasets)
     assert len(datasets) == len(teacher_probs_datasets)
     assert len(datasets) == len(teacher_idxs_datasets)
     if len(datasets) > 1:
         dataset = ConcatDataset(datasets)
         transcr_dataset = ConcatDataset(transcr_datasets)
         teacher_idxs_dataset = ConcatDataset(teacher_idxs_datasets)
         teacher_probs_dataset = ConcatDataset(teacher_probs_datasets)
     else:
         dataset = datasets[0]
         transcr_dataset = transcr_datasets[0]
         teacher_idxs_dataset = teacher_idxs_datasets[0]
         teacher_probs_dataset = teacher_probs_datasets[0]
     dataset_with_transcr = TranscriptionWrapperDataset(
         dataset, transcr_dataset, self.dicts[src])
     dataset_with_kd = DatasetWithTeacherOutput(dataset_with_transcr,
                                                teacher_probs_dataset,
                                                teacher_idxs_dataset,
                                                self.dicts[tgt],
                                                self.args.distill_topk)
     return self.alter_dataset_langtok(dataset_with_kd,
                                       src_eos=self.dicts[src].eos(),
                                       src_lang=src,
                                       tgt_eos=self.dicts[tgt].eos(),
                                       tgt_lang=tgt)
 def indexed_dataset_hter(path, dictionary):
     if IndexedDataset.exists(path):
         if self.args.lazy_load:
             return IndexedDataset(path, fix_lua_indexing=True)
         else:
             return IndexedCachedDataset(path, fix_lua_indexing=True)
     return None
Пример #3
0
 def split_exists(split, src, tgt, lang, data_path):
     filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang))
     if self.args.dataset_impl == 'raw' and IndexedRawTextDataset.exists(filename):
         return True
     elif self.args.dataset_impl != 'raw' and IndexedDataset.exists(filename):
         return True
     return False
Пример #4
0
 def split_exists(split, src, tgt, lang):
     filename = os.path.join(self.args.data, '{}.{}-{}.{}'.format(split, src, tgt, lang))
     if self.args.raw_text and IndexedRawTextDataset.exists(filename):
         return True
     elif not self.args.raw_text and IndexedDataset.exists(filename):
         return True
     return False
 def split_exists(split, data_type, data_path):
     filename = os.path.join(data_path, f'{split}.{data_type}')
     if self.args.raw_text and IndexedRawTextDataset.exists(filename):
         return True
     elif not self.args.raw_text and IndexedDataset.exists(filename):
         return True
     return False
Пример #6
0
    def load_dataset(self, split, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """

        loaded_datasets = []

        for k in itertools.count():
            split_k = split + (str(k) if k > 0 else '')
            path = os.path.join(self.args.data, split_k)

            if self.args.raw_text and IndexedRawTextDataset.exists(path):
                ds = IndexedRawTextDataset(path, self.dictionary)
            elif not self.args.raw_text and IndexedDataset.exists(path):
                ds = IndexedDataset(path, fix_lua_indexing=True)
            else:
                if k > 0:
                    break
                else:
                    raise FileNotFoundError(
                        'Dataset not found: {} ({})'.format(
                            split, self.args.data))

            loaded_datasets.append(
                TokenBlockDataset(
                    ds,
                    self.args.tokens_per_sample,
                    pad=self.dictionary.pad(),
                    eos=self.dictionary.eos(),
                    break_mode=self.args.sample_break_mode,
                    include_targets=True,
                ))

            print('| {} {} {} examples'.format(self.args.data, split_k,
                                               len(loaded_datasets[-1])))

            if not combine:
                break

        if len(loaded_datasets) == 1:
            dataset = loaded_datasets[0]
            sizes = dataset.sizes
        else:
            dataset = ConcatDataset(loaded_datasets)
            sizes = np.concatenate([ds.sizes for ds in loaded_datasets])

        add_eos_for_other_targets = self.args.sample_break_mode is not None and self.args.sample_break_mode != 'none'

        self.datasets[split] = MonolingualDataset(
            dataset,
            sizes,
            self.dictionary,
            self.output_dictionary,
            add_eos_for_other_targets=add_eos_for_other_targets,
            shuffle=True,
            targets=self.targets,
        )
Пример #7
0
 def split_para_exists(split, key, lang):
     filename = os.path.join(self.args.data, '{}.{}.{}'.format(split, key, lang))
     print(filename); print(self.args.raw_text)
     if self.args.raw_text and IndexedRawTextDataset.exists(filename):
         return True
     elif not self.args.raw_text and IndexedDataset.exists(filename):
         return True
     return False
Пример #8
0
 def indexed_dataset(path, dictionary, copy_ext_dict=False, src_dataset=None):
     if self.args.raw_text:
         return IndexedRawTextDataset(path, dictionary, copy_ext_dict=copy_ext_dict, src_dataset=src_dataset)
     elif IndexedDataset.exists(path):
         if self.args.lazy_load:
             return IndexedDataset(path, fix_lua_indexing=True)
         else:
             return IndexedCachedDataset(path, fix_lua_indexing=True)
     return None
Пример #9
0
 def split_exists(split, src, tgt, lang, data_path):
     if src is not None:
         filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang))
     else:
         filename = os.path.join(data_path, '{}.{}-None.{}'.format(split, src, tgt))
     if self.args.raw_text and IndexedRawTextDataset.exists(filename):
         return True
     elif not self.args.raw_text and IndexedDataset.exists(filename):
         return True
     return False
Пример #10
0
 def indexed_dataset(path, dictionary, ex_dict=None, is_tgt=False):
     if self.args.segment:
         #if self.args.raw_text:
         return IndexedRawTextSegDataset(path, dictionary, ex_dict, is_tgt)
     else:
         if self.args.raw_text:
             return IndexedRawTextDataset(path, dictionary)
         elif IndexedDataset.exists(path):
             return IndexedCachedDataset(path, fix_lua_indexing=True)
     return None
 def indexed_dataset(path):
     assert IndexedDataset.exists(path), f'IndexedDataset.exists({path})'
     # if self.args.raw_text:
     #     return IndexedRawTextDataset(path, dictionary)
     # elif IndexedDataset.exists(path):
     #     if self.args.lazy_load:
     #         return IndexedDataset(path, fix_lua_indexing=True)
     #     else:
     #         return IndexedCachedDataset(path, fix_lua_indexing=True)
     # return None
     return IndexedCachedDataset(path, fix_lua_indexing=True)
Пример #12
0
 def indexed_sememe_dataset(path, dictionary):
     if self.args.raw_text:
         return IndexedRawSememeTextDataset(path,
                                            dictionary,
                                            append_eos=False)
     elif IndexedDataset.exists(path):
         if self.args.lazy_load:
             raise NotImplementedError
         else:
             raise NotImplementedError
     return None
    def load_dataset(self, split, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        if self.args.dataset_from_json:
            # TODO: not implemented yet
            raise NotImplementedError
        else:
            super().load_dataset(split, combine=combine, **kwargs)
            teacher_probs_datasets = []
            teacher_idxs_datasets = []
            for path in self.paths:
                prefix = os.path.join(path,
                                      split) + "." + self.args.target_lang
                teacher_idxs_fname = prefix + '.top{}_idx'.format(
                    self.args.distill_topk)
                teacher_probs_fname = prefix + '.top{}_out'.format(
                    self.args.distill_topk)
                if IndexedDataset.exists(
                        teacher_idxs_fname) and IndexedDataset.exists(
                            teacher_probs_fname):
                    teacher_probs_datasets.append(
                        TeacherOutputDataset(teacher_probs_fname, np.float32))
                    teacher_idxs_datasets.append(
                        TeacherOutputDataset(teacher_idxs_fname, np.int32))

            assert len(teacher_idxs_datasets) > 0
            assert len(teacher_probs_datasets) > 0
            if len(teacher_idxs_datasets) > 1:
                teacher_idxs_dataset = ConcatDataset(teacher_idxs_datasets)
                teacher_probs_dataset = ConcatDataset(teacher_probs_datasets)
            else:
                teacher_idxs_dataset = teacher_idxs_datasets[0]
                teacher_probs_dataset = teacher_probs_datasets[0]
        assert len(self.datasets[split]) == len(teacher_idxs_dataset)
        assert len(teacher_probs_dataset) == len(teacher_idxs_dataset)
        self.datasets[split] = DatasetWithTeacherOutput(
            self.datasets[split], teacher_probs_dataset, teacher_idxs_dataset,
            self.tgt_dict, self.args.distill_topk)
Пример #14
0
    def load_dataset(self, split, combine=False):
        """Load a given dataset split.
        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        dataset_map = OrderedDict()

        for lang in self.langs2id.keys():
            if self.default_key is None:
                self.default_key = lang
            # Datasets are expected to be in "split.lang" format (Eg: train.en)
            language_split = '{}.{}'.format(split, lang)
            path = os.path.join(self.args.data, language_split)

            if self.args.raw_text and IndexedRawTextDataset.exists(path):
                ds = IndexedRawTextDataset(path, self.dictionary)
            elif not self.args.raw_text and IndexedDataset.exists(path):
                if self.args.lazy_load:
                    ds = IndexedDataset(path, fix_lua_indexing=True)
                else:
                    ds = IndexedCachedDataset(path, fix_lua_indexing=True)
            else:
                raise FileNotFoundError('Dataset not found: {} ({})'.format(
                    language_split, self.args.data))

            # Since we append each block with the classification_token,
            # we need to effectively create blocks of length
            # tokens_per_sample-1
            block_dataset = TokenBlockDataset(
                dataset=ds,
                sizes=ds.sizes,
                block_size=self.args.tokens_per_sample - 1,
                pad=self.dictionary.pad(),
                eos=self.dictionary.eos())

            dataset_map[lang] = MaskedLMDataset(
                dataset=block_dataset,
                sizes=block_dataset.sizes,
                vocab=self.dictionary,
                pad_idx=self.dictionary.pad(),
                mask_idx=self.dictionary.mask(),
                classif_token_idx=self.dictionary.eos(),
                sep_token_idx=self.dictionary.eos(),
                shuffle=getattr(self.args, 'shuffle', False),
                has_pairs=False,
                segment_id=self.langs2id[lang],
                seed=self.seed,
            )

        self.datasets[split] = MultiCorpusSampledDataset(
            dataset_map, default_key=self.default_key)
        print('| {} {} {} examples'.format(self.args.data, split,
                                           len(self.datasets[split])))
 def split_exists(split, data_type, data_path):
     filename = os.path.join(data_path, f'{split}.{data_type}')
     assert not self.args.raw_text
     exists = [
         IndexedDataset.exists(
             os.path.join(data_path, f'{split}.{data_type}.{k}'))
         for k in DPTREE_KEYS
     ]
     if all(exists):
         return True
     else:
         print(f'Following modality not exists: {exists}')
         return False
Пример #16
0
 def indexed_dataset(path, dictionary, cached=True, audio=False):
     if self.args.raw_text:
         return IndexedRawTextDataset(path, dictionary)
     elif IndexedDataset.exists(path):
         if cached:
             return IndexedCachedDataset(path,
                                         fix_lua_indexing=True,
                                         audio=audio)
         else:
             return IndexedDataset(path,
                                   fix_lua_indexing=True,
                                   audio=audio)
     return None
Пример #17
0
    def _load_single_lang_dataset(self, split):
        loaded_datasets = []

        for k in itertools.count():
            split_k = split + (str(k) if k > 0 else '')
            path = os.path.join(self.args.data, split_k)

            if self.args.raw_text and IndexedRawTextDataset.exists(path):
                ds = IndexedRawTextDataset(path, self.dictionary)
            elif not self.args.raw_text and IndexedDataset.exists(path):
                if self.args.lazy_load:
                    ds = IndexedDataset(path, fix_lua_indexing=True)
                else:
                    ds = IndexedCachedDataset(path, fix_lua_indexing=True)
            else:
                if k > 0:
                    break
                else:
                    raise FileNotFoundError(
                        'Dataset not found: {} ({})'.format(
                            split, self.args.data))

            # Since we append each block with the classification_token,
            # we need to effectively create blocks of length
            # tokens_per_sample-1
            loaded_datasets.append(
                TokenBlockDataset(
                    ds,
                    ds.sizes,
                    self.args.tokens_per_sample - 1,
                    pad=self.dictionary.pad(),
                    eos=self.dictionary.eos(),
                ))

            print('| {} {} {} examples'.format(self.args.data, split_k,
                                               len(loaded_datasets[-1])))

        if len(loaded_datasets) == 1:
            dataset = loaded_datasets[0]
            sizes = dataset.sizes
        else:
            dataset = ConcatDataset(loaded_datasets)
            sizes = np.concatenate([ds.sizes for ds in loaded_datasets])

        return dataset, sizes
Пример #18
0
    def load_dataset(self, split, combine=False):
        """
        Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        loaded_datasets = []

        for k in itertools.count():
            split_k = split + (str(k) if k > 0 else '')
            path = os.path.join(self.args.data, split_k)

            if self.args.raw_text and IndexedRawTextDataset.exists(path):
                ds = IndexedRawTextDataset(path, self.dictionary)
            elif not self.args.raw_text and IndexedDataset.exists(path):
                if self.args.lazy_load:
                    ds = IndexedDataset(path, fix_lua_indexing=True)
                else:
                    ds = IndexedCachedDataset(path, fix_lua_indexing=True)
            else:
                if k > 0:
                    break
                else:
                    raise FileNotFoundError(
                        'Dataset not found: {} ({})'.format(
                            split, self.args.data))
            with data_utils.numpy_seed(self.seed + k):
                loaded_datasets.append(
                    BlockPairDataset(
                        ds,
                        self.dictionary,
                        ds.sizes,
                        self.args.tokens_per_sample,
                        break_mode=self.args.break_mode,
                    ))

            logger.info('{} {} {} examples'.format(self.args.data, split_k,
                                                   len(loaded_datasets[-1])))

            if not combine:
                break

        if len(loaded_datasets) == 1:
            dataset = loaded_datasets[0]
            sizes = dataset.sizes
        else:
            dataset = ConcatDataset(loaded_datasets)
            sizes = np.concatenate([ds.sizes for ds in loaded_datasets])

        self.datasets[split] = MaskedLMDataset(
            dataset=dataset,
            sizes=sizes,
            vocab=self.dictionary,
            pad_idx=self.dictionary.pad(),
            mask_idx=self.dictionary.mask(),
            classif_token_idx=self.dictionary.cls(),
            sep_token_idx=self.dictionary.sep(),
            shuffle=False,
            seed=self.seed,
        )
Пример #19
0
 def indexed_dataset(path, dictionary):
     if self.args.raw_text:
         return IndexedRawTextDataset(path, dictionary)
     elif IndexedDataset.exists(path):
         return IndexedCachedDataset(path, fix_lua_indexing=True)
     return None
 def indexed_dataset(path):
     assert IndexedDataset.exists(path), f'IndexedDataset.exists({path})'
     return IndexedCachedDataset(path, fix_lua_indexing=True)
Пример #21
0
 def split_exists(split, bert_pref, data_path):
     filename = os.path.join(data_path,
                             '{}-{}'.format(split, bert_pref))
     if IndexedDataset.exists(filename):
         return True
     return False
Пример #22
0
 def indexed_dataset(path):
     if IndexedDataset.exists(path):
         return BertIndexedCachedDataset(path, fix_lua_indexing=True)
     return None
 def indexed_dataset(path, dictionary):
     if self.args.raw_text:
         raise NotImplementedError
     elif IndexedDataset.exists(path):
         return DPTreeIndexedCachedDataset(path, fix_lua_indexing=True)
     return None