def load_token_dataset( data_path, split, tgt, tgt_dict, dataset_impl, attrs=None, attr_dict=None, attrs_mapping=None, reversed_attrs_mapping=None, truncate_target=False, max_target_positions=None, shuffle=True, ): # load tokens tgt_path = os.path.join(data_path, '{}.{}'.format(split, tgt)) tgt_dataset = _load_dataset(tgt_path, dataset_impl) if truncate_target: tgt_dataset = TruncateDataset(tgt_dataset, max_target_positions) LOGGER.info('Truncate dataset into max length: {}'.format(max_target_positions)) LOGGER.info('loaded {} examples from: {}'.format(len(tgt_dataset), tgt_path)) # load tokens.ext tgt_ext_path = os.path.join(data_path, '{}.{}.ext'.format(split, tgt)) if indexed_dataset.SeqIndexedDataset.exists(tgt_ext_path): tgt_ext_dataset = indexed_dataset.SeqIndexedDataset(tgt_ext_path) if truncate_target: tgt_ext_dataset.clip(max_position=max_target_positions) assert len(tgt_dataset) == len(tgt_ext_dataset), (len(tgt_dataset), len(tgt_ext_dataset)) else: tgt_ext_dataset = None # load attrs if attrs is None: attr_dataset = None else: attr_path = os.path.join(data_path, '{}.code_types'.format(split)) attr_dataset = _load_dataset(attr_path, dataset_impl) if truncate_target: tgt_dataset = TruncateDataset(tgt_dataset, max_target_positions) LOGGER.info('Truncate dataset\'s attributes into max length: {}'.format(max_target_positions)) LOGGER.info('loaded {} examples from: {}'.format(len(attr_dataset), attr_path)) # load attr.ext attr_ext_path = os.path.join(data_path, '{}.code_types.ext'.format(split)) if indexed_dataset.SeqIndexedDataset.exists(attr_ext_path): attr_ext_dataset = indexed_dataset.SeqIndexedDataset(attr_ext_path) if truncate_target: attr_ext_dataset.clip(max_position=max_target_positions) assert np.all(tgt_ext_dataset == attr_ext_dataset) del attr_ext_dataset return CompletionDataset( tgt_dataset, tgt_dataset.sizes, tgt_dict, extends=tgt_ext_dataset, attrs=attrs, attr_indices=attr_dataset, attr_dict=attr_dict, attrs_mapping=attrs_mapping, reversed_attrs_mapping=reversed_attrs_mapping, max_target_positions=max_target_positions, shuffle=shuffle, )
def load_inference_token_dataset( data_paths, split, tgt, tgt_dict, dataset_impl, attrs=None, attr_dict=None, attrs_mapping=None, reversed_attrs_mapping=None, truncate_target=False, max_target_positions=None, ): # load tokens tgt_dataset = [] for path in data_paths: tgt_path = os.path.join(path, '{}.{}'.format(split, tgt)) tgt_dataset.append(_load_dataset(tgt_path, dataset_impl)) tgt_dataset = ConcatDataset(tgt_dataset) if truncate_target: tgt_dataset = TruncateDataset(tgt_dataset, max_target_positions) LOGGER.info('Truncate dataset into max length: {}'.format( max_target_positions)) LOGGER.info('loaded {} examples from: {}'.format(len(tgt_dataset), data_paths)) return CompletionDataset( tgt_dataset, tgt_dataset.sizes, tgt_dict, extends=None, attrs=None, attr_indices=None, attr_dict=None, attrs_mapping=None, reversed_attrs_mapping=None, max_target_positions=max_target_positions, )
def load_token_dataset( data_path, split, tgt, tgt_dict, dataset_impl, attrs=None, attr_dict=None, attrs_mapping=None, reversed_attrs_mapping=None, truncate_target=False, max_target_positions=None, # lifelong learning prev_tasks=[], cur_task=None, sample_portion=None, ): # load tokens tgt_path = os.path.join(data_path, cur_task, '{}.{}'.format(split, tgt)) tgt_dataset = [_load_dataset(tgt_path, dataset_impl)] if len(prev_tasks ) > 0 and cur_task is not None and sample_portion is not None: sample_size_per_task = int( len(tgt_dataset[0]) * sample_portion // len(prev_tasks)) else: sample_size_per_task = -1 if sample_size_per_task > 0: for p_task in prev_tasks: p_path = os.path.join(data_path, p_task, '{}.{}'.format(split, tgt)) p_dataset = _load_dataset(p_path, dataset_impl) tgt_dataset.append( SliceDataset(p_dataset, end=sample_size_per_task)) tgt_dataset = ConcatDataset(tgt_dataset) if truncate_target: tgt_dataset = TruncateDataset(tgt_dataset, max_target_positions) LOGGER.info('Truncate dataset into max length: {}'.format( max_target_positions)) LOGGER.info('loaded {} examples from: [{}](current task) + {}(previous tasks)'. \ format(len(tgt_dataset), cur_task, prev_tasks)) # load tokens.ext tgt_ext_path = os.path.join(data_path, cur_task, '{}.{}.ext'.format(split, tgt)) if indexed_dataset.SeqIndexedDataset.exists(tgt_ext_path): tgt_ext_dataset = indexed_dataset.SeqIndexedDataset(tgt_ext_path) if sample_size_per_task > 0: for p_task in prev_tasks: p_ext_path = os.path.join(data_path, p_task, '{}.{}.ext'.format(split, tgt)) p_ext_dataset = indexed_dataset.SeqIndexedDataset(p_ext_path) p_ext_dataset.truncate(end=sample_size_per_task) tgt_ext_dataset.append(p_ext_dataset) if truncate_target: tgt_ext_dataset.clip(max_position=max_target_positions) assert len(tgt_dataset) == len(tgt_ext_dataset), (len(tgt_dataset), len(tgt_ext_dataset)) else: tgt_ext_dataset = None # load attrs if attrs is None: attr_dataset = None else: attr_path = os.path.join(data_path, cur_task, '{}.code_types'.format(split)) attr_dataset = [_load_dataset(attr_path, dataset_impl)] if sample_size_per_task > 0: for p_task in prev_tasks: p_attr_path = os.path.join(data_path, p_task, '{}.code_types'.format(split)) p_attr_dataset = _load_dataset(p_attr_path, dataset_impl) attr_dataset.append( SliceDataset(p_attr_dataset, end=sample_size_per_task)) attr_dataset = ConcatDataset(attr_dataset) if truncate_target: tgt_dataset = TruncateDataset(tgt_dataset, max_target_positions) LOGGER.info( 'Truncate dataset\'s attributes into max length: {}'.format( max_target_positions)) LOGGER.info('loaded attributes {} examples from: [{}](current task) + {}(previous tasks)'. \ format(len(attr_dataset), cur_task, prev_tasks)) # load attr.ext attr_ext_path = os.path.join(data_path, cur_task, '{}.code_types.ext'.format(split)) if indexed_dataset.SeqIndexedDataset.exists(attr_ext_path): attr_ext_dataset = indexed_dataset.SeqIndexedDataset(attr_ext_path) if sample_size_per_task > 0: for p_task in prev_tasks: p_attr_ext_path = os.path.join( data_path, p_task, '{}.code_types.ext'.format(split)) p_attr_ext_dataset = indexed_dataset.SeqIndexedDataset( p_attr_ext_path) p_attr_ext_dataset.truncate(end=sample_size_per_task) attr_ext_dataset.append(p_attr_ext_dataset) if truncate_target: attr_ext_dataset.clip(max_position=max_target_positions) assert np.all(tgt_ext_dataset == attr_ext_dataset) del attr_ext_dataset return CompletionDataset( tgt_dataset, tgt_dataset.sizes, tgt_dict, extends=tgt_ext_dataset, attrs=attrs, attr_indices=attr_dataset, attr_dict=attr_dict, attrs_mapping=attrs_mapping, reversed_attrs_mapping=reversed_attrs_mapping, max_target_positions=max_target_positions, )
def build_dataset_for_inference(self, src_tokens, src_lengths): return CompletionDataset(src_tokens, src_lengths, self.target_dictionary) # TODO: bug