Esempio n. 1
0
def load_token_dataset(
    data_path, split, tgt, tgt_dict, dataset_impl,
    attrs=None, attr_dict=None,
    attrs_mapping=None, reversed_attrs_mapping=None,
    truncate_target=False, max_target_positions=None,
    shuffle=True,
):
    # load tokens
    tgt_path = os.path.join(data_path, '{}.{}'.format(split, tgt))
    tgt_dataset = _load_dataset(tgt_path, dataset_impl)
    if truncate_target:
        tgt_dataset = TruncateDataset(tgt_dataset, max_target_positions)
        LOGGER.info('Truncate dataset into max length: {}'.format(max_target_positions))
    LOGGER.info('loaded {} examples from: {}'.format(len(tgt_dataset), tgt_path))
    # load tokens.ext
    tgt_ext_path = os.path.join(data_path, '{}.{}.ext'.format(split, tgt))
    if indexed_dataset.SeqIndexedDataset.exists(tgt_ext_path):
        tgt_ext_dataset = indexed_dataset.SeqIndexedDataset(tgt_ext_path)
        if truncate_target:
            tgt_ext_dataset.clip(max_position=max_target_positions)
        assert len(tgt_dataset) == len(tgt_ext_dataset), (len(tgt_dataset), len(tgt_ext_dataset))
    else:
        tgt_ext_dataset = None
    # load attrs
    if attrs is None:
        attr_dataset = None
    else:
        attr_path = os.path.join(data_path, '{}.code_types'.format(split))
        attr_dataset = _load_dataset(attr_path, dataset_impl)
        if truncate_target:
            tgt_dataset = TruncateDataset(tgt_dataset, max_target_positions)
            LOGGER.info('Truncate dataset\'s attributes into max length: {}'.format(max_target_positions))
        LOGGER.info('loaded {} examples from: {}'.format(len(attr_dataset), attr_path))
        # load attr.ext
        attr_ext_path = os.path.join(data_path, '{}.code_types.ext'.format(split))
        if indexed_dataset.SeqIndexedDataset.exists(attr_ext_path):
            attr_ext_dataset = indexed_dataset.SeqIndexedDataset(attr_ext_path)
            if truncate_target:
                attr_ext_dataset.clip(max_position=max_target_positions)
            assert np.all(tgt_ext_dataset == attr_ext_dataset)
            del attr_ext_dataset

    return CompletionDataset(
        tgt_dataset, tgt_dataset.sizes, tgt_dict, extends=tgt_ext_dataset,
        attrs=attrs, attr_indices=attr_dataset, attr_dict=attr_dict,
        attrs_mapping=attrs_mapping, reversed_attrs_mapping=reversed_attrs_mapping,
        max_target_positions=max_target_positions,
        shuffle=shuffle,
    )
def load_inference_token_dataset(
    data_paths,
    split,
    tgt,
    tgt_dict,
    dataset_impl,
    attrs=None,
    attr_dict=None,
    attrs_mapping=None,
    reversed_attrs_mapping=None,
    truncate_target=False,
    max_target_positions=None,
):
    # load tokens
    tgt_dataset = []
    for path in data_paths:
        tgt_path = os.path.join(path, '{}.{}'.format(split, tgt))
        tgt_dataset.append(_load_dataset(tgt_path, dataset_impl))
    tgt_dataset = ConcatDataset(tgt_dataset)
    if truncate_target:
        tgt_dataset = TruncateDataset(tgt_dataset, max_target_positions)
        LOGGER.info('Truncate dataset into max length: {}'.format(
            max_target_positions))
    LOGGER.info('loaded {} examples from: {}'.format(len(tgt_dataset),
                                                     data_paths))
    return CompletionDataset(
        tgt_dataset,
        tgt_dataset.sizes,
        tgt_dict,
        extends=None,
        attrs=None,
        attr_indices=None,
        attr_dict=None,
        attrs_mapping=None,
        reversed_attrs_mapping=None,
        max_target_positions=max_target_positions,
    )
Esempio n. 3
0
def load_token_dataset(
    data_path,
    split,
    tgt,
    tgt_dict,
    dataset_impl,
    attrs=None,
    attr_dict=None,
    attrs_mapping=None,
    reversed_attrs_mapping=None,
    truncate_target=False,
    max_target_positions=None,
    # lifelong learning
    prev_tasks=[],
    cur_task=None,
    sample_portion=None,
):
    # load tokens
    tgt_path = os.path.join(data_path, cur_task, '{}.{}'.format(split, tgt))
    tgt_dataset = [_load_dataset(tgt_path, dataset_impl)]
    if len(prev_tasks
           ) > 0 and cur_task is not None and sample_portion is not None:
        sample_size_per_task = int(
            len(tgt_dataset[0]) * sample_portion // len(prev_tasks))
    else:
        sample_size_per_task = -1
    if sample_size_per_task > 0:
        for p_task in prev_tasks:
            p_path = os.path.join(data_path, p_task,
                                  '{}.{}'.format(split, tgt))
            p_dataset = _load_dataset(p_path, dataset_impl)
            tgt_dataset.append(
                SliceDataset(p_dataset, end=sample_size_per_task))
    tgt_dataset = ConcatDataset(tgt_dataset)
    if truncate_target:
        tgt_dataset = TruncateDataset(tgt_dataset, max_target_positions)
        LOGGER.info('Truncate dataset into max length: {}'.format(
            max_target_positions))
    LOGGER.info('loaded {} examples from: [{}](current task) + {}(previous tasks)'. \
                format(len(tgt_dataset), cur_task, prev_tasks))
    # load tokens.ext
    tgt_ext_path = os.path.join(data_path, cur_task,
                                '{}.{}.ext'.format(split, tgt))
    if indexed_dataset.SeqIndexedDataset.exists(tgt_ext_path):
        tgt_ext_dataset = indexed_dataset.SeqIndexedDataset(tgt_ext_path)
        if sample_size_per_task > 0:
            for p_task in prev_tasks:
                p_ext_path = os.path.join(data_path, p_task,
                                          '{}.{}.ext'.format(split, tgt))
                p_ext_dataset = indexed_dataset.SeqIndexedDataset(p_ext_path)
                p_ext_dataset.truncate(end=sample_size_per_task)
                tgt_ext_dataset.append(p_ext_dataset)
        if truncate_target:
            tgt_ext_dataset.clip(max_position=max_target_positions)
        assert len(tgt_dataset) == len(tgt_ext_dataset), (len(tgt_dataset),
                                                          len(tgt_ext_dataset))
    else:
        tgt_ext_dataset = None
    # load attrs
    if attrs is None:
        attr_dataset = None
    else:
        attr_path = os.path.join(data_path, cur_task,
                                 '{}.code_types'.format(split))
        attr_dataset = [_load_dataset(attr_path, dataset_impl)]
        if sample_size_per_task > 0:
            for p_task in prev_tasks:
                p_attr_path = os.path.join(data_path, p_task,
                                           '{}.code_types'.format(split))
                p_attr_dataset = _load_dataset(p_attr_path, dataset_impl)
                attr_dataset.append(
                    SliceDataset(p_attr_dataset, end=sample_size_per_task))
        attr_dataset = ConcatDataset(attr_dataset)
        if truncate_target:
            tgt_dataset = TruncateDataset(tgt_dataset, max_target_positions)
            LOGGER.info(
                'Truncate dataset\'s attributes into max length: {}'.format(
                    max_target_positions))
        LOGGER.info('loaded attributes {} examples from: [{}](current task) + {}(previous tasks)'. \
                    format(len(attr_dataset), cur_task, prev_tasks))
        # load attr.ext
        attr_ext_path = os.path.join(data_path, cur_task,
                                     '{}.code_types.ext'.format(split))
        if indexed_dataset.SeqIndexedDataset.exists(attr_ext_path):
            attr_ext_dataset = indexed_dataset.SeqIndexedDataset(attr_ext_path)
            if sample_size_per_task > 0:
                for p_task in prev_tasks:
                    p_attr_ext_path = os.path.join(
                        data_path, p_task, '{}.code_types.ext'.format(split))
                    p_attr_ext_dataset = indexed_dataset.SeqIndexedDataset(
                        p_attr_ext_path)
                    p_attr_ext_dataset.truncate(end=sample_size_per_task)
                    attr_ext_dataset.append(p_attr_ext_dataset)
            if truncate_target:
                attr_ext_dataset.clip(max_position=max_target_positions)
            assert np.all(tgt_ext_dataset == attr_ext_dataset)
            del attr_ext_dataset

    return CompletionDataset(
        tgt_dataset,
        tgt_dataset.sizes,
        tgt_dict,
        extends=tgt_ext_dataset,
        attrs=attrs,
        attr_indices=attr_dataset,
        attr_dict=attr_dict,
        attrs_mapping=attrs_mapping,
        reversed_attrs_mapping=reversed_attrs_mapping,
        max_target_positions=max_target_positions,
    )
Esempio n. 4
0
 def build_dataset_for_inference(self, src_tokens, src_lengths):
     return CompletionDataset(src_tokens, src_lengths,
                              self.target_dictionary)  # TODO: bug