def load_token_dataset( data_path, split, tgt, tgt_dict, dataset_impl, attrs=None, attr_dict=None, attrs_mapping=None, reversed_attrs_mapping=None, truncate_target=False, max_target_positions=None, # kd topk_ids_prefix=None, topk_probs_prefix=None, topk=None, distill_topk=None, ): # load tokens tgt_path = os.path.join(data_path, '{}.{}'.format(split, tgt)) tgt_dataset = _load_dataset(tgt_path, dataset_impl) if truncate_target: tgt_dataset = TruncateDataset(tgt_dataset, max_target_positions) LOGGER.info('Truncate dataset into max length: {}'.format(max_target_positions)) LOGGER.info('loaded {} examples from: {}'.format(len(tgt_dataset), tgt_path)) # load tokens.ext tgt_ext_path = os.path.join(data_path, '{}.{}.ext'.format(split, tgt)) if indexed_dataset.SeqIndexedDataset.exists(tgt_ext_path): tgt_ext_dataset = indexed_dataset.SeqIndexedDataset(tgt_ext_path) if truncate_target: tgt_ext_dataset.clip(max_position=max_target_positions) assert len(tgt_dataset) == len(tgt_ext_dataset), (len(tgt_dataset), len(tgt_ext_dataset)) else: tgt_ext_dataset = None # load attrs if attrs is None: attr_dataset = None else: attr_path = os.path.join(data_path, '{}.code_types'.format(split)) attr_dataset = _load_dataset(attr_path, dataset_impl) if truncate_target: tgt_dataset = TruncateDataset(tgt_dataset, max_target_positions) LOGGER.info('Truncate dataset\'s attributes into max length: {}'.format(max_target_positions)) LOGGER.info('loaded {} examples from: {}'.format(len(attr_dataset), attr_path)) # load attr.ext attr_ext_path = os.path.join(data_path, '{}.code_types.ext'.format(split)) if indexed_dataset.SeqIndexedDataset.exists(attr_ext_path): attr_ext_dataset = indexed_dataset.SeqIndexedDataset(attr_ext_path) if truncate_target: attr_ext_dataset.clip(max_position=max_target_positions) assert np.all(tgt_ext_dataset == attr_ext_dataset) del attr_ext_dataset topk_ids = TeacherOutDataset(topk_ids_prefix) topk_probs = TeacherOutDataset(topk_probs_prefix) return TopkKDCompletionDataset( topk_ids=topk_ids, topk_probs=topk_probs, topk=topk, distill_topk=distill_topk, tgt=tgt_dataset, tgt_sizes=tgt_dataset.sizes, tgt_dict=tgt_dict, extends=tgt_ext_dataset, attrs=attrs, attr_indices=attr_dataset, attr_dict=attr_dict, attrs_mapping=attrs_mapping, reversed_attrs_mapping=reversed_attrs_mapping, max_target_positions=max_target_positions, )
def load_token_dataset( data_path, split, tgt, tgt_dict, dataset_impl, attrs=None, attr_dict=None, attrs_mapping=None, reversed_attrs_mapping=None, truncate_target=False, max_target_positions=None, # lifelong learning prev_tasks=[], cur_task=None, sample_portion=None, ): # load tokens tgt_path = os.path.join(data_path, cur_task, '{}.{}'.format(split, tgt)) tgt_dataset = [_load_dataset(tgt_path, dataset_impl)] if len(prev_tasks ) > 0 and cur_task is not None and sample_portion is not None: sample_size_per_task = int( len(tgt_dataset[0]) * sample_portion // len(prev_tasks)) else: sample_size_per_task = -1 if sample_size_per_task > 0: for p_task in prev_tasks: p_path = os.path.join(data_path, p_task, '{}.{}'.format(split, tgt)) p_dataset = _load_dataset(p_path, dataset_impl) tgt_dataset.append( SliceDataset(p_dataset, end=sample_size_per_task)) tgt_dataset = ConcatDataset(tgt_dataset) if truncate_target: tgt_dataset = TruncateDataset(tgt_dataset, max_target_positions) LOGGER.info('Truncate dataset into max length: {}'.format( max_target_positions)) LOGGER.info('loaded {} examples from: [{}](current task) + {}(previous tasks)'. \ format(len(tgt_dataset), cur_task, prev_tasks)) # load tokens.ext tgt_ext_path = os.path.join(data_path, cur_task, '{}.{}.ext'.format(split, tgt)) if indexed_dataset.SeqIndexedDataset.exists(tgt_ext_path): tgt_ext_dataset = indexed_dataset.SeqIndexedDataset(tgt_ext_path) if sample_size_per_task > 0: for p_task in prev_tasks: p_ext_path = os.path.join(data_path, p_task, '{}.{}.ext'.format(split, tgt)) p_ext_dataset = indexed_dataset.SeqIndexedDataset(p_ext_path) p_ext_dataset.truncate(end=sample_size_per_task) tgt_ext_dataset.append(p_ext_dataset) if truncate_target: tgt_ext_dataset.clip(max_position=max_target_positions) assert len(tgt_dataset) == len(tgt_ext_dataset), (len(tgt_dataset), len(tgt_ext_dataset)) else: tgt_ext_dataset = None # load attrs if attrs is None: attr_dataset = None else: attr_path = os.path.join(data_path, cur_task, '{}.code_types'.format(split)) attr_dataset = [_load_dataset(attr_path, dataset_impl)] if sample_size_per_task > 0: for p_task in prev_tasks: p_attr_path = os.path.join(data_path, p_task, '{}.code_types'.format(split)) p_attr_dataset = _load_dataset(p_attr_path, dataset_impl) attr_dataset.append( SliceDataset(p_attr_dataset, end=sample_size_per_task)) attr_dataset = ConcatDataset(attr_dataset) if truncate_target: tgt_dataset = TruncateDataset(tgt_dataset, max_target_positions) LOGGER.info( 'Truncate dataset\'s attributes into max length: {}'.format( max_target_positions)) LOGGER.info('loaded attributes {} examples from: [{}](current task) + {}(previous tasks)'. \ format(len(attr_dataset), cur_task, prev_tasks)) # load attr.ext attr_ext_path = os.path.join(data_path, cur_task, '{}.code_types.ext'.format(split)) if indexed_dataset.SeqIndexedDataset.exists(attr_ext_path): attr_ext_dataset = indexed_dataset.SeqIndexedDataset(attr_ext_path) if sample_size_per_task > 0: for p_task in prev_tasks: p_attr_ext_path = os.path.join( data_path, p_task, '{}.code_types.ext'.format(split)) p_attr_ext_dataset = indexed_dataset.SeqIndexedDataset( p_attr_ext_path) p_attr_ext_dataset.truncate(end=sample_size_per_task) attr_ext_dataset.append(p_attr_ext_dataset) if truncate_target: attr_ext_dataset.clip(max_position=max_target_positions) assert np.all(tgt_ext_dataset == attr_ext_dataset) del attr_ext_dataset return CompletionDataset( tgt_dataset, tgt_dataset.sizes, tgt_dict, extends=tgt_ext_dataset, attrs=attrs, attr_indices=attr_dataset, attr_dict=attr_dict, attrs_mapping=attrs_mapping, reversed_attrs_mapping=reversed_attrs_mapping, max_target_positions=max_target_positions, )
def load_token_dataset( data_paths, split, tgt, tgt_dict, dataset_impl, attrs=None, attr_dict=None, attrs_mapping=None, reversed_attrs_mapping=None, truncate_target=False, max_target_positions=None, ): # load tokens tgt_dataset = [] for data_path in data_paths: tgt_path = os.path.join(data_path, '{}.{}'.format(split, tgt)) tgt_dataset.append(_load_dataset(tgt_path, dataset_impl)) tgt_dataset = ConcatDataset(tgt_dataset) if truncate_target: tgt_dataset = TruncateDataset(tgt_dataset, max_target_positions) LOGGER.info('Truncate dataset into max length: {}'.format( max_target_positions)) LOGGER.info('loaded {} examples from: {}'.format(len(tgt_dataset), data_paths)) # load tokens.ext tgt_ext_paths = [ os.path.join(data_path, '{}.{}.ext'.format(split, tgt)) for data_path in data_paths ] if all( indexed_dataset.SeqIndexedDataset.exists(tgt_ext_path) for tgt_ext_path in tgt_ext_paths): tgt_ext_dataset = indexed_dataset.SeqIndexedDataset(tgt_ext_paths[0]) for tgt_ext_path in tgt_ext_paths[1:]: tgt_ext_dataset.append( indexed_dataset.SeqIndexedDataset(tgt_ext_path)) if truncate_target: tgt_ext_dataset.clip(max_position=max_target_positions) assert len(tgt_dataset) == len(tgt_ext_dataset), (len(tgt_dataset), len(tgt_ext_dataset)) else: tgt_ext_dataset = None # load attrs if attrs is None: attr_dataset = None else: attr_dataset = [] for data_path in data_paths: attr_path = os.path.join(data_path, '{}.code_types'.format(split)) attr_dataset.append(_load_dataset(attr_path, dataset_impl)) attr_dataset = ConcatDataset(attr_dataset) if truncate_target: tgt_dataset = TruncateDataset(tgt_dataset, max_target_positions) LOGGER.info( 'Truncate dataset\'s attributes into max length: {}'.format( max_target_positions)) LOGGER.info('loaded {} examples from: {}'.format( len(attr_dataset), data_path)) # load attr.ext attr_ext_paths = [ os.path.join(data_path, '{}.code_types.ext'.format(split)) for data_path in data_paths ] if all( indexed_dataset.SeqIndexedDataset.exists(attr_ext_path) for attr_ext_path in attr_ext_paths): attr_ext_dataset = indexed_dataset.SeqIndexedDataset( attr_ext_paths[0]) for attr_ext_path in attr_ext_paths[1:]: attr_ext_dataset.append( indexed_dataset.SeqIndexedDataset(attr_ext_path)) if truncate_target: attr_ext_dataset.clip(max_position=max_target_positions) assert np.all(tgt_ext_dataset == attr_ext_dataset) del attr_ext_dataset return CompletionDataset( tgt_dataset, tgt_dataset.sizes, tgt_dict, extends=tgt_ext_dataset, attrs=attrs, attr_indices=attr_dataset, attr_dict=attr_dict, attrs_mapping=attrs_mapping, reversed_attrs_mapping=reversed_attrs_mapping, max_target_positions=max_target_positions, )