def load_tokens_dataset( data_path, split, srcs, src_dicts, tgts, tgt_dicts, dataset_impl, max_srcs=None, max_tgts=None, shuffle=False, sample_neg=False, **kwargs, ): langs = kwargs.get('langs', None) # source langs src_paths = OrderedDict() for src in srcs: if langs is None: src_paths[src] = [os.path.join(data_path, f'{split}.{src}')] else: src_paths[src] = [os.path.join(data_path, f'{split}.{lang}.{src}') for lang in langs] src_datasets, src_sizes = OrderedDict(), OrderedDict() for idx, (src, paths) in enumerate(src_paths.items()): datasets = _load_dataset(paths, dataset_impl) if max_srcs is not None: datasets = [TruncateDataset(ds, max_srcs[idx]) for ds in datasets] datasets = ConcatDataset(datasets, labels=langs) src_datasets[src] = datasets src_sizes[src] = datasets.sizes LOGGER.info('loaded {} modality(ies) from: {}'.format(len(src_datasets), src_paths)) # target langs tgt_paths = OrderedDict() for tgt in tgts: if langs is None: tgt_paths[tgt] = [os.path.join(data_path, f'{split}.{tgt}')] else: tgt_paths[tgt] = [os.path.join(data_path, f'{split}.{lang}.{tgt}') for lang in langs] tgt_datasets, tgt_sizes = OrderedDict(), OrderedDict() for idx, (tgt, paths) in enumerate(tgt_paths.items()): datasets = _load_dataset(paths, dataset_impl) if max_tgts is not None: datasets = [TruncateDataset(ds, max_tgts[idx]) for ds in datasets] datasets = ConcatDataset(datasets, labels=langs) tgt_datasets[tgt] = datasets tgt_sizes[tgt] = datasets.sizes LOGGER.info('loaded {} modality(ies) from: {}'.format(len(tgt_datasets), tgt_paths)) return MultiModalitiesRetrievalDataset( src_datasets, src_sizes, src_dicts, tgt_datasets, tgt_sizes, tgt_dicts, max_source_positions=max_srcs, max_target_positions=max_tgts, fraction_using_func_name=kwargs.get('fraction_using_func_name', None), shuffle=shuffle, labels=langs, sample_neg=sample_neg, )
def load_inference_token_dataset( data_paths, split, tgt, tgt_dict, dataset_impl, attrs=None, attr_dict=None, attrs_mapping=None, reversed_attrs_mapping=None, truncate_target=False, max_target_positions=None, ): # load tokens tgt_dataset = [] for path in data_paths: tgt_path = os.path.join(path, '{}.{}'.format(split, tgt)) tgt_dataset.append(_load_dataset(tgt_path, dataset_impl)) tgt_dataset = ConcatDataset(tgt_dataset) if truncate_target: tgt_dataset = TruncateDataset(tgt_dataset, max_target_positions) LOGGER.info('Truncate dataset into max length: {}'.format( max_target_positions)) LOGGER.info('loaded {} examples from: {}'.format(len(tgt_dataset), data_paths)) return CompletionDataset( tgt_dataset, tgt_dataset.sizes, tgt_dict, extends=None, attrs=None, attr_indices=None, attr_dict=None, attrs_mapping=None, reversed_attrs_mapping=None, max_target_positions=max_target_positions, )
def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ # paths = self.args.data.split(':') paths = utils.split_paths(self.args['task']['data']) assert len(paths) > 0 data_path = paths[(epoch - 1) % len(paths)] # split_path = os.path.join(data_path, split) if self.langs is None: languages = sorted([ name for name in os.listdir(data_path) if os.path.isdir(os.path.join(data_path, name)) ]) else: languages = self.langs # .split(',') # for name in languages: # assert os.path.exists(os.path.join(data_path, name)), FileNotFoundError(os.path.join(data_path, name)) LOGGER.info("| Training on {0} languages: {1}".format( len(languages), languages)) LOGGER.info("| Language to id mapping: ", {lang: id for id, lang in enumerate(languages)}) mask_whole_words = get_whole_word_mask(self.args, self.dictionary) lang_datasets = [] for language in languages: # split_path = os.path.join(data_path, language, split) if language == 'docstring': split_path = os.path.join(data_path, language, f"{split}.docstring.spm") else: split_path = os.path.join(data_path, language, f"{split}.code.spm") # split_path = os.path.join(data_path, language, f"{split}.spm.{language}") # dataset = data_utils.load_indexed_dataset( # split_path, # self.source_dictionary, # self.args['dataset']['dataset_impl'], # combine=combine, # ) dataset = load_lang_dataset_denoising( path=split_path, impl=self.args['dataset']['dataset_impl'], dict=self.source_dictionary) if dataset is None: raise FileNotFoundError('Dataset not found: {} ({})'.format( split, split_path)) dataset = AppendTokenDataset( TruncateDataset( StripTokenDataset(dataset, self.source_dictionary.eos()), self.args['task']['max_source_positions'] - 3), # <lang>, <bos>, <eos> token=self.source_dictionary.eos(), ) end_token = self.source_dictionary.index('[{}]'.format(language)) \ if self.args['task']['add_lang_token'] else self.source_dictionary.eos() # create continuous blocks of tokens dataset = TokenBlockDataset( dataset, dataset.sizes, self.args['task']['tokens_per_sample'] - 2, # one less for <s> and one for </s> pad=self.source_dictionary.pad(), eos=end_token, break_mode=self.args['task']['sample_break_mode'], document_sep_len=0, ) LOGGER.info('| loaded {} blocks from: {}'.format( len(dataset), split_path)) # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT) dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) dataset = AppendTokenDataset(dataset, end_token) lang_dataset = DenoisingDataset( dataset, dataset.sizes, self.dictionary, self.mask_idx, mask_whole_words, shuffle=self.args['dataset']['shuffle_instance'], seed=self.seed, args=self.args, eos=None if not self.args['task']['add_lang_token'] else self.source_dictionary.index('[{}]'.format(language)), ) lang_datasets.append(lang_dataset) dataset_lengths = np.array( [len(d) for d in lang_datasets], dtype=float, ) LOGGER.info('| loaded total {} blocks for all languages'.format( dataset_lengths.sum(), )) if split == self.args['dataset']['train_subset']: # For train subset, additionally up or down sample languages. sample_probs = self._get_sample_prob(dataset_lengths) LOGGER.info( "| Sample probability by language: ", { lang: "{0:.4f}".format(sample_probs[id]) for id, lang in enumerate(languages) }) size_ratio = (sample_probs * dataset_lengths.sum()) / dataset_lengths LOGGER.info( "| Up/Down Sampling ratio by language: ", { lang: "{0:.2f}".format(size_ratio[id]) for id, lang in enumerate(languages) }) resampled_lang_datasets = [ ResamplingDataset( lang_datasets[i], size_ratio=size_ratio[i], seed=self.args['common']['seed'], epoch=epoch, replace=size_ratio[i] >= 1.0, ) for i, d in enumerate(lang_datasets) ] dataset = ConcatDataset(resampled_lang_datasets, ) else: dataset = ConcatDataset(lang_datasets) lang_splits = [split] # for lang_id, lang_dataset in enumerate(lang_datasets): # split_name = split + '_' + languages[lang_id] # lang_splits.append(split_name) # self.datasets[split_name] = lang_dataset if split in self.args['dataset']['valid_subset']: self.args['dataset']['valid_subset'] = self.args['dataset'][ 'valid_subset'].replace(split, ','.join(lang_splits)) with data_utils.numpy_seed(self.args['common']['seed'] + epoch): shuffle = np.random.permutation(len(dataset)) self.datasets[split] = SortDataset( dataset, sort_order=[ shuffle, dataset.sizes, ], )
def load_tokens_dataset( data_path, split, src, src_dict, tgt, tgt_dict, dataset_impl, src_max_tokens=None, tgt_max_tokens=None, src_aux=None, src_aux_dict=None, tgt_aux=None, tgt_aux_dict=None, src_aux_max_tokens=None, tgt_aux_max_tokens=None, fraction_using_func_name=0., labels=None, shuffle=False, ): if len(labels) == 1 and os.path.exists(os.path.join(data_path, '{}.{}.idx'.format(split, src))): src_paths = [os.path.join(data_path, '{}.{}'.format(split, src))] else: src_paths = [os.path.join(data_path, '{}.{}.{}'.format(split, lbl, src)) for lbl in labels] src_datasets = _load_dataset(src_paths, dataset_impl) src_datasets = [TruncateDataset(ds, src_max_tokens) for ds in src_datasets] src_datasets = ConcatDataset(src_datasets, labels=labels) if len(labels) == 1 and os.path.exists(os.path.join(data_path, '{}.{}.idx'.format(split, tgt))): tgt_paths = [os.path.join(data_path, '{}.{}'.format(split, tgt))] else: tgt_paths = [os.path.join(data_path, '{}.{}.{}'.format(split, lbl, tgt)) for lbl in labels] tgt_datasets = _load_dataset(tgt_paths, dataset_impl) tgt_datasets = [TruncateDataset(ds, tgt_max_tokens) for ds in tgt_datasets] tgt_datasets = ConcatDataset(tgt_datasets, labels=labels) LOGGER.info('loaded {} examples from: {}'.format(len(src_datasets), src_paths)) LOGGER.info('loaded {} examples from: {}'.format(len(tgt_datasets), tgt_paths)) if split == 'train' and src_aux is not None: if len(labels) == 1 and os.path.exists(os.path.join(data_path, '{}.{}.idx'.format(split, src_aux))): src_aux_paths = [os.path.join(data_path, '{}.{}'.format(split, src_aux))] else: src_aux_paths = [os.path.join(data_path, '{}.{}.{}'.format(split, lbl, src_aux)) for lbl in labels] src_aux_datasets = _load_dataset(src_aux_paths, dataset_impl) if src_aux_max_tokens is None: src_aux_max_tokens = src_max_tokens src_aux_datasets = [TruncateDataset(ds, src_aux_max_tokens) for ds in src_aux_datasets] src_aux_datasets = ConcatDataset(src_aux_datasets, labels=labels) LOGGER.info('loaded {} examples from: {}'.format(len(src_aux_datasets), src_aux_paths)) else: src_aux_datasets = None if split == 'train' and tgt_aux is not None: if len(labels) == 1 and os.path.exists(os.path.join(data_path, '{}.{}.idx'.format(split, tgt_aux))): tgt_aux_paths = [os.path.join(data_path, '{}.{}'.format(split, tgt_aux))] else: tgt_aux_paths = [os.path.join(data_path, '{}.{}.{}'.format(split, lbl, tgt_aux)) for lbl in labels] tgt_aux_datasets = _load_dataset(tgt_aux_paths, dataset_impl) if tgt_aux_max_tokens is None: tgt_aux_max_tokens = tgt_max_tokens tgt_aux_datasets = [TruncateDataset(ds, tgt_aux_max_tokens) for ds in tgt_aux_datasets] tgt_aux_datasets = ConcatDataset(tgt_aux_datasets, labels=labels) LOGGER.info('loaded {} examples from: {}'.format(len(tgt_aux_datasets), tgt_aux_paths)) else: tgt_aux_datasets = None return HybridRetrievalDataset( src_datasets, src_datasets.sizes, src_dict, tgt_datasets, tgt_datasets.sizes, tgt_dict, max_source_positions=src_max_tokens, max_target_positions=tgt_max_tokens, src_aux=src_aux_datasets, src_aux_sizes=None if src_aux_datasets is None else src_aux_datasets.sizes, src_aux_dict=src_dict if src_aux_dict is None else src_aux_dict, tgt_aux=tgt_aux_datasets, tgt_aux_sizes=None if tgt_aux_datasets is None else tgt_aux_datasets.sizes, tgt_aux_dict=tgt_dict if tgt_aux_dict is None else tgt_aux_dict, fraction_using_func_name=fraction_using_func_name, shuffle=shuffle, labels=labels, )
def load_kd_token_dataset( data_path, split, tgt, tgt_dict, dataset_impl, attrs=None, attr_dict=None, attrs_mapping=None, reversed_attrs_mapping=None, truncate_target=False, max_target_positions=None, # lifelong learning prev_tasks=[], cur_task=None, sample_portion=None, # kd teacher_out_dir=None, topk=None, distill_topk=None, ): # load tokens tgt_path = os.path.join(data_path, cur_task, '{}.{}'.format(split, tgt)) tgt_dataset = [_load_dataset(tgt_path, dataset_impl)] # teacher output topk_ids_prefix = os.path.join(str.replace(teacher_out_dir, '*', cur_task), f'{split}.top{topk}_idx') topk_ids = [TeacherOutDataset(topk_ids_prefix)] topk_probs_prefix = os.path.join( str.replace(teacher_out_dir, '*', cur_task), f'{split}.top{topk}_prob') topk_probs = [TeacherOutDataset(topk_probs_prefix)] if len(prev_tasks ) > 0 and cur_task is not None and sample_portion is not None: sample_size_per_task = int( len(tgt_dataset[0]) * sample_portion // len(prev_tasks)) else: sample_size_per_task = -1 if sample_size_per_task > 0: for p_task in prev_tasks: p_path = os.path.join(data_path, p_task, '{}.{}'.format(split, tgt)) p_dataset = _load_dataset(p_path, dataset_impl) tgt_dataset.append( SliceDataset(p_dataset, end=sample_size_per_task)) topk_ids.append( PlaceholderDataset(placeholder=None, length=sample_size_per_task)) topk_probs.append( PlaceholderDataset(placeholder=None, length=sample_size_per_task)) tgt_dataset = ConcatDataset(tgt_dataset) topk_ids = ConcatDataset(topk_ids) topk_probs = ConcatDataset(topk_probs) if truncate_target: tgt_dataset = TruncateDataset(tgt_dataset, max_target_positions) LOGGER.info('Truncate dataset into max length: {}'.format( max_target_positions)) LOGGER.info('loaded {} examples from: [{}](current task) + {}(previous tasks)'. \ format(len(tgt_dataset), cur_task, prev_tasks)) return TopkKDCompletionDataset( topk_ids=topk_ids, topk_probs=topk_probs, topk=topk, distill_topk=distill_topk, tgt=tgt_dataset, tgt_sizes=tgt_dataset.sizes, tgt_dict=tgt_dict, extends=None, attrs=None, attr_indices=None, attr_dict=None, attrs_mapping=None, reversed_attrs_mapping=None, max_target_positions=max_target_positions, shuffle=(split == 'train'), )
def load_langpair_dataset( data_path, split, domains, src, src_dict, tgt, tgt_dict, combine, dataset_impl, upsample_primary, left_pad_source, left_pad_target, max_source_positions, max_target_positions, prepend_bos=False, load_alignments=False, truncate_source=False, append_source_id=False, append_eos_to_target=False, ): def split_exists(split, src, data_path, domain): filename = os.path.join(data_path, domain, '{}.{}'.format(split, src)) # -{}.{} , tgt, lang return indexed_dataset.dataset_exists(filename, impl=dataset_impl) src_datasets = [] tgt_datasets = [] for dm in domains: # load datasets of src domains for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') # infer langcode if split_exists(split_k, src, data_path, dm): prefix = os.path.join( data_path, dm, '{}.'.format(split_k)) # {}-{}. , src, tgt elif split_exists(split_k, tgt, data_path, dm): prefix = os.path.join( data_path, dm, '{}.'.format(split_k)) # {}-{}. , tgt, src else: if k > 0: break else: raise FileNotFoundError( 'Dataset not found: {} ({})'.format(split, data_path)) src_dataset = data_utils.load_indexed_dataset( prefix + src, src, src_dict, dataset_impl) if truncate_source: src_dataset = AppendTokenDataset( TruncateDataset( StripTokenDataset(src_dataset, src_dict.eos()), max_source_positions - 1, ), src_dict.eos(), ) src_datasets.append(src_dataset) tgt_dataset = data_utils.load_indexed_dataset( prefix + tgt, tgt, tgt_dict, dataset_impl) if tgt_dataset is not None: tgt_datasets.append(tgt_dataset) if not combine: break assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0 if len(src_datasets) == 1: src_dataset = src_datasets[0] tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None else: sample_ratios = [1] * len(src_datasets) sample_ratios[0] = upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) if len(tgt_datasets) > 0: tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) else: tgt_dataset = None if prepend_bos: assert hasattr(src_dict, "bos_index") and hasattr( tgt_dict, "bos_index") src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) if tgt_dataset is not None: tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) eos = None if append_source_id: src_dataset = AppendTokenDataset(src_dataset, src_dict.index('[{}]'.format(src))) if tgt_dataset is not None: tgt_dataset = AppendTokenDataset( tgt_dataset, tgt_dict.index('[{}]'.format(tgt))) eos = tgt_dict.index('[{}]'.format(tgt)) # align_dataset = None # if load_alignments: # align_path = os.path.join(data_path, '{}.align.{}-{}'.format(split, src, tgt)) # if indexed_dataset.dataset_exists(align_path, impl=dataset_impl): # align_dataset = data_utils.load_indexed_dataset(align_path, None, dataset_impl) tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None return LanguagePairDataset( src_dataset, src_dataset.sizes, src_dict, tgt_dataset, tgt_dataset_sizes, tgt_dict, left_pad_source=left_pad_source, left_pad_target=left_pad_target, max_source_positions=max_source_positions, max_target_positions=max_target_positions, align_dataset=None, eos=eos, remove_eos_from_source=True, append_eos_to_target=append_eos_to_target, shuffle=True, # TODO debug: shuffle=False )
def load_token_dataset( data_path, split, tgt, tgt_dict, dataset_impl, attrs=None, attr_dict=None, attrs_mapping=None, reversed_attrs_mapping=None, truncate_target=False, max_target_positions=None, # lifelong learning prev_tasks=[], cur_task=None, sample_portion=None, ): # load tokens tgt_path = os.path.join(data_path, cur_task, '{}.{}'.format(split, tgt)) tgt_dataset = [_load_dataset(tgt_path, dataset_impl)] kd_ids = [PlaceholderDataset(placeholder=True, length=len(tgt_dataset[0]))] if len(prev_tasks ) > 0 and cur_task is not None and sample_portion is not None: sample_size_per_task = int( len(tgt_dataset[0]) * sample_portion // len(prev_tasks)) else: sample_size_per_task = -1 if sample_size_per_task > 0: for p_task in prev_tasks: p_path = os.path.join(data_path, p_task, '{}.{}'.format(split, tgt)) p_dataset = _load_dataset(p_path, dataset_impl) tgt_dataset.append( SliceDataset(p_dataset, end=sample_size_per_task)) kd_ids.append( PlaceholderDataset(placeholder=False, length=sample_size_per_task)) tgt_dataset = ConcatDataset(tgt_dataset) kd_ids = ConcatDataset(kd_ids) if truncate_target: tgt_dataset = TruncateDataset(tgt_dataset, max_target_positions) LOGGER.info('Truncate dataset into max length: {}'.format( max_target_positions)) LOGGER.info('loaded {} examples from: [{}](current task) + {}(previous tasks)'. \ format(len(tgt_dataset), cur_task, prev_tasks)) return LifelongKDCompletionDataset( kd_indices=kd_ids, tgt=tgt_dataset, tgt_sizes=tgt_dataset.sizes, tgt_dict=tgt_dict, extends=None, attrs=None, attr_indices=None, attr_dict=None, attrs_mapping=None, reversed_attrs_mapping=None, max_target_positions=max_target_positions, shuffle=(split == 'train'), )
def load_token_dataset( data_path, split, tgt, tgt_dict, dataset_impl, attrs=None, attr_dict=None, attrs_mapping=None, reversed_attrs_mapping=None, truncate_target=False, max_target_positions=None, # lifelong learning prev_tasks=[], cur_task=None, sample_portion=None, ): # load tokens tgt_path = os.path.join(data_path, cur_task, '{}.{}'.format(split, tgt)) tgt_dataset = [_load_dataset(tgt_path, dataset_impl)] if len(prev_tasks ) > 0 and cur_task is not None and sample_portion is not None: sample_size_per_task = int( len(tgt_dataset[0]) * sample_portion // len(prev_tasks)) else: sample_size_per_task = -1 if sample_size_per_task > 0: for p_task in prev_tasks: p_path = os.path.join(data_path, p_task, '{}.{}'.format(split, tgt)) p_dataset = _load_dataset(p_path, dataset_impl) tgt_dataset.append( SliceDataset(p_dataset, end=sample_size_per_task)) tgt_dataset = ConcatDataset(tgt_dataset) if truncate_target: tgt_dataset = TruncateDataset(tgt_dataset, max_target_positions) LOGGER.info('Truncate dataset into max length: {}'.format( max_target_positions)) LOGGER.info('loaded {} examples from: [{}](current task) + {}(previous tasks)'. \ format(len(tgt_dataset), cur_task, prev_tasks)) # load tokens.ext tgt_ext_path = os.path.join(data_path, cur_task, '{}.{}.ext'.format(split, tgt)) if indexed_dataset.SeqIndexedDataset.exists(tgt_ext_path): tgt_ext_dataset = indexed_dataset.SeqIndexedDataset(tgt_ext_path) if sample_size_per_task > 0: for p_task in prev_tasks: p_ext_path = os.path.join(data_path, p_task, '{}.{}.ext'.format(split, tgt)) p_ext_dataset = indexed_dataset.SeqIndexedDataset(p_ext_path) p_ext_dataset.truncate(end=sample_size_per_task) tgt_ext_dataset.append(p_ext_dataset) if truncate_target: tgt_ext_dataset.clip(max_position=max_target_positions) assert len(tgt_dataset) == len(tgt_ext_dataset), (len(tgt_dataset), len(tgt_ext_dataset)) else: tgt_ext_dataset = None # load attrs if attrs is None: attr_dataset = None else: attr_path = os.path.join(data_path, cur_task, '{}.code_types'.format(split)) attr_dataset = [_load_dataset(attr_path, dataset_impl)] if sample_size_per_task > 0: for p_task in prev_tasks: p_attr_path = os.path.join(data_path, p_task, '{}.code_types'.format(split)) p_attr_dataset = _load_dataset(p_attr_path, dataset_impl) attr_dataset.append( SliceDataset(p_attr_dataset, end=sample_size_per_task)) attr_dataset = ConcatDataset(attr_dataset) if truncate_target: tgt_dataset = TruncateDataset(tgt_dataset, max_target_positions) LOGGER.info( 'Truncate dataset\'s attributes into max length: {}'.format( max_target_positions)) LOGGER.info('loaded attributes {} examples from: [{}](current task) + {}(previous tasks)'. \ format(len(attr_dataset), cur_task, prev_tasks)) # load attr.ext attr_ext_path = os.path.join(data_path, cur_task, '{}.code_types.ext'.format(split)) if indexed_dataset.SeqIndexedDataset.exists(attr_ext_path): attr_ext_dataset = indexed_dataset.SeqIndexedDataset(attr_ext_path) if sample_size_per_task > 0: for p_task in prev_tasks: p_attr_ext_path = os.path.join( data_path, p_task, '{}.code_types.ext'.format(split)) p_attr_ext_dataset = indexed_dataset.SeqIndexedDataset( p_attr_ext_path) p_attr_ext_dataset.truncate(end=sample_size_per_task) attr_ext_dataset.append(p_attr_ext_dataset) if truncate_target: attr_ext_dataset.clip(max_position=max_target_positions) assert np.all(tgt_ext_dataset == attr_ext_dataset) del attr_ext_dataset return CompletionDataset( tgt_dataset, tgt_dataset.sizes, tgt_dict, extends=tgt_ext_dataset, attrs=attrs, attr_indices=attr_dataset, attr_dict=attr_dict, attrs_mapping=attrs_mapping, reversed_attrs_mapping=reversed_attrs_mapping, max_target_positions=max_target_positions, )
def load_langpair_dataset( data_path, split, src, src_dict, tgt, tgt_dict, dataset_impl, # combine, dataset_impl, upsample_primary, left_pad_source, left_pad_target, max_source_positions, max_target_positions, prepend_bos=True, append_eos=True, load_alignments=False, truncate_source=False, append_source_id=False, truncate_target=False, # lifelong learning prev_tasks=[], cur_task=None, sample_portion=None, ): # truncate sentence for prepend <bos> and append <eos> max_target_positions -= int(prepend_bos) + int(append_eos) # load source dataset src_path = os.path.join(data_path, cur_task, '{}.{}'.format(split, src)) src_dataset = [ _load_dataset(path=src_path, impl=dataset_impl, dict=src_dict) ] # load previous tasks if len(prev_tasks ) > 0 and cur_task is not None and sample_portion is not None: sample_size_per_task = int( len(src_dataset[0]) * sample_portion // len(prev_tasks)) else: sample_size_per_task = -1 if sample_size_per_task > 0: for p_task in prev_tasks: p_path = os.path.join(data_path, p_task, '{}.{}'.format(split, src)) p_dataset = _load_dataset(p_path, dataset_impl, src_dict) src_dataset.append( SliceDataset(p_dataset, end=sample_size_per_task)) src_dataset = ConcatDataset(src_dataset) # truncate dataset if truncate_source: # sntn => sntn[:max_source_positions] LOGGER.info('truncate {}.{} to {}'.format(split, src, max_source_positions)) src_dataset = TruncateDataset(src_dataset, max_source_positions) # load target dataset tgt_path = os.path.join(data_path, cur_task, '{}.{}'.format(split, tgt)) tgt_dataset = [ _load_dataset(path=tgt_path, impl=dataset_impl, dict=tgt_dict) ] if sample_size_per_task > 0: for p_task in prev_tasks: p_path = os.path.join(data_path, p_task, '{}.{}'.format(split, tgt)) p_dataset = _load_dataset(p_path, dataset_impl, tgt_dict) tgt_dataset.append( SliceDataset(p_dataset, end=sample_size_per_task)) tgt_dataset = ConcatDataset(tgt_dataset) if truncate_target: # sntn => sntn[:max_target_positions] LOGGER.info('truncate {}.{} to {}'.format(split, tgt, max_target_positions)) tgt_dataset = TruncateDataset( tgt_dataset, max_target_positions) # 2 for BOS and EOS # sntn[:max_target_positions] => <bos> sntn[:max_target_positions] if prepend_bos: tgt_dataset = PrependTokenDataset(tgt_dataset, token=tgt_dict.bos()) if append_eos: tgt_dataset = AppendTokenDataset(tgt_dataset, token=tgt_dict.eos()) tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None assert len(src_dataset) == len(tgt_dataset), (len(src_dataset), len(tgt_dataset)) LOGGER.info('loaded {} examples from: [{}](current task) + {}(previous tasks)'. \ format(len(src_dataset), cur_task, prev_tasks)) return BELanguagePairDataset( src_dataset, src_dataset.sizes, src_dict, tgt_dataset, tgt_dataset_sizes, tgt_dict, left_pad_source=left_pad_source, left_pad_target=left_pad_target, max_source_positions=max_source_positions, max_target_positions=max_target_positions, align_dataset=None, bos=src_dict.bos(), eos=src_dict.eos(), shuffle=(split == 'train'), )
def load_inference_langpair_dataset( data_paths, split, src, src_dict, tgt, tgt_dict, dataset_impl, # combine, dataset_impl, upsample_primary, left_pad_source, left_pad_target, max_source_positions, max_target_positions, prepend_bos=True, append_eos=True, load_alignments=False, truncate_source=False, append_source_id=False, truncate_target=False, ): # truncate sentence for prepend <bos> and append <eos> max_target_positions -= int(prepend_bos) + int(append_eos) # load source dataset src_dataset = [] for data_path in data_paths: src_path = os.path.join(data_path, '{}.{}'.format(split, src)) src_dataset.append( _load_dataset(path=src_path, impl=dataset_impl, dict=src_dict)) src_dataset = ConcatDataset(src_dataset) # truncate dataset if truncate_source: # sntn => sntn[:max_source_positions] LOGGER.info('truncate {}.{} to {}'.format(split, src, max_source_positions)) src_dataset = TruncateDataset(src_dataset, max_source_positions) # load target dataset tgt_dataset = [] for data_path in data_paths: tgt_path = os.path.join(data_path, '{}.{}'.format(split, src)) tgt_dataset.append( _load_dataset(path=tgt_path, impl=dataset_impl, dict=tgt_dict)) tgt_dataset = ConcatDataset(tgt_dataset) if truncate_target: # sntn => sntn[:max_target_positions] LOGGER.info('truncate {}.{} to {}'.format(split, tgt, max_target_positions)) tgt_dataset = TruncateDataset( tgt_dataset, max_target_positions) # 2 for BOS and EOS # sntn[:max_target_positions] => <bos> sntn[:max_target_positions] if prepend_bos: tgt_dataset = PrependTokenDataset(tgt_dataset, token=tgt_dict.bos()) if append_eos: tgt_dataset = AppendTokenDataset(tgt_dataset, token=tgt_dict.eos()) tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None LOGGER.info('loaded {} examples from: {}.{}'.format( len(src_dataset), data_paths, src)) LOGGER.info('loaded {} examples from: {}.{}'.format( len(tgt_dataset), data_paths, tgt)) return BELanguagePairDataset( src_dataset, src_dataset.sizes, src_dict, tgt_dataset, tgt_dataset_sizes, tgt_dict, left_pad_source=left_pad_source, left_pad_target=left_pad_target, max_source_positions=max_source_positions, max_target_positions=max_target_positions, align_dataset=None, bos=src_dict.bos(), eos=src_dict.eos(), shuffle=(split == 'train'), )
def load_langpair_dataset( args, programming_langs, data_path, split, src, src_dict, tgt, tgt_dict, combine, dataset_impl, upsample_primary, left_pad_source, left_pad_target, max_source_positions, max_target_positions, prepend_bos=False, load_alignments=False, truncate_source=False, append_source_id=False, is_distill=False, ): def split_exists(split, src, data_path): filename = os.path.join(data_path, '{}.{}'.format(split, src)) # -{}.{} , tgt, lang return indexed_dataset.dataset_exists(filename, impl=dataset_impl) src_datasets = [] tgt_datasets = [] topk_idxs = [] topk_probs = [] expert_scores = [] dataset_ids = [] lng_borders = [0] is_train = split == 'train' for ds_idx, program_lang in enumerate(programming_langs): lang_data_path = os.path.join(data_path, program_lang) split_k = split # infer langcode if split_exists(split_k, src, lang_data_path): prefix = os.path.join(lang_data_path, '{}.'.format(split_k)) # {}-{}. , src, tgt elif split_exists(split_k, tgt, lang_data_path): prefix = os.path.join(lang_data_path, '{}.'.format(split_k)) # {}-{}. , tgt, src else: raise NotImplementedError('No data in {}'.format(lang_data_path)) src_dataset = data_utils.load_indexed_dataset(prefix + src, src, src_dict, dataset_impl) if truncate_source: src_dataset = AppendTokenDataset( TruncateDataset( StripTokenDataset(src_dataset, src_dict.eos()), max_source_positions - 1, ), src_dict.eos(), ) src_datasets.append(src_dataset) length = len(src_dataset) lng_borders.append(lng_borders[-1] + length) tgt_dataset = data_utils.load_indexed_dataset(prefix + tgt, tgt, tgt_dict, dataset_impl) if tgt_dataset is not None: tgt_datasets.append(tgt_dataset) for i in range(length): dataset_ids.append(ds_idx) if is_distill and is_train: # distill only for train path = '{}_{}_{}_topk_idx'.format(lang_data_path, src, tgt) topk_idxs.append(TeacherOutputDataset(path)) path = '{}_{}_{}_topk_prob'.format(lang_data_path, src, tgt) topk_probs.append(TeacherOutputDataset(path)) expert_bleu = os.path.join( data_path, 'expert_bleu_{}_{}_{}.json'.format(program_lang, src, tgt)) expert_bleu = json.load(open(expert_bleu)) expert_scores.append(expert_bleu[f"bleu_{program_lang}"]) assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0 sample_ratios = [1] * len(src_datasets) sample_ratios[0] = upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) if len(tgt_datasets) > 0: tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) else: tgt_dataset = None LOGGER.info('src data: {}, tgt data: {}'.format(len(src_dataset), len(tgt_dataset))) if is_distill and is_train: # distill only for train topk_idx_dataset = ConcatDataset(topk_idxs) topk_probs_dataset = ConcatDataset(topk_probs) assert len(topk_probs_dataset) == len(src_dataset), ( len(topk_probs_dataset), len(src_dataset)) assert len(topk_idx_dataset) == len(src_dataset) else: topk_idx_dataset = None topk_probs_dataset = None if prepend_bos: assert hasattr(src_dict, "bos_index") and hasattr( tgt_dict, "bos_index") src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) if tgt_dataset is not None: tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None return UniversalDataset( args, src_dataset, src_dataset.sizes, src_dict, tgt_dataset, tgt_dataset_sizes, tgt_dict, dataset_ids=dataset_ids, lng_borders=lng_borders, dataset_names=programming_langs, left_pad_source=left_pad_source, left_pad_target=left_pad_target, max_source_positions=max_source_positions, max_target_positions=max_target_positions, topk_idxs=topk_idx_dataset, topk_probs=topk_probs_dataset, expert_scores=expert_scores, is_train=is_train, )
def load_token_dataset( data_paths, split, tgt, tgt_dict, dataset_impl, attrs=None, attr_dict=None, attrs_mapping=None, reversed_attrs_mapping=None, truncate_target=False, max_target_positions=None, ): # load tokens tgt_dataset = [] for data_path in data_paths: tgt_path = os.path.join(data_path, '{}.{}'.format(split, tgt)) tgt_dataset.append(_load_dataset(tgt_path, dataset_impl)) tgt_dataset = ConcatDataset(tgt_dataset) if truncate_target: tgt_dataset = TruncateDataset(tgt_dataset, max_target_positions) LOGGER.info('Truncate dataset into max length: {}'.format( max_target_positions)) LOGGER.info('loaded {} examples from: {}'.format(len(tgt_dataset), data_paths)) # load tokens.ext tgt_ext_paths = [ os.path.join(data_path, '{}.{}.ext'.format(split, tgt)) for data_path in data_paths ] if all( indexed_dataset.SeqIndexedDataset.exists(tgt_ext_path) for tgt_ext_path in tgt_ext_paths): tgt_ext_dataset = indexed_dataset.SeqIndexedDataset(tgt_ext_paths[0]) for tgt_ext_path in tgt_ext_paths[1:]: tgt_ext_dataset.append( indexed_dataset.SeqIndexedDataset(tgt_ext_path)) if truncate_target: tgt_ext_dataset.clip(max_position=max_target_positions) assert len(tgt_dataset) == len(tgt_ext_dataset), (len(tgt_dataset), len(tgt_ext_dataset)) else: tgt_ext_dataset = None # load attrs if attrs is None: attr_dataset = None else: attr_dataset = [] for data_path in data_paths: attr_path = os.path.join(data_path, '{}.code_types'.format(split)) attr_dataset.append(_load_dataset(attr_path, dataset_impl)) attr_dataset = ConcatDataset(attr_dataset) if truncate_target: tgt_dataset = TruncateDataset(tgt_dataset, max_target_positions) LOGGER.info( 'Truncate dataset\'s attributes into max length: {}'.format( max_target_positions)) LOGGER.info('loaded {} examples from: {}'.format( len(attr_dataset), data_path)) # load attr.ext attr_ext_paths = [ os.path.join(data_path, '{}.code_types.ext'.format(split)) for data_path in data_paths ] if all( indexed_dataset.SeqIndexedDataset.exists(attr_ext_path) for attr_ext_path in attr_ext_paths): attr_ext_dataset = indexed_dataset.SeqIndexedDataset( attr_ext_paths[0]) for attr_ext_path in attr_ext_paths[1:]: attr_ext_dataset.append( indexed_dataset.SeqIndexedDataset(attr_ext_path)) if truncate_target: attr_ext_dataset.clip(max_position=max_target_positions) assert np.all(tgt_ext_dataset == attr_ext_dataset) del attr_ext_dataset return CompletionDataset( tgt_dataset, tgt_dataset.sizes, tgt_dict, extends=tgt_ext_dataset, attrs=attrs, attr_indices=attr_dataset, attr_dict=attr_dict, attrs_mapping=attrs_mapping, reversed_attrs_mapping=reversed_attrs_mapping, max_target_positions=max_target_positions, )
def load_langpair_dataset( data_path, split, src, src_dict, tgt, tgt_dict, combine, dataset_impl, upsample_primary, left_pad_source, left_pad_target, max_source_positions, max_target_positions, load_alignments=False, truncate_source=False, truncate_target=False, ): def split_exists(split, src, data_path): filename = os.path.join(data_path, '{}.{}'.format(split, src)) # -{}.{} , tgt, lang return indexed_dataset.dataset_exists(filename, impl=dataset_impl) src_datasets = [] tgt_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') # infer langcode if split_exists(split_k, src, data_path): prefix = os.path.join(data_path, '{}.'.format(split_k)) # {}-{}. , src, tgt elif split_exists(split_k, tgt, data_path): prefix = os.path.join(data_path, '{}.'.format(split_k)) # {}-{}. , tgt, src else: if k > 0: break else: raise FileNotFoundError('Dataset not found: {} ({})'.format( split, data_path)) src_dataset = data_utils.load_indexed_dataset(prefix + src, 'text', src_dict, dataset_impl) if truncate_source and max_source_positions: src_dataset = TruncateDataset(src_dataset, max_source_positions) src_datasets.append(src_dataset) tgt_dataset = data_utils.load_indexed_dataset(prefix + tgt, 'text', tgt_dict, dataset_impl) if truncate_target and max_target_positions: tgt_dataset = PrependTokenDataset( AppendTokenDataset(TruncateDataset(tgt_dataset, max_target_positions - 2), token=tgt_dict.eos()), tgt_dict.bos()) if tgt_dataset is not None: tgt_datasets.append(tgt_dataset) if not combine: break assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0 if len(src_datasets) == 1: src_dataset = src_datasets[0] tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None else: sample_ratios = [1] * len(src_datasets) sample_ratios[0] = upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) if len(tgt_datasets) > 0: tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) else: tgt_dataset = None tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None return LanguagePairDataset( src_dataset, src_dataset.sizes, src_dict, tgt_dataset, tgt_dataset_sizes, tgt_dict, left_pad_source=left_pad_source, left_pad_target=left_pad_target, max_source_positions=max_source_positions, max_target_positions=max_target_positions, shuffle=True, # TODO debug: shuffle=False )