def get_xlm_align_dataset_with_mask(args, dataset_path, vocab, mask_idx, combine=False): ptb_dataset = get_prepended_token_block_dataset(args, dataset_path, vocab, combine=combine) src_dataset, tgt_dataset = MaskTokensDataset.apply_mask( ptb_dataset, vocab=vocab, pad_idx=vocab.pad(), mask_idx=mask_idx, seed=args.seed, mask_prob=args.mask_prob, ) dataset = NestedDictionaryDataset( { 'net_input': { 'src_tokens': PadDataset( ptb_dataset, pad_idx=vocab.pad(), left_pad=False, ), 'src_lengths': NumelDataset(ptb_dataset, reduce=False), }, 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(ptb_dataset, reduce=True), 'offsets': OffsetDataset(ptb_dataset, vocab), 'net_input_tlm': { 'src_tokens': PadDataset( src_dataset, pad_idx=vocab.pad(), left_pad=False, ), 'src_lengths': NumelDataset(src_dataset, reduce=False), }, 'target': PadDataset( tgt_dataset, pad_idx=vocab.pad(), left_pad=False, ), }, sizes=[ptb_dataset.sizes]) return dataset
def get_xlco_dataset(args, dataset_path, vocab, mask_idx, combine=False): dataset = data_utils.load_indexed_dataset(dataset_path, vocab, args.dataset_impl, combine=combine) dataset, _ = MaskTokensDataset.apply_mask( dataset, vocab=vocab, pad_idx=vocab.pad(), mask_idx=mask_idx, seed=args.seed, mask_prob=args.mask_prob, mask_whole_words=None, ) dataset = XlcoDataset(dataset, vocab) return dataset
def get_mlm_dataset(args, dataset_path, vocab, mask_idx, mask_whole_words=None, combine=False): ptb_dataset = get_prepended_token_block_dataset(args, dataset_path, vocab, combine=combine) src_dataset, tgt_dataset = MaskTokensDataset.apply_mask( ptb_dataset, vocab=vocab, pad_idx=vocab.pad(), mask_idx=mask_idx, seed=args.seed, mask_prob=args.mask_prob, mask_whole_words=mask_whole_words, ) dataset = NestedDictionaryDataset( { 'net_input': { 'src_tokens': PadDataset( src_dataset, pad_idx=vocab.pad(), left_pad=False, ), 'src_lengths': NumelDataset(src_dataset, reduce=False), }, 'target': PadDataset( tgt_dataset, pad_idx=vocab.pad(), left_pad=False, ), 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_dataset, reduce=True), # 'lang_id': RawLabelDataset([lang_id] * src_dataset.sizes.shape[0]), }, sizes=[src_dataset.sizes], ) return dataset
def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = utils.split_paths(self.args.data) assert len(paths) > 0 data_path = paths[(epoch - 1) % len(paths)] split_path = os.path.join(data_path, split) dataset = data_utils.load_indexed_dataset( split_path, self.source_dictionary, self.args.dataset_impl, combine=combine, ) if dataset is None: raise FileNotFoundError("Dataset not found: {} ({})".format( split, split_path)) dataset = maybe_shorten_dataset( dataset, split, self.args.shorten_data_split_list, self.args.shorten_method, self.args.tokens_per_sample, self.args.seed, ) # create continuous blocks of tokens dataset = TokenBlockDataset( dataset, dataset.sizes, self.args.tokens_per_sample - 1, # one less for <s> pad=self.source_dictionary.pad(), eos=self.source_dictionary.eos(), break_mode=self.args.sample_break_mode, ) logger.info("loaded {} blocks from: {}".format(len(dataset), split_path)) # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT) dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) # create masked input and targets mask_whole_words = (get_whole_word_mask(self.args, self.source_dictionary) if self.args.mask_whole_words else None) src_dataset, tgt_dataset = MaskTokensDataset.apply_mask( dataset, self.source_dictionary, pad_idx=self.source_dictionary.pad(), mask_idx=self.mask_idx, seed=self.args.seed, mask_prob=self.args.mask_prob, leave_unmasked_prob=self.args.leave_unmasked_prob, random_token_prob=self.args.random_token_prob, freq_weighted_replacement=self.args.freq_weighted_replacement, mask_whole_words=mask_whole_words, mask_multiple_length=self.args.mask_multiple_length, mask_stdev=self.args.mask_stdev, ) with data_utils.numpy_seed(self.args.seed + epoch): shuffle = np.random.permutation(len(src_dataset)) self.datasets[split] = SortDataset( NestedDictionaryDataset( { "id": IdDataset(), "net_input": { "src_tokens": RightPadDataset( src_dataset, pad_idx=self.source_dictionary.pad(), ), "src_lengths": NumelDataset(src_dataset, reduce=False), }, "target": RightPadDataset( tgt_dataset, pad_idx=self.source_dictionary.pad(), ), "nsentences": NumSamplesDataset(), "ntokens": NumelDataset(src_dataset, reduce=True), }, sizes=[src_dataset.sizes], ), sort_order=[ shuffle, src_dataset.sizes, ], )
def load_dataset(self, split, epoch=0, combine=False, data_selector=None): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = self.args.data.split(':') assert len(paths) > 0 data_path = paths[epoch % len(paths)] split_path = os.path.join(data_path, split) dataset = data_utils.load_indexed_dataset( split_path, self.source_dictionary, self.args.dataset_impl, combine=combine, ) if dataset is None: raise FileNotFoundError('Dataset not found: {} ({})'.format(split, split_path)) # create continuous blocks of tokens dataset = TokenBlockDataset( dataset, dataset.sizes, self.args.tokens_per_sample - 1, # one less for <s> pad=self.source_dictionary.pad(), eos=self.source_dictionary.eos(), break_mode=self.args.sample_break_mode, ) print('| loaded {} batches from: {}'.format(len(dataset), split_path)) # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT) dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) # create masked input and targets if self.args.mask_whole_words: bpe = encoders.build_bpe(self.args) if bpe is not None: def is_beginning_of_word(i): if i < self.source_dictionary.nspecial: # special elements are always considered beginnings return True tok = self.source_dictionary[i] if tok.startswith('madeupword'): return True try: return bpe.is_beginning_of_word(tok) except ValueError: return True mask_whole_words = torch.ByteTensor(list( map(is_beginning_of_word, range(len(self.source_dictionary))) )) else: mask_whole_words = None src_dataset, tgt_dataset = MaskTokensDataset.apply_mask( dataset, self.source_dictionary, pad_idx=self.source_dictionary.pad(), mask_idx=self.mask_idx, seed=self.args.seed, mask_prob=self.args.mask_prob, leave_unmasked_prob=self.args.leave_unmasked_prob, random_token_prob=self.args.random_token_prob, freq_weighted_replacement=self.args.freq_weighted_replacement, mask_whole_words=mask_whole_words, ) with data_utils.numpy_seed(self.args.seed + epoch): shuffle = np.random.permutation(len(src_dataset)) self.datasets[split] = SortDataset( NestedDictionaryDataset( { 'id': IdDataset(), 'net_input': { 'src_tokens': PadDataset( src_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, ), 'src_lengths': NumelDataset(src_dataset, reduce=False), }, 'target': PadDataset( tgt_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, ), 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_dataset, reduce=True), }, sizes=[src_dataset.sizes], ), sort_order=[ shuffle, src_dataset.sizes, ], )
def load_dataset(self, split, epoch=0, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = self.args.data.split(':') assert len(paths) > 0 data_path = paths[epoch % len(paths)] split_path = os.path.join(data_path, split) dataset = data_utils.load_indexed_dataset( split_path, self.source_dictionary, self.args.dataset_impl, combine=combine, ) if dataset is None: raise FileNotFoundError('Dataset not found: {} ({})'.format( split, split_path)) # create continuous blocks of tokens dataset = TokenBlockDataset( dataset, dataset.sizes, self.args.tokens_per_sample - 1, # one less for <s> pad=self.source_dictionary.pad(), eos=self.source_dictionary.eos(), break_mode=self.args.sample_break_mode, ) print('| loaded {} blocks from: {}'.format(len(dataset), split_path)) # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT) dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) # create masked input and targets mask_whole_words = get_whole_word_mask(self.args, self.source_dictionary) \ if self.args.mask_whole_words else None src_dataset, tgt_dataset = MaskTokensDataset.apply_mask( dataset, self.source_dictionary, pad_idx=self.source_dictionary.pad(), mask_idx=self.mask_idx, seed=self.args.seed, mask_prob=self.args.mask_prob, leave_unmasked_prob=self.args.leave_unmasked_prob, random_token_prob=self.args.random_token_prob, freq_weighted_replacement=self.args.freq_weighted_replacement, mask_whole_words=mask_whole_words, ) with data_utils.numpy_seed(self.args.seed + epoch): shuffle = np.random.permutation(len(src_dataset)) self.datasets[split] = SortDataset( NestedDictionaryDataset( { 'id': IdDataset(), 'net_input': { 'src_tokens': PadDataset( src_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, ), 'src_lengths': NumelDataset(src_dataset, reduce=False), }, 'target': PadDataset( tgt_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, ), 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_dataset, reduce=True), }, sizes=[src_dataset.sizes], ), sort_order=[ shuffle, src_dataset.sizes, ], )
def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = utils.split_paths(self.args.data) assert len(paths) > 0 data_path = paths[(epoch - 1) % len(paths)] languages = sorted(name for name in os.listdir(data_path) if os.path.isdir(os.path.join(data_path, name))) logger.info("Training on {0} languages: {1}".format( len(languages), languages)) logger.info("Language to id mapping: ", {lang: id for id, lang in enumerate(languages)}) mask_whole_words = self._get_whole_word_mask() lang_datasets = [] for lang_id, language in enumerate(languages): split_path = os.path.join(data_path, language, split) dataset = data_utils.load_indexed_dataset( split_path, self.source_dictionary, self.args.dataset_impl, combine=combine, ) if dataset is None: raise FileNotFoundError('Dataset not found: {} ({})'.format( split, split_path)) # create continuous blocks of tokens dataset = TokenBlockDataset( dataset, dataset.sizes, self.args.tokens_per_sample - 1, # one less for <s> pad=self.source_dictionary.pad(), eos=self.source_dictionary.eos(), break_mode=self.args.sample_break_mode, ) logger.info('loaded {} blocks from: {}'.format( len(dataset), split_path)) # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT) dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) src_dataset, tgt_dataset = MaskTokensDataset.apply_mask( dataset, self.source_dictionary, pad_idx=self.source_dictionary.pad(), mask_idx=self.mask_idx, seed=self.args.seed, mask_prob=self.args.mask_prob, leave_unmasked_prob=self.args.leave_unmasked_prob, random_token_prob=self.args.random_token_prob, freq_weighted_replacement=self.args.freq_weighted_replacement, mask_whole_words=mask_whole_words, ) lang_dataset = NestedDictionaryDataset( { 'net_input': { 'src_tokens': PadDataset( src_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, ), 'src_lengths': NumelDataset(src_dataset, reduce=False), }, 'target': PadDataset( tgt_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, ), 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_dataset, reduce=True), 'lang_id': RawLabelDataset([lang_id] * src_dataset.sizes.shape[0]), }, sizes=[src_dataset.sizes], ) lang_datasets.append(lang_dataset) dataset_lengths = np.array( [len(d) for d in lang_datasets], dtype=float, ) logger.info('loaded total {} blocks for all languages'.format( dataset_lengths.sum(), )) if split == self.args.train_subset: # For train subset, additionally up or down sample languages. sample_probs = self._get_sample_prob(dataset_lengths) logger.info( "Sample probability by language: ", { lang: "{0:.4f}".format(sample_probs[id]) for id, lang in enumerate(languages) }) size_ratio = (sample_probs * dataset_lengths.sum()) / dataset_lengths logger.info( "Up/Down Sampling ratio by language: ", { lang: "{0:.2f}".format(size_ratio[id]) for id, lang in enumerate(languages) }) resampled_lang_datasets = [ ResamplingDataset( lang_datasets[i], size_ratio=size_ratio[i], seed=self.args.seed, epoch=epoch, replace=size_ratio[i] >= 1.0, ) for i, d in enumerate(lang_datasets) ] dataset = ConcatDataset(resampled_lang_datasets) else: dataset = ConcatDataset(lang_datasets) lang_splits = [split] for lang_id, lang_dataset in enumerate(lang_datasets): split_name = split + '_' + languages[lang_id] lang_splits.append(split_name) self.datasets[split_name] = lang_dataset # [TODO]: This is hacky for now to print validation ppl for each # language individually. Maybe need task API changes to allow it # in more generic ways. if split in self.args.valid_subset: self.args.valid_subset = self.args.valid_subset.replace( split, ','.join(lang_splits)) with data_utils.numpy_seed(self.args.seed + epoch): shuffle = np.random.permutation(len(dataset)) self.datasets[split] = SortDataset( dataset, sort_order=[ shuffle, dataset.sizes, ], )
def load_dataset(self, split, epoch=0, combine=False): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ def get_path(type, split): return os.path.join(self.args.data, type, split) def make_dataset(type, dictionary): split_path = get_path(type, split) dataset = data_utils.load_indexed_dataset( split_path, self.source_dictionary, self.args.dataset_impl, combine=combine, ) return dataset input0 = make_dataset('input0', self.source_dictionary) assert input0 is not None, 'could not find dataset: {}'.format( get_path(type, split)) input1 = make_dataset('input1', self.source_dictionary) input0 = PrependTokenDataset(input0, self.source_dictionary.bos()) if input1 is None: src_tokens = input0 else: input1 = PrependTokenDataset(input1, self.source_dictionary.eos()) src_tokens = ConcatSentencesDataset(input0, input1) src_tokens = TruncateDataset(src_tokens, self.args.max_positions) assert not self.args.mask_whole_words src_dataset, tgt_dataset = MaskTokensDataset.apply_mask( src_tokens, self.source_dictionary, pad_idx=self.source_dictionary.pad(), mask_idx=self.mask_idx, seed=self.args.seed, mask_prob=self.args.mask_prob, leave_unmasked_prob=self.args.leave_unmasked_prob, random_token_prob=self.args.random_token_prob, freq_weighted_replacement=self.args.freq_weighted_replacement, mask_whole_words=None, ) with data_utils.numpy_seed(self.args.seed + epoch): shuffle = np.random.permutation(len(src_dataset)) self.datasets[split] = SortDataset( NestedDictionaryDataset( { 'id': IdDataset(), 'net_input': { 'src_tokens': PadToLenDataset(src_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, pad_len=self.args.max_positions), 'src_lengths': NumelDataset(src_dataset, reduce=False), }, 'target': PadToLenDataset(tgt_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, pad_len=self.args.max_positions), 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_dataset, reduce=True), }, sizes=[src_dataset.sizes], ), sort_order=[ shuffle, src_dataset.sizes, ], )
def load_langpair_dataset( data_path, split, src, src_dict, tgt, tgt_dict, combine, dataset_impl, upsample_primary, left_pad_source, left_pad_target, max_source_positions, max_target_positions, prepend_bos=False, load_alignments=False, truncate_source=False, append_source_id=False, num_buckets=0, shuffle=True, pad_to_multiple=1, # Masked LM parameters. mask_idx: int = 0, seed: int = 1, mask_prob: float = 0.01, leave_unmasked_prob: float = 0.0, random_token_prob: float = 0.0, freq_weighted_replacement: bool = False, mask_whole_words: torch.Tensor = None, mask_multiple_length: int = 1, mask_stdev: float = 0.0, ): def split_exists(split, src, tgt, lang, data_path): filename = os.path.join( data_path, "{}.{}-{}.{}".format(split, src, tgt, lang)) return indexed_dataset.dataset_exists(filename, impl=dataset_impl) src_datasets = [] tgt_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else "") # infer langcode if split_exists(split_k, src, tgt, src, data_path): prefix = os.path.join( data_path, "{}.{}-{}.".format(split_k, src, tgt)) elif split_exists(split_k, tgt, src, src, data_path): prefix = os.path.join( data_path, "{}.{}-{}.".format(split_k, tgt, src)) else: if k > 0: break else: raise FileNotFoundError( "Dataset not found: {} ({})".format(split, data_path) ) src_dataset = data_utils.load_indexed_dataset( prefix + src, src_dict, dataset_impl ) if truncate_source: src_dataset = AppendTokenDataset( TruncateDataset( StripTokenDataset(src_dataset, src_dict.eos()), max_source_positions - 1, ), src_dict.eos(), ) src_datasets.append(src_dataset) tgt_dataset = data_utils.load_indexed_dataset( prefix + tgt, tgt_dict, dataset_impl ) if tgt_dataset is not None: tgt_datasets.append(tgt_dataset) logger.info( "{} {} {}-{} {} examples".format( data_path, split_k, src, tgt, len(src_datasets[-1]) ) ) if not combine: break # logger.info('Length of Source DataSets: {}'.format(len(src_datasets))) assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0 if len(src_datasets) == 1: src_dataset = src_datasets[0] tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None else: sample_ratios = [1] * len(src_datasets) sample_ratios[0] = upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) if len(tgt_datasets) > 0: tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) else: tgt_dataset = None if prepend_bos: assert hasattr(src_dict, "bos_index") and hasattr( tgt_dict, "bos_index") src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) if tgt_dataset is not None: tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) eos = None if append_source_id: src_dataset = AppendTokenDataset( src_dataset, src_dict.index("[{}]".format(src)) ) if tgt_dataset is not None: tgt_dataset = AppendTokenDataset( tgt_dataset, tgt_dict.index("[{}]".format(tgt)) ) eos = tgt_dict.index("[{}]".format(tgt)) align_dataset = None if load_alignments: align_path = os.path.join( data_path, "{}.align.{}-{}".format(split, src, tgt)) if indexed_dataset.dataset_exists(align_path, impl=dataset_impl): align_dataset = data_utils.load_indexed_dataset( align_path, None, dataset_impl ) tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None # mask source dataset. src_dataset, masked_src_dataset = MaskTokensDataset.apply_mask( src_dataset, src_dict, pad_idx=src_dict.pad(), mask_idx=mask_idx, seed=seed, mask_prob=mask_prob, leave_unmasked_prob=leave_unmasked_prob, random_token_prob=random_token_prob, freq_weighted_replacement=freq_weighted_replacement, mask_whole_words=mask_whole_words, mask_multiple_length=mask_multiple_length, mask_stdev=mask_stdev, ) # Print samples. # if split == 'valid': # print(src_dataset[1]) # print(masked_src_dataset[1]) return LanguagePairDataset( src_dataset, src_dataset.sizes, src_dict, tgt_dataset, tgt_dataset_sizes, tgt_dict, # for Mask LM loss calculation. masked_src_dataset, masked_src_dataset.sizes, left_pad_source=left_pad_source, left_pad_target=left_pad_target, align_dataset=align_dataset, eos=eos, num_buckets=num_buckets, shuffle=shuffle, pad_to_multiple=pad_to_multiple, )
def load_dataset(self, split, epoch=0, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = self.args.data.split(":") assert len(paths) > 0 data_path = paths[epoch % len(paths)] split_path = os.path.join(data_path, split) dataset = MMapIndexedDataset(split_path) if dataset is None: raise FileNotFoundError("Dataset not found: {} ({})".format( split, split_path)) # create continuous blocks of tokens dataset = TokenBlockDataset( dataset, dataset.sizes, self.args.tokens_per_sample - 1, # one less for <s> pad=self.source_dictionary.pad(), eos=self.source_dictionary.eos(), break_mode=self.args.sample_break_mode, ) # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT) dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) # create masked input and targets if self.args.mask_whole_words: space_idx, bos_idx = ( self.source_dictionary.indexer[" "], self.source_dictionary.bos(), ) mask_whole_words = torch.ByteTensor( list( map( lambda idx: idx in [space_idx, bos_idx], range(len(self.source_dictionary)), ))) else: mask_whole_words = None src_dataset, tgt_dataset = MaskTokensDataset.apply_mask( dataset, self.source_dictionary, pad_idx=self.source_dictionary.pad(), mask_idx=self.mask_idx, seed=self.args.seed, mask_prob=self.args.mask_prob, leave_unmasked_prob=self.args.leave_unmasked_prob, random_token_prob=self.args.random_token_prob, freq_weighted_replacement=self.args.freq_weighted_replacement, mask_whole_words=mask_whole_words, ) with data_utils.numpy_seed(self.args.seed + epoch): shuffle = np.random.permutation(len(src_dataset)) self.datasets[split] = SortDataset( NestedDictionaryDataset( { "id": IdDataset(), "net_input": { "src_tokens": PadDataset( src_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, ), "src_lengths": NumelDataset(src_dataset, reduce=False), }, "target": PadDataset( tgt_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, ), "nsentences": NumSamplesDataset(), "ntokens": NumelDataset(src_dataset, reduce=True), }, sizes=[src_dataset.sizes], ), sort_order=[ shuffle, src_dataset.sizes, ], )
def get_batch(dataset, mydict, batch_size, decode_dataset=None, rerank=None, mode='train', dist=False, cudaid=0, size=1, start_pos=None): src_dataset, tgt_dataset = MaskTokensDataset.apply_mask( dataset, mydict, pad_idx=mydict.pad(), mask_idx=mydict.index('<mask>'), seed=1, mask_prob=0.15, leave_unmasked_prob=0.1, random_token_prob=0.1, freq_weighted_replacement=False, mask_whole_words=None, ) i = 0 # for item in src_dataset: # print(item) # break #print(src_dataset.__getitem__(0),src_dataset.__getitem__(0).shape) src_dataset = PadDataset( src_dataset, pad_idx=mydict.pad(), left_pad=False, ) #print(src_dataset.__getitem__(0),src_dataset.__getitem__(0).shape) tgt_dataset = PadDataset( tgt_dataset, pad_idx=mydict.pad(), left_pad=False, ) #print('???',tgt_dataset.__getitem__(0)) data_size = len(dataset) #data_size=len(dataset.sizes) print('size: ', data_size, len(src_dataset.sizes)) assert len(dataset) == len(decode_dataset) assert len(rerank) == len(dataset) #assert 1==0 valid_size = int(data_size * 0.002) if mode == 'train': data_size = data_size - valid_size rerank = rerank[:data_size] elif mode == 'valid': data_size = valid_size rerank = rerank[-data_size:] else: assert 1 == 0 if dist: assert size != 1 dist_size = int(len(rerank) / size) + 1 rerank = rerank[cudaid * dist_size:(cudaid + 1) * dist_size] print('dist_size: ', dist_size, cudaid) data_size = len(rerank) if start_pos != None: i = start_pos index = 0 length = 0 while i < data_size: token_list = [] mask_label_list = [] decode_label_list = [] #for x in range(32):#感觉这里想拼起来不是那么容易,考虑一下 index = 0 max_len_cur1 = 0 max_len_cur2 = 0 length = 0 # while i<data_size and length<batch_size*512: #index<batch_size: # token_list.append( list(np.array(src_dataset.__getitem__(rerank[i])))) # mask_label_list.append( list(np.array( tgt_dataset.__getitem__(rerank[i])))) # decode_label_list.append(list(np.array( decode_dataset.__getitem__(rerank[i])))) # #decode_label_list.append( list(np.array(src_dataset.__getitem__(i)))) # if len(src_dataset.__getitem__(rerank[i]))>max_len_cur1: # max_len_cur1=len(src_dataset.__getitem__(rerank[i])) # if len(decode_dataset.__getitem__(rerank[i]))>max_len_cur2: # max_len_cur2=len(decode_dataset.__getitem__(rerank[i])) # length=max_len_cur1*index # if max_len_cur1>512: # max_len_cur1=512 # if max_len_cur2>512: # max_len_cur2=512 # i+=1 # index+=1 while i < data_size and length <= batch_size * 512: #index<batch_size: old_max1 = max_len_cur1 old_max2 = max_len_cur2 if len(src_dataset.__getitem__(rerank[i])) > max_len_cur1: max_len_cur1 = len(src_dataset.__getitem__(rerank[i])) if len(decode_dataset.__getitem__(rerank[i])) > max_len_cur2: max_len_cur2 = len(decode_dataset.__getitem__(rerank[i])) if max_len_cur1 > 512: max_len_cur1 = 512 if max_len_cur2 > 512: max_len_cur2 = 512 index += 1 if max_len_cur1 > max_len_cur2: length = max_len_cur1 * index else: length = max_len_cur2 * index if length > batch_size * 512: max_len_cur1 = old_max1 max_len_cur2 = old_max2 #print('???',batch_size,max_len_cur1,max_len_cur2,index,len(token_list)) break # if cudaid==1: # print('...',max_len_cur1,max_len_cur2,index,length) token_list.append( list(np.array(src_dataset.__getitem__(rerank[i])))) mask_label_list.append( list(np.array(tgt_dataset.__getitem__(rerank[i])))) decode_label_list.append( list(np.array(decode_dataset.__getitem__(rerank[i])))) i += 1 #print(node_token_list[0],list(node_token_list[0])) #print([[ padding_node(node_token_list[0],max_node_len,mydict.pad()) ]]) # node_token_list=torch.LongTensor([ padding_node(np.array(item),max_node_len,mydict['<pad>']) for item in node_token_list ] ) # node_mask_in_id=torch.LongTensor([ padding_node(np.array(item),max_node_len,mydict['<pad>']) for item in node_mask_in_id ] ) #print(token_list) # if cudaid==1: # print('???',batch_size,max_len_cur1,max_len_cur2,index,len(token_list),' cudaid: ',cudaid) token_list = [ padding(item, max_len=max_len_cur1, padding_idx=1) for item in token_list ] mask_label_list = [ padding(item, max_len=max_len_cur1, padding_idx=1) for item in mask_label_list ] decode_label_list = [ padding(item, max_len=max_len_cur2, padding_idx=1) for item in decode_label_list ] token_list = torch.LongTensor(token_list) mask_label_list = torch.LongTensor(mask_label_list) decode_label_list = torch.LongTensor(decode_label_list) #print(node_token_list[0],node_mask_in_id[0]) yield (token_list, mask_label_list, decode_label_list)
def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = utils.split_paths(self.args.data) assert len(paths) > 0 src_tokens = {} tgt_tokens = {} tgt_values = {} for field in configs.fields: split_path = os.path.join(self.args.data, field, split) dataset = data_utils.load_indexed_dataset( split_path, self.source_dictionary[field], self.args.dataset_impl, combine=combine, ) if dataset is None: raise FileNotFoundError( "Dataset not found: {} ({})".format(split, split_path) ) dataset = maybe_shorten_dataset( dataset, split, self.args.shorten_data_split_list, self.args.shorten_method, self.args.tokens_per_sample, self.args.seed, ) # create continuous blocks of tokens dataset = TokenBlockDataset( dataset, dataset.sizes, self.args.tokens_per_sample - 1, # one less for <s> pad=self.source_dictionary[field].pad(), eos=self.source_dictionary[field].eos(), break_mode=self.args.sample_break_mode, ) logger.info("loaded {} blocks from: {}".format(len(dataset), split_path)) # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT) dataset = PrependTokenDataset(dataset, self.source_dictionary[field].bos()) if field == configs.static_field: src_dataset_code, tgt_dataset_code = MaskTokensDataset.apply_mask( dataset, self.source_dictionary[field], pad_idx=self.source_dictionary[field].pad(), mask_idx=self.mask_idx_dict[field], seed=self.args.seed, mask_prob=self.args.mask_prob, leave_unmasked_prob=self.args.leave_unmasked_prob, random_token_prob=self.args.random_token_prob, freq_weighted_replacement=self.args.freq_weighted_replacement, ) src_tokens[field] = RightPadDataset( src_dataset_code, pad_idx=self.source_dictionary[field].pad() ) tgt_tokens[field] = RightPadDataset( tgt_dataset_code, pad_idx=self.source_dictionary[field].pad() ) elif field in configs.byte_fields: src_dataset_value, tgt_dataset_value = MaskValuesDataset.apply_mask( dataset, self.source_dictionary[field], pad_idx=self.source_dictionary[field].pad(), mask_idx=self.mask_idx_dict[field], seed=self.args.seed, mask_prob=self.args.mask_prob, leave_unmasked_prob=self.args.leave_unmasked_prob, random_token_prob=self.args.random_token_prob, freq_weighted_replacement=self.args.freq_weighted_replacement, ) src_tokens[field] = RightPadDataset( src_dataset_value, pad_idx=self.source_dictionary[field].pad() ) # dummy tokens are treated as 1 # TODO: assert there should not be any dummy tokens here tgt_values[field] = BytevalueDataset(tgt_dataset_value, self.source_dictionary[field]) else: src_tokens[field] = RightPadDataset( dataset, pad_idx=self.source_dictionary[field].pad() ) with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(src_dataset_code)) self.datasets[split] = SortDataset( NestedDictionaryDataset( { "id": IdDataset(), "net_input": { "src_tokens": src_tokens, "src_lengths": NumelDataset(src_dataset_code, reduce=False), }, "target": { "tgt_tokens": tgt_tokens, "tgt_values": tgt_values }, "nsentences": NumSamplesDataset(), "ntokens": NumelDataset(src_dataset_code, reduce=True), }, sizes=[src_dataset_code.sizes], ), sort_order=[ shuffle, src_dataset_code.sizes, ], )
def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = utils.split_paths(self.args.data) assert len(paths) > 0 data_path = paths[(epoch - 1) % len(paths)] split_path = os.path.join(data_path, split) dataset = data_utils.load_indexed_dataset( split_path, self.source_dictionary, self.args.dataset_impl, combine=combine, ) if dataset is None: raise FileNotFoundError('Dataset not found: {} ({})'.format(split, split_path)) dataset = maybe_shorten_dataset( dataset, split, self.args.shorten_data_split_list, self.args.shorten_method, self.args.tokens_per_sample, self.args.seed, ) # create continuous blocks of tokens dataset = TokenBlockDataset( dataset, dataset.sizes, self.args.tokens_per_sample, pad=self.source_dictionary.pad(), eos=self.source_dictionary.eos(), break_mode=self.args.sample_break_mode, ) logger.info('loaded {} blocks from: {}'.format(len(dataset), split_path)) # remove tail dataset = RemoveTailDataset(dataset) # create masked input and targets mask_whole_words = get_whole_word_mask(self.args, self.source_dictionary) \ if self.args.mask_whole_words else None src_dataset, tgt_dataset = MaskTokensDataset.apply_mask( dataset, self.source_dictionary, pad_idx=self.source_dictionary.pad(), mask_idx=self.mask_idx, seed=self.args.seed, mask_prob=self.args.mask_prob, leave_unmasked_prob=self.args.leave_unmasked_prob, random_token_prob=self.args.random_token_prob, freq_weighted_replacement=self.args.freq_weighted_replacement, mask_whole_words=mask_whole_words, ) with data_utils.numpy_seed(self.args.seed + epoch): shuffle = np.random.permutation(len(src_dataset)) self.datasets[split] = SortDataset( NestedDictionaryDataset( { 'id': IdDataset(), 'net_input': { 'src_tokens': RightPadDataset( src_dataset, pad_idx=self.source_dictionary.pad(), ), 'src_lengths': NumelDataset(src_dataset, reduce=False), }, 'target': RightPadDataset( tgt_dataset, pad_idx=self.source_dictionary.pad(), ), 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_dataset, reduce=True), }, sizes=[src_dataset.sizes], ), sort_order=[ shuffle, src_dataset.sizes, ], )
def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ dataset = self._load_dataset_split(split, epoch, combine) # create masked input and targets mask_whole_words = (get_whole_word_mask(self.args, self.source_dictionary) if self.cfg.mask_whole_words else None) src_dataset, tgt_dataset = MaskTokensDataset.apply_mask( dataset, self.source_dictionary, pad_idx=self.source_dictionary.pad(), mask_idx=self.mask_idx, seed=self.cfg.seed, mask_prob=self.cfg.mask_prob, leave_unmasked_prob=self.cfg.leave_unmasked_prob, random_token_prob=self.cfg.random_token_prob, freq_weighted_replacement=self.cfg.freq_weighted_replacement, mask_whole_words=mask_whole_words, mask_multiple_length=self.cfg.mask_multiple_length, mask_stdev=self.cfg.mask_stdev, ) with data_utils.numpy_seed(self.cfg.seed): shuffle = np.random.permutation(len(src_dataset)) target_dataset = RightPadDataset( tgt_dataset, pad_idx=self.source_dictionary.pad(), ) input_dict = { "src_tokens": RightPadDataset( src_dataset, pad_idx=self.source_dictionary.pad(), ), "src_lengths": NumelDataset(src_dataset, reduce=False), } if self.cfg.include_target_tokens: input_dict["target_tokens"] = target_dataset self.datasets[split] = SortDataset( NestedDictionaryDataset( { "id": IdDataset(), "net_input": input_dict, "target": target_dataset, "nsentences": NumSamplesDataset(), "ntokens": NumelDataset(src_dataset, reduce=True), }, sizes=[src_dataset.sizes], ), sort_order=[ shuffle, src_dataset.sizes, ], )
def load_dataset(self, split, combine=False, **kwargs): """Load a given dataset split (e.g., train, valid, test).""" def get_path(type, split): return os.path.join(self.args.data, type, split) def make_dataset(type, dictionary): split_path = get_path(type, split) dataset = data_utils.load_indexed_dataset( split_path, dictionary, self.args.dataset_impl, combine=combine, ) return dataset # input0 is source, input1 is synthetic target, input2 is reference input0 = make_dataset(self.args.input0, self.source_dictionary) assert input0 is not None, 'could not find dataset: {}'.format( get_path(type, split)) input1 = make_dataset(self.args.input1, self.source_dictionary) if self.args.init_token is not None: input0 = PrependTokenDataset(input0, self.args.init_token) if self.args.input2 is not None: input2 = make_dataset(self.args.input2, self.source_dictionary) if self.args.input2 is not None and self.add_ref_prob > 0 and split != 'valid': input3 = PrependTokenDataset(input2, self.args.separator_token) else: input3 = None if input1 is None: src_tokens = input0 else: if self.args.separator_token is not None: input1 = PrependTokenDataset(input1, self.args.separator_token) if self.args.input2 is not None and self.add_ref_prob > 0. and split != 'valid': src_tokens = ConcatSentencesDataset( input0, input3, input1, add_ref_prob=self.add_ref_prob, drop_ref_rate=self.args.dropout_ref, pad_idx=self.source_dictionary.pad(), eos_idx=self.source_dictionary.eos(), bos_idx=self.source_dictionary.bos()) else: src_tokens = ConcatSentencesDataset(input0, input1) with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(src_tokens)) if self.args.truncate_sequence: src_tokens = TruncateDataset(src_tokens, self.args.max_positions) if self.args.input2 is not None and self.args.add_tran_loss: # create masked input and targets mask_whole_words = get_whole_word_mask(self.args, self.source_dictionary) \ if self.args.mask_whole_words else None ref_dataset, ref_target_dataset = MaskTokensDataset.apply_mask( input2, self.source_dictionary, pad_idx=self.source_dictionary.pad(), mask_idx=self.mask_idx, seed=self.args.seed, mask_prob=self.args.mask_prob, leave_unmasked_prob=self.args.leave_unmasked_prob, random_token_prob=self.args.random_token_prob, freq_weighted_replacement=self.args.freq_weighted_replacement, mask_whole_words=mask_whole_words, ) if self.args.separator_token is not None: input2 = PrependTokenDataset(ref_dataset, self.args.separator_token) parallel_src_tokens = ConcatSentencesDataset(input0, input2) if self.args.truncate_sequence: parallel_src_tokens = TruncateDataset(parallel_src_tokens, self.args.max_positions) dataset = { 'id': IdDataset(), 'net_input': { 'src_tokens': RightPadDataset( src_tokens, pad_idx=self.source_dictionary.pad(), ), 'src_lengths': NumelDataset(src_tokens, reduce=False), }, 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_tokens, reduce=True), } if self.args.input2 is not None and self.args.add_tran_loss: dataset['net_input']['parallel_src_tokens'] = RightPadDataset( parallel_src_tokens, pad_idx=self.source_dictionary.pad(), ) if self.args.add_prev_output_tokens: prev_tokens_dataset = RightPadDataset( RollDataset(src_tokens, 1), pad_idx=self.dictionary.pad(), ) dataset['net_input'].update( prev_output_tokens=prev_tokens_dataset, ) if not self.args.regression_target: label_dataset = make_dataset('label', self.label_dictionary) if label_dataset is not None: dataset.update(target=OffsetTokensDataset( StripTokenDataset( label_dataset, id_to_strip=self.label_dictionary.eos(), ), offset=-self.label_dictionary.nspecial, )) if self.args.input2 is not None and self.args.add_tran_loss: # used as translation target when calculating loss dataset.update(parallel_target=RightPadDataset( ref_target_dataset, pad_idx=self.source_dictionary.pad(), )) else: label_path = "{0}.label".format(get_path('label', split)) if os.path.exists(label_path): def parse_regression_target(i, line): values = line.split() assert len(values) == self.args.num_classes, \ f'expected num_classes={self.args.num_classes} regression target values on line {i}, found: "{line}"' return [float(x) for x in values] dataset.update(target=RawLabelDataset([ parse_regression_target(i, line.strip()) for i, line in enumerate(open(label_path).readlines()) ])) nested_dataset = NestedDictionaryDataset( dataset, sizes=[src_tokens.sizes], all_sizes=src_tokens.all_sizes if self.args.add_target_num_tokens else None, padding_idx=self.source_dictionary.pad(), add_ref_prob=self.add_ref_prob if split != 'valid' else 0., ) if self.args.no_shuffle: dataset = nested_dataset else: dataset = SortDataset( nested_dataset, # shuffle sort_order=[shuffle], ) logger.info("Loaded {0} with #samples: {1}".format( split, len(dataset))) self.datasets[split] = dataset return self.datasets[split]
def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = utils.split_paths(self.args.data) assert len(paths) > 0 data_path = paths[(epoch - 1) % len(paths)] split_path = os.path.join(data_path, split) #+ '.bpe' dataset = data_utils.load_indexed_dataset( split_path, self.source_dictionary, self.args.dataset_impl, combine=combine, ) if dataset is None: raise FileNotFoundError('Dataset not found: {} ({})'.format(split, split_path)) # create continuous blocks of tokens dataset = TokenBlockDataset( dataset, dataset.sizes, self.args.tokens_per_sample - 1, # one less for <s> pad=self.source_dictionary.pad(), eos=self.source_dictionary.eos(), break_mode=self.args.sample_break_mode, ) logger.info('loaded {} blocks from: {}'.format(len(dataset), split_path)) # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT) dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) # create masked input and targets mask_whole_words = get_whole_word_mask(self.args, self.source_dictionary) \ if self.args.mask_whole_words else None src_dataset, tgt_dataset = MaskTokensDataset.apply_mask( dataset, self.source_dictionary, pad_idx=self.source_dictionary.pad(), mask_idx=self.mask_idx, seed=self.args.seed, mask_prob=self.args.mask_prob, leave_unmasked_prob=self.args.leave_unmasked_prob, random_token_prob=self.args.random_token_prob, freq_weighted_replacement=self.args.freq_weighted_replacement, mask_whole_words=mask_whole_words, ) with data_utils.numpy_seed(self.args.seed + epoch): shuffle = np.random.permutation(len(src_dataset)) # load counts thresh = 100 with open(split_path + '.counts') as count_file: lines = [line.rstrip() for line in count_file] counts = [line.split(' ') for line in lines] for i, count in enumerate(counts): count = [int(el) for el in count] counts[i] = [el if el < thresh else thresh for el in count] counts[i] = torch.LongTensor(np.concatenate([[0],counts[i],[0]])) # load embeddings if not self.args.input_format=='tokens': embs = torch.load(split_path + '.features') # mask counts and embeddings for i, data in enumerate(src_dataset): counts[i] = counts[i] * (data != self.mask_idx) embs[i] = embs[i] * (data != self.mask_idx)[1:-1, None] self.datasets[split] = SortDataset( NestedDictionaryDataset( { 'id': IdDataset(), 'net_input': { 'src_tokens': PadDataset( src_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, ), 'src_counts': PadDataset( counts, pad_idx=0, left_pad=False, ), 'src_embs': EmbeddingDataset( embs, pad_idx=0, left_pad=False, ) if not self.args.input_format=='tokens' else None, 'src_lengths': NumelDataset(src_dataset, reduce=False), }, 'target': PadDataset( tgt_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, ), 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_dataset, reduce=True), }, sizes=[src_dataset.sizes], ), sort_order=[ shuffle, src_dataset.sizes, ], )