def load_sentence(self, split, sentence): loaded_datasets = [] words = sentence.split(' ') ds = IndexedRawTextDataset(words, self.dictionary) loaded_datasets.append( TokenBlockDataset( ds, ds.sizes, self.args.tokens_per_sample, pad=self.dictionary.pad(), eos=self.dictionary.eos(), break_mode=self.args.sample_break_mode, include_targets=True, )) if len(loaded_datasets) == 1: dataset = loaded_datasets[0] sizes = dataset.sizes else: dataset = ConcatDataset(loaded_datasets) sizes = np.concatenate([ds.sizes for ds in loaded_datasets]) add_eos_for_other_targets = self.args.sample_break_mode is not None and self.args.sample_break_mode != 'none' self.datasets[split] = MonolingualDataset( dataset, sizes, self.dictionary, self.output_dictionary, add_eos_for_other_targets=add_eos_for_other_targets, shuffle=True, targets=self.targets, )
def load_dataset(self, split, epoch=0, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = self.args.data.split(':') assert len(paths) > 0 data_path = paths[epoch % len(paths)] split_path = os.path.join(data_path, split) dataset = data_utils.load_indexed_dataset( split_path, self.dictionary, self.args.dataset_impl, combine=combine, ) if dataset is None: raise FileNotFoundError('Dataset not found: {} ({})'.format(split, split_path)) dataset = TokenBlockDataset( dataset, dataset.sizes, self.args.tokens_per_sample, pad=self.dictionary.pad(), eos=self.dictionary.eos(), break_mode=self.args.sample_break_mode, include_targets=True, ) add_eos_for_other_targets = self.args.sample_break_mode is not None and self.args.sample_break_mode != 'none' self.datasets[split] = MonolingualDataset( dataset, dataset.sizes, self.dictionary, self.output_dictionary, add_eos_for_other_targets=add_eos_for_other_targets, shuffle=True, targets=self.targets, add_bos_token=self.args.add_bos_token, )
def load_dataset(self, split, epoch=0, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ use_ctx_dataset = getattr(self.vqvae_args, 'use_context_dataset', 0) paths = self.vqvae_args.data.split(":") assert len(paths) > 0 data_path = paths[epoch % len(paths)] split_path = os.path.join(data_path, split) dataset = data_utils.load_indexed_dataset( split_path, self.dictionary, self.vqvae_args.dataset_impl, combine=combine ) if dataset is None: raise FileNotFoundError( "Dataset not found: {} ({})".format(split, split_path) ) if use_ctx_dataset: dataset = DocBlockDataset( dataset, dataset.sizes, self.vqvae_args.tokens_per_sample, pad=self.dictionary.pad(), eos=self.dictionary.eos(), break_mode=self.vqvae_args.sample_break_mode, include_targets=True, context_mode=self.vqvae_args.context_mode, window_size=self.vqvae_args.window_size, ) else: dataset = TokenBlockDataset( dataset, dataset.sizes, self.vqvae_args.tokens_per_sample, pad=self.dictionary.pad(), eos=self.dictionary.eos(), break_mode=self.vqvae_args.sample_break_mode, include_targets=True, ) add_eos_for_other_targets = ( self.vqvae_args.sample_break_mode is not None and self.vqvae_args.sample_break_mode != "none" ) self.datasets[split] = MonolingualDataset( dataset, dataset.sizes, self.dictionary, self.output_dictionary, add_eos_for_other_targets=add_eos_for_other_targets, shuffle=True, targets=self.targets, add_bos_token=self.vqvae_args.add_bos_token, )
def load_dataset(self, split, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ loaded_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') path = os.path.join(self.args.data, split_k) if self.args.raw_text and IndexedRawTextDataset.exists(path): ds = IndexedRawTextDataset(path, self.dictionary) elif not self.args.raw_text and IndexedDataset.exists(path): ds = IndexedDataset(path, fix_lua_indexing=True) else: if k > 0: break else: raise FileNotFoundError( 'Dataset not found: {} ({})'.format( split, self.args.data)) loaded_datasets.append( TokenBlockDataset( ds, self.args.tokens_per_sample, pad=self.dictionary.pad(), eos=self.dictionary.eos(), break_mode=self.args.sample_break_mode, include_targets=True, )) print('| {} {} {} examples'.format(self.args.data, split_k, len(loaded_datasets[-1]))) if not combine: break if len(loaded_datasets) == 1: dataset = loaded_datasets[0] sizes = dataset.sizes else: dataset = ConcatDataset(loaded_datasets) sizes = np.concatenate([ds.sizes for ds in loaded_datasets]) add_eos_for_other_targets = self.args.sample_break_mode is not None and self.args.sample_break_mode != 'none' self.datasets[split] = MonolingualDataset( dataset, sizes, self.dictionary, self.output_dictionary, add_eos_for_other_targets=add_eos_for_other_targets, shuffle=True, targets=self.targets, )
def load_dataset(self, split: str, epoch=1, combine=False, **kwargs) -> MonolingualDataset: """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = utils.split_paths(self.args.data) assert len(paths) > 0 data_path = paths[(epoch - 1) % len(paths)] split_path = os.path.join(data_path, split) # each process has its own copy of the raw data (likely to be an np.memmap) dataset = data_utils.load_indexed_dataset(split_path, self.dictionary, self.args.dataset_impl, combine=combine) if dataset is None: raise FileNotFoundError( f"Dataset not found: {split} ({split_path})") dataset = maybe_shorten_dataset( dataset, split, self.args.shorten_data_split_list, self.args.shorten_method, self.args.tokens_per_sample, self.args.seed, ) dataset = TokenBlockDataset( dataset, dataset.sizes, self.args.tokens_per_sample, pad=self.dictionary.pad(), eos=self.dictionary.eos(), break_mode=self.args.sample_break_mode, include_targets=True, use_plasma_view=self.args.use_plasma_view, split_path=split_path, plasma_path=self.args.plasma_path, ) add_eos_for_other_targets = (self.args.sample_break_mode is not None and self.args.sample_break_mode != "none") self.datasets[split] = MonolingualDataset( dataset=dataset, sizes=dataset.sizes, src_vocab=self.dictionary, tgt_vocab=self.output_dictionary, add_eos_for_other_targets=add_eos_for_other_targets, shuffle=True, targets=self.targets, add_bos_token=self.args.add_bos_token, )
def load_dataset(self, split, combine=False): """Load a dataset split.""" loaded_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') path = os.path.join(self.args.data, split_k) if self.args.raw_text and IndexedRawTextDataset.exists(path): ds = IndexedRawTextDataset(path, self.dictionary) tokens = [t for l in ds.tokens_list for t in l] elif not self.args.raw_text and IndexedInMemoryDataset.exists( path): ds = IndexedInMemoryDataset(path, fix_lua_indexing=True) tokens = ds.buffer else: if k > 0: break else: raise FileNotFoundError( 'Dataset not found: {} ({})'.format( split, self.args.data)) cbt_booktitle_idx = None if self.args.sample_break_mode == 'cbt_booktitle': if self.dictionary.index( '_BOOK_TITLE_') != self.dictionary.unk(): cbt_booktitle_idx = self.dictionary.index('_BOOK_TITLE_') loaded_datasets.append( TokenBlockDataset( tokens, ds.sizes, self.args.tokens_per_sample, self.args.sample_break_mode, include_targets=True, cbt_booktitle_idx=cbt_booktitle_idx, )) print('| {} {} {} examples'.format(self.args.data, split_k, len(loaded_datasets[-1]))) if not combine: break if len(loaded_datasets) == 1: dataset = loaded_datasets[0] sizes = dataset.sizes else: dataset = ConcatDataset(loaded_datasets) sizes = np.concatenate([ds.sizes for ds in loaded_datasets]) self.datasets[split] = MonolingualDataset(dataset, sizes, self.dictionary, shuffle=False)
def load_dataset(self, split, epoch=0, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ print("This is the split", split) from fairseq.data.cvit.utils import monoling_select dataset = monoling_select(self.data['corpora'], split) from ilmulti.sentencepiece import SentencePieceTokenizer hard_code_dict = self.data['hard_coded_dict'] tokenizer = SentencePieceTokenizer(hard_code_dict) dataset = CVITIndexedRawTextDataset(dataset, tokenizer, self.dictionary) if dataset is None: raise FileNotFoundError('Dataset not found: {} ({})'.format( split, split_path)) dataset = TokenBlockDataset( dataset, dataset.sizes, self.args.tokens_per_sample, pad=self.dictionary.pad(), eos=self.dictionary.eos(), break_mode=self.args.sample_break_mode, include_targets=True, ) add_eos_for_other_targets = self.args.sample_break_mode is not None and self.args.sample_break_mode != 'none' self.datasets[split] = MonolingualDataset( dataset, dataset.sizes, self.dictionary, self.output_dictionary, add_eos_for_other_targets=add_eos_for_other_targets, shuffle=True, targets=self.targets, add_bos_token=self.args.add_bos_token, )
def load_dataset_ordering(self, input_ordered_file, input_shuffled_file): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ loaded_datasets = [] assert self.args.raw_text and IndexedRawTextDataset.exists( input_shuffled_file) ds = IndexedRawTextDataset(input_shuffled_file, self.dictionary) tokens = [t for l in ds.tokens_list for t in l] loaded_datasets.append( TokenBlockDataset( tokens, ds.sizes, self.args.tokens_per_sample, pad=self.dictionary.pad(), eos=self.dictionary.eos(), break_mode=self.args.sample_break_mode, include_targets=True, )) print('| {} {} examples'.format(input_shuffled_file, len(loaded_datasets[-1]))) # if not combine: # break assert len(loaded_datasets) == 1 dataset = loaded_datasets[0] sizes = dataset.sizes add_eos_for_other_targets = self.args.sample_break_mode is not None and self.args.sample_break_mode != 'none' self.datasets['test'] = MonolingualDataset( dataset, sizes, self.dictionary, self.output_dictionary, add_eos_for_other_targets=add_eos_for_other_targets, shuffle=False, targets=self.targets, )
def load_dataset(self, split): """Load a dataset split.""" path = os.path.join(self.args.data, split) if self.args.raw_text and IndexedRawTextDataset.exists(path): ds = IndexedRawTextDataset(path, self.dictionary) tokens = ds.tokens_list elif not self.args.raw_text and IndexedInMemoryDataset.exists(path): ds = IndexedInMemoryDataset(path, fix_lua_indexing=True) tokens = ds.buffer else: raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data)) dataset = TokenBlockDataset( tokens, ds.sizes, self.args.tokens_per_sample, self.args.sample_break_mode, include_targets=True, # return next tokens as targets ) self.datasets[split] = MonolingualDataset(dataset, dataset.sizes, self.dictionary, shuffle=False)
def test_eval_dataloader(self): dictionary = test_utils.dummy_dictionary(10) assert len(dictionary) == 14 # 4 extra special symbols assert dictionary.pad() == 1 dataset = test_utils.TestDataset([ torch.tensor([4, 5, 6, 7], dtype=torch.long), torch.tensor([8, 9, 10, 11], dtype=torch.long), torch.tensor([12, 13], dtype=torch.long), ]) dataset = MonolingualDataset(dataset, sizes=[4, 4, 2], src_vocab=dictionary) config = LanguageModelingConfig(tokens_per_sample=4) task = LanguageModelingTask(config, dictionary) eval_dataloader = task.eval_lm_dataloader( dataset=dataset, batch_size=1, context_window=2, num_workers=0, ) batch = next(eval_dataloader) assert batch["net_input"]["src_tokens"][0].tolist() == [ 4, 5, 6, 7, 1, 1 ] assert batch["target"][0].tolist() == [4, 5, 6, 7, 1, 1] batch = next(eval_dataloader) assert batch["net_input"]["src_tokens"][0].tolist() == [ 6, 7, 8, 9, 10, 11 ] assert batch["target"][0].tolist() == [1, 1, 8, 9, 10, 11] batch = next(eval_dataloader) assert batch["net_input"]["src_tokens"][0].tolist() == [10, 11, 12, 13] assert batch["target"][0].tolist() == [1, 1, 12, 13]
def build_dataset_for_inference(self, src_tokens, src_lengths): return TransformEosDataset( MonolingualDataset( TokenBlockDataset( src_tokens, src_lengths, block_size=None, pad=self.source_dictionary.pad(), eos=self.source_dictionary.eos(), break_mode='eos', include_targets=False, ), src_lengths, self.source_dictionary, self.target_dictionary, add_eos_for_other_targets=False, shuffle=False, ), eos=self.source_dictionary.eos(), # remove EOS since this will be used as a prefix for generation remove_eos_from_src=True, has_target=False, )
def _initialize_dataset(self, **kwargs): return MonolingualDataset(**kwargs)
def load_dataset(self, split, epoch=0, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = self.args.data.split(":") assert len(paths) > 0 if self.args.multiple_datasets: if len(paths) == 1: paths = [ os.path.join(paths[0], p) for p in next(os.walk(paths[0]))[1] ] datasets = [ ShardedDataset( self.dictionary, self.args.dataset_impl, path, split, epoch, combine=combine, ) for path in paths ] if split in self.subsample_splits: sizes = [sum(d.sizes) for d in datasets] min_sz = min(sizes) ratios = [min_sz / sz for sz in sizes] datasets = [ SubsampleDataset(d, r) if r < 1 else d for d, r in zip(datasets, ratios) ] dataset = ConcatDataset(datasets) else: data_path = paths[epoch % len(paths)] split_path = os.path.join(data_path, split) dataset = data_utils.load_indexed_dataset(split_path, self.dictionary, self.args.dataset_impl, combine=combine) if dataset is None: raise FileNotFoundError("Dataset not found: {} ({})".format( split, split_path)) dataset = TokenBlockDataset( dataset, dataset.sizes, self.args.tokens_per_sample, pad=self.dictionary.pad(), eos=self.dictionary.eos(), break_mode=self.args.sample_break_mode, include_targets=True, ) if self.args.prepend_ds_name: dataset = self.make_prepended_ds(dataset) dataset = ReplaceDataset( dataset, {self.dictionary.eos(): self.dictionary.indices['\\n']}, offset=1) add_eos_for_other_targets = (self.args.sample_break_mode is not None and self.args.sample_break_mode != "none") self.datasets[split] = MonolingualDataset( dataset, dataset.sizes, self.dictionary, self.output_dictionary, add_eos_for_other_targets=add_eos_for_other_targets, shuffle=True, targets=self.targets, add_bos_token=self.args.add_bos_token, )
def load_dataset(self, split, epoch=0, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ loaded_datasets = [] paths = self.args.data.split(':') assert len(paths) > 0 data_path = paths[epoch % len(paths)] for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') path = os.path.join(data_path, split_k) ds = indexed_dataset.make_dataset(path, impl=self.args.dataset_impl, fix_lua_indexing=True, dictionary=self.dictionary) if ds is None: if k > 0: break else: raise FileNotFoundError( 'Dataset not found: {} ({})'.format(split, data_path)) loaded_datasets.append( TokenBlockDataset( ds, ds.sizes, self.args.tokens_per_sample, pad=self.dictionary.pad(), eos=self.dictionary.eos(), break_mode=self.args.sample_break_mode, include_targets=True, )) print('| {} {} {} examples'.format(data_path, split_k, len(loaded_datasets[-1]))) if not combine: break if len(loaded_datasets) == 1: dataset = loaded_datasets[0] sizes = dataset.sizes else: dataset = ConcatDataset(loaded_datasets) sizes = np.concatenate([ds.sizes for ds in loaded_datasets]) add_eos_for_other_targets = self.args.sample_break_mode is not None and self.args.sample_break_mode != 'none' self.datasets[split] = MonolingualDataset( dataset, sizes, self.dictionary, self.output_dictionary, add_eos_for_other_targets=add_eos_for_other_targets, shuffle=True, targets=self.targets, add_bos_token=self.args.add_bos_token, )
def load_dataset(self, split: str, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ languages, data_path = MultilingualLanguageModelingTask._get_langs( self.args, epoch) lang_to_offline_shard_ratio = None if self.args.lang_to_offline_shard_ratio != "": lang_to_offline_shard_ratio = {} assert os.path.exists( self.args.lang_to_offline_shard_ratio ), "provided offline shard ratio file doesn't exist: {0}".format( self.args.lang_to_offline_shard_ratio) with open(self.args.lang_to_offline_shard_ratio) as fin: for line in fin: lang, ratio = line.strip().split("\t") ratio = float(ratio) lang_to_offline_shard_ratio[lang] = ratio logger.info( "Found offline sharded ratio: %s", lang_to_offline_shard_ratio, ) if split == self.args.train_subset: logger.info("Training on {0} languages: {1}".format( len(languages), languages)) else: logger.info("Evaluating on {0} languages: {1}".format( len(languages), languages)) tokens_per_sample = self.args.tokens_per_sample - int( self.args.add_bos_token) fixed_pad_length = None if self.args.pad_to_fixed_length: fixed_pad_length = self.args.tokens_per_sample pad_to_bsz = None if self.args.pad_to_fixed_bsz: pad_to_bsz = (self.args.batch_size_valid if "valid" in split else self.args.batch_size) lang_datasets = [] for lang_id, language in enumerate(languages): split_path = os.path.join(data_path, language, split) dataset = data_utils.load_indexed_dataset(split_path, self.dictionary, self.args.dataset_impl, combine=combine) # print('len(dataset) =', len(dataset)) if dataset is None: raise FileNotFoundError("Dataset not found: {} ({})".format( split, split_path)) dataset = maybe_shorten_dataset( dataset, split, self.args.shorten_data_split_list, self.args.shorten_method, tokens_per_sample, self.args.seed, ) dataset = TokenBlockDataset( dataset, dataset.sizes, tokens_per_sample, pad=self.dictionary.pad(), eos=self.dictionary.eos(), break_mode=self.args.sample_break_mode, include_targets=True, ) add_eos_for_other_targets = ( self.args.sample_break_mode is not None and self.args.sample_break_mode != "none") src_lang_idx, tgt_lang_idx = None, None if self.args.add_bos_token: src_lang_idx = self.dictionary.index(lang_token(language)) tgt_lang_idx = self.output_dictionary.index( lang_token(language)) lang_datasets.append( MonolingualDataset( dataset=dataset, sizes=dataset.sizes, src_vocab=self.dictionary, tgt_vocab=self.output_dictionary, add_eos_for_other_targets=add_eos_for_other_targets, shuffle=True, targets=self.targets, fixed_pad_length=fixed_pad_length, pad_to_bsz=pad_to_bsz, add_bos_token=self.args.add_bos_token, src_lang_idx=src_lang_idx, tgt_lang_idx=tgt_lang_idx, )) dataset_lengths = np.array( [len(d) for d in lang_datasets], dtype=float, ) logger.info("loaded total {} blocks for all languages".format( dataset_lengths.sum(), )) if split == self.args.train_subset: dataset_lengths_ratio_multiplier = np.ones(len(dataset_lengths)) if lang_to_offline_shard_ratio is not None: dataset_lengths_ratio_multiplier = [] for lang in languages: assert ( lang in lang_to_offline_shard_ratio ), "Lang: {0} missing in offline shard ratio file: {1}".format( lang, self.args.lang_to_offline_shard_ratio, ) dataset_lengths_ratio_multiplier.append( lang_to_offline_shard_ratio[lang]) dataset_lengths_ratio_multiplier = np.array( dataset_lengths_ratio_multiplier) true_dataset_lengths = (dataset_lengths * dataset_lengths_ratio_multiplier) else: true_dataset_lengths = dataset_lengths # For train subset, additionally up or down sample languages. sample_probs = self._get_sample_prob(true_dataset_lengths) logger.info( "Sample probability by language: %s", { lang: "{0:.4f}".format(sample_probs[id]) for id, lang in enumerate(languages) }, ) size_ratio = (sample_probs * true_dataset_lengths.sum()) / dataset_lengths # TODO: add an option for shrinking all size ratios to below 1 # if self.args.multilang_sampling_alpha != 1: # size_ratio /= size_ratio.max() # Fix numeric errors in size ratio computation # 0.999999999999999999 -> 1 # 1.000000000000000002 -> 1 for i in range(len(size_ratio)): size_ratio[i] = round(size_ratio[i], 8) logger.info( "Up/Down Sampling ratio by language: %s", { lang: "{0:.2f}".format(size_ratio[id]) for id, lang in enumerate(languages) }, ) logger.info( "Actual dataset size by language: %s", { lang: "{0:.2f}".format(len(lang_datasets[id])) for id, lang in enumerate(languages) }, ) resampled_lang_datasets = [ ResamplingDataset( lang_datasets[i], size_ratio=size_ratio[i], seed=self.args.seed, epoch=epoch, replace=size_ratio[i] > 1.0, ) for i, d in enumerate(lang_datasets) ] logger.info( "Resampled dataset size by language: %s", { lang: "{0:.2f}".format(len(resampled_lang_datasets[id])) for id, lang in enumerate(languages) }, ) dataset = ConcatDataset(resampled_lang_datasets) else: dataset = ConcatDataset(lang_datasets) lang_splits = [split] for lang_id, lang_dataset in enumerate(lang_datasets): split_name = split + "_" + languages[lang_id] lang_splits.append(split_name) self.datasets[split_name] = lang_dataset # [TODO]: This is hacky for now to print validation ppl for each # language individually. Maybe need task API changes to allow it # in more generic ways. if split in self.args.valid_subset: self.args.valid_subset = self.args.valid_subset.replace( split, ",".join(lang_splits)) with data_utils.numpy_seed(self.args.seed + epoch): shuffle = np.random.permutation(len(dataset)) self.datasets[split] = SortDataset( dataset, sort_order=[ shuffle, dataset.sizes, ], )
def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = utils.split_paths(self.args.data) assert len(paths) > 0 if self.args.multiple_datasets: if len(paths) == 1: paths = [ os.path.join(paths[0], p) for p in next(os.walk(paths[0]))[1] ] datasets = [ ShardedDataset( self.dictionary, self.args.dataset_impl, path, split, epoch - 1, combine=combine, ) for path in paths ] ds_names = [ds.name for ds in datasets] if split in self.subsample_splits: sizes = [sum(d.sizes) for d in datasets] min_sz = min(sizes) ratios = [min_sz / sz for sz in sizes] datasets = [ SubsampleDataset(d, r) if r < 1 else d for d, r in zip(datasets, ratios) ] dataset = ConcatDataset(datasets) else: data_path = paths[(epoch - 1) % len(paths)] split_path = os.path.join(data_path, split) dataset = data_utils.load_indexed_dataset(split_path, self.dictionary, self.args.dataset_impl, combine=combine) if dataset is None: raise FileNotFoundError("Dataset not found: {} ({})".format( split, split_path)) ds_names = [None] dataset = TokenBlockDataset( dataset, dataset.sizes, self.args.tokens_per_sample, pad=self.dictionary.pad(), eos=self.dictionary.eos(), break_mode=self.args.sample_break_mode, include_targets=True, ) if self.args.prepend_ds_name: dataset = PrependDataset( dataset, prepend_getter=ds_name_getter( offset=0, generic_ds_name_chance=self.args.generic_ds_name_chance, dictionary=self.dictionary, ), ensure_first_token_is=self.dictionary.eos(), ) dataset = ReplaceDataset( dataset, replace_map={ self.dictionary.eos(): self.dictionary.indices["\\n"] }, offsets=[1, -1], ) add_eos_for_other_targets = (self.args.sample_break_mode is not None and self.args.sample_break_mode != "none") dataset = MonolingualDataset( dataset, dataset.sizes, self.dictionary, self.output_dictionary, add_eos_for_other_targets=add_eos_for_other_targets, shuffle=True, targets=self.targets, add_bos_token=self.args.add_bos_token, ) if self.args.colorize_ds_name: ds_names.append("generic") min_ds = min(self.dictionary.indices[n] for n in ds_names) dataset = ColorizeDataset( dataset, color_getter=ds_name_getter( offset=-min_ds, generic_ds_name_chance=self.args.generic_ds_name_chance, dictionary=self.dictionary, ), ) self.datasets[split] = dataset
def load_dataset(self, split, epoch=0, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = self.args.data.split(':') assert len(paths) > 0 data_path = paths[epoch % len(paths)] context_compress = None if self.args.context_form != 'code' and self.args.context_compress is not None: context_compress = list( map(int, self.args.context_compress.strip().split(','))) # infer langcode src, tgt = self.args.source_lang, self.args.target_lang langpair_dataset = load_langpair_dataset( data_path, split, src, self.src_dict, tgt, self.tgt_dict, combine=combine, dataset_impl=self.args.dataset_impl, upsample_primary=self.args.upsample_primary, left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, max_source_positions=self.args.max_source_positions, max_target_positions=self.args.max_target_positions, prepend_bos=(self.args.input_form == 'cat'), ) ctx_path = os.path.join(data_path, split + '.' + self.args.context_suffix) if self.args.context_form == 'codes': # ctx_dataset = RawLabelDataset([torch.IntTensor(map(int, line.strip().split())) for line in open(ctx_path).readlines()]) # ctx_dataset = ReferenceDataset(ctx_dataset, index_list, sizes=ctx_dataset.sizes) raise NotImplementedError elif self.args.context_form == 'sent': ctx_dataset = langpair_dataset.src elif self.args.context_form == 'doc' or self.args.context_form == 'window': ctx_dataset = data_utils.load_indexed_dataset( ctx_path, self.ctx_dict, self.args.dataset_impl, combine=False ) # in fact, the binary datasets doesn't need the dict if ctx_dataset is None: raise FileNotFoundError("Dataset not found: {}".format( os.path.join(data_path, ctx_path))) dataset = DocBlockDataset( ctx_dataset, ctx_dataset.sizes, self.args.tokens_per_sample, pad=self.ctx_dict.pad(), eos=self.ctx_dict.eos(), break_mode='complete_doc', include_targets=False, context_mode=self.args.context_form, window_size=self.args.window_size, ) print("| Loaded {} documents/context!".format(len(dataset))) assert len(dataset) == len(langpair_dataset.src) # return {'id': index, 'source': source, 'target': target}: target = None ctx_dataset = MonolingualDataset( dataset, dataset.sizes, self.ctx_dict, self.ctx_dict, add_eos_for_other_targets=False, shuffle=False, targets=None, add_bos_token=False, ) else: raise ValueError self.datasets[split] = ContextLanguagePairDataset( ctx_dataset, langpair_dataset, input_form=self.args.input_form, context_form=self.args.context_form, context_compress=context_compress, context_dict=self.ctx_dict)