Exemplo n.º 1
0
def load_token_dataset(
    data_path, split, tgt, tgt_dict, dataset_impl,
    attrs=None, attr_dict=None,
    attrs_mapping=None, reversed_attrs_mapping=None,
    truncate_target=False, max_target_positions=None,
    # kd
    topk_ids_prefix=None, topk_probs_prefix=None, topk=None, distill_topk=None,
):
    # load tokens
    tgt_path = os.path.join(data_path, '{}.{}'.format(split, tgt))
    tgt_dataset = _load_dataset(tgt_path, dataset_impl)
    if truncate_target:
        tgt_dataset = TruncateDataset(tgt_dataset, max_target_positions)
        LOGGER.info('Truncate dataset into max length: {}'.format(max_target_positions))
    LOGGER.info('loaded {} examples from: {}'.format(len(tgt_dataset), tgt_path))
    # load tokens.ext
    tgt_ext_path = os.path.join(data_path, '{}.{}.ext'.format(split, tgt))
    if indexed_dataset.SeqIndexedDataset.exists(tgt_ext_path):
        tgt_ext_dataset = indexed_dataset.SeqIndexedDataset(tgt_ext_path)
        if truncate_target:
            tgt_ext_dataset.clip(max_position=max_target_positions)
        assert len(tgt_dataset) == len(tgt_ext_dataset), (len(tgt_dataset), len(tgt_ext_dataset))
    else:
        tgt_ext_dataset = None
    # load attrs
    if attrs is None:
        attr_dataset = None
    else:
        attr_path = os.path.join(data_path, '{}.code_types'.format(split))
        attr_dataset = _load_dataset(attr_path, dataset_impl)
        if truncate_target:
            tgt_dataset = TruncateDataset(tgt_dataset, max_target_positions)
            LOGGER.info('Truncate dataset\'s attributes into max length: {}'.format(max_target_positions))
        LOGGER.info('loaded {} examples from: {}'.format(len(attr_dataset), attr_path))
        # load attr.ext
        attr_ext_path = os.path.join(data_path, '{}.code_types.ext'.format(split))
        if indexed_dataset.SeqIndexedDataset.exists(attr_ext_path):
            attr_ext_dataset = indexed_dataset.SeqIndexedDataset(attr_ext_path)
            if truncate_target:
                attr_ext_dataset.clip(max_position=max_target_positions)
            assert np.all(tgt_ext_dataset == attr_ext_dataset)
            del attr_ext_dataset

    topk_ids = TeacherOutDataset(topk_ids_prefix)
    topk_probs = TeacherOutDataset(topk_probs_prefix)

    return TopkKDCompletionDataset(
        topk_ids=topk_ids, topk_probs=topk_probs, topk=topk, distill_topk=distill_topk,
        tgt=tgt_dataset, tgt_sizes=tgt_dataset.sizes, tgt_dict=tgt_dict, extends=tgt_ext_dataset,
        attrs=attrs, attr_indices=attr_dataset, attr_dict=attr_dict,
        attrs_mapping=attrs_mapping, reversed_attrs_mapping=reversed_attrs_mapping,
        max_target_positions=max_target_positions,
    )
Exemplo n.º 2
0
def load_token_dataset(
    data_path,
    split,
    tgt,
    tgt_dict,
    dataset_impl,
    attrs=None,
    attr_dict=None,
    attrs_mapping=None,
    reversed_attrs_mapping=None,
    truncate_target=False,
    max_target_positions=None,
    # lifelong learning
    prev_tasks=[],
    cur_task=None,
    sample_portion=None,
):
    # load tokens
    tgt_path = os.path.join(data_path, cur_task, '{}.{}'.format(split, tgt))
    tgt_dataset = [_load_dataset(tgt_path, dataset_impl)]
    if len(prev_tasks
           ) > 0 and cur_task is not None and sample_portion is not None:
        sample_size_per_task = int(
            len(tgt_dataset[0]) * sample_portion // len(prev_tasks))
    else:
        sample_size_per_task = -1
    if sample_size_per_task > 0:
        for p_task in prev_tasks:
            p_path = os.path.join(data_path, p_task,
                                  '{}.{}'.format(split, tgt))
            p_dataset = _load_dataset(p_path, dataset_impl)
            tgt_dataset.append(
                SliceDataset(p_dataset, end=sample_size_per_task))
    tgt_dataset = ConcatDataset(tgt_dataset)
    if truncate_target:
        tgt_dataset = TruncateDataset(tgt_dataset, max_target_positions)
        LOGGER.info('Truncate dataset into max length: {}'.format(
            max_target_positions))
    LOGGER.info('loaded {} examples from: [{}](current task) + {}(previous tasks)'. \
                format(len(tgt_dataset), cur_task, prev_tasks))
    # load tokens.ext
    tgt_ext_path = os.path.join(data_path, cur_task,
                                '{}.{}.ext'.format(split, tgt))
    if indexed_dataset.SeqIndexedDataset.exists(tgt_ext_path):
        tgt_ext_dataset = indexed_dataset.SeqIndexedDataset(tgt_ext_path)
        if sample_size_per_task > 0:
            for p_task in prev_tasks:
                p_ext_path = os.path.join(data_path, p_task,
                                          '{}.{}.ext'.format(split, tgt))
                p_ext_dataset = indexed_dataset.SeqIndexedDataset(p_ext_path)
                p_ext_dataset.truncate(end=sample_size_per_task)
                tgt_ext_dataset.append(p_ext_dataset)
        if truncate_target:
            tgt_ext_dataset.clip(max_position=max_target_positions)
        assert len(tgt_dataset) == len(tgt_ext_dataset), (len(tgt_dataset),
                                                          len(tgt_ext_dataset))
    else:
        tgt_ext_dataset = None
    # load attrs
    if attrs is None:
        attr_dataset = None
    else:
        attr_path = os.path.join(data_path, cur_task,
                                 '{}.code_types'.format(split))
        attr_dataset = [_load_dataset(attr_path, dataset_impl)]
        if sample_size_per_task > 0:
            for p_task in prev_tasks:
                p_attr_path = os.path.join(data_path, p_task,
                                           '{}.code_types'.format(split))
                p_attr_dataset = _load_dataset(p_attr_path, dataset_impl)
                attr_dataset.append(
                    SliceDataset(p_attr_dataset, end=sample_size_per_task))
        attr_dataset = ConcatDataset(attr_dataset)
        if truncate_target:
            tgt_dataset = TruncateDataset(tgt_dataset, max_target_positions)
            LOGGER.info(
                'Truncate dataset\'s attributes into max length: {}'.format(
                    max_target_positions))
        LOGGER.info('loaded attributes {} examples from: [{}](current task) + {}(previous tasks)'. \
                    format(len(attr_dataset), cur_task, prev_tasks))
        # load attr.ext
        attr_ext_path = os.path.join(data_path, cur_task,
                                     '{}.code_types.ext'.format(split))
        if indexed_dataset.SeqIndexedDataset.exists(attr_ext_path):
            attr_ext_dataset = indexed_dataset.SeqIndexedDataset(attr_ext_path)
            if sample_size_per_task > 0:
                for p_task in prev_tasks:
                    p_attr_ext_path = os.path.join(
                        data_path, p_task, '{}.code_types.ext'.format(split))
                    p_attr_ext_dataset = indexed_dataset.SeqIndexedDataset(
                        p_attr_ext_path)
                    p_attr_ext_dataset.truncate(end=sample_size_per_task)
                    attr_ext_dataset.append(p_attr_ext_dataset)
            if truncate_target:
                attr_ext_dataset.clip(max_position=max_target_positions)
            assert np.all(tgt_ext_dataset == attr_ext_dataset)
            del attr_ext_dataset

    return CompletionDataset(
        tgt_dataset,
        tgt_dataset.sizes,
        tgt_dict,
        extends=tgt_ext_dataset,
        attrs=attrs,
        attr_indices=attr_dataset,
        attr_dict=attr_dict,
        attrs_mapping=attrs_mapping,
        reversed_attrs_mapping=reversed_attrs_mapping,
        max_target_positions=max_target_positions,
    )
Exemplo n.º 3
0
def load_token_dataset(
    data_paths,
    split,
    tgt,
    tgt_dict,
    dataset_impl,
    attrs=None,
    attr_dict=None,
    attrs_mapping=None,
    reversed_attrs_mapping=None,
    truncate_target=False,
    max_target_positions=None,
):
    # load tokens
    tgt_dataset = []
    for data_path in data_paths:
        tgt_path = os.path.join(data_path, '{}.{}'.format(split, tgt))
        tgt_dataset.append(_load_dataset(tgt_path, dataset_impl))
    tgt_dataset = ConcatDataset(tgt_dataset)
    if truncate_target:
        tgt_dataset = TruncateDataset(tgt_dataset, max_target_positions)
        LOGGER.info('Truncate dataset into max length: {}'.format(
            max_target_positions))
    LOGGER.info('loaded {} examples from: {}'.format(len(tgt_dataset),
                                                     data_paths))
    # load tokens.ext
    tgt_ext_paths = [
        os.path.join(data_path, '{}.{}.ext'.format(split, tgt))
        for data_path in data_paths
    ]
    if all(
            indexed_dataset.SeqIndexedDataset.exists(tgt_ext_path)
            for tgt_ext_path in tgt_ext_paths):
        tgt_ext_dataset = indexed_dataset.SeqIndexedDataset(tgt_ext_paths[0])
        for tgt_ext_path in tgt_ext_paths[1:]:
            tgt_ext_dataset.append(
                indexed_dataset.SeqIndexedDataset(tgt_ext_path))
        if truncate_target:
            tgt_ext_dataset.clip(max_position=max_target_positions)
        assert len(tgt_dataset) == len(tgt_ext_dataset), (len(tgt_dataset),
                                                          len(tgt_ext_dataset))
    else:
        tgt_ext_dataset = None
    # load attrs
    if attrs is None:
        attr_dataset = None
    else:
        attr_dataset = []
        for data_path in data_paths:
            attr_path = os.path.join(data_path, '{}.code_types'.format(split))
            attr_dataset.append(_load_dataset(attr_path, dataset_impl))
        attr_dataset = ConcatDataset(attr_dataset)
        if truncate_target:
            tgt_dataset = TruncateDataset(tgt_dataset, max_target_positions)
            LOGGER.info(
                'Truncate dataset\'s attributes into max length: {}'.format(
                    max_target_positions))
        LOGGER.info('loaded {} examples from: {}'.format(
            len(attr_dataset), data_path))
        # load attr.ext
        attr_ext_paths = [
            os.path.join(data_path, '{}.code_types.ext'.format(split))
            for data_path in data_paths
        ]
        if all(
                indexed_dataset.SeqIndexedDataset.exists(attr_ext_path)
                for attr_ext_path in attr_ext_paths):
            attr_ext_dataset = indexed_dataset.SeqIndexedDataset(
                attr_ext_paths[0])
            for attr_ext_path in attr_ext_paths[1:]:
                attr_ext_dataset.append(
                    indexed_dataset.SeqIndexedDataset(attr_ext_path))
            if truncate_target:
                attr_ext_dataset.clip(max_position=max_target_positions)
            assert np.all(tgt_ext_dataset == attr_ext_dataset)
            del attr_ext_dataset

    return CompletionDataset(
        tgt_dataset,
        tgt_dataset.sizes,
        tgt_dict,
        extends=tgt_ext_dataset,
        attrs=attrs,
        attr_indices=attr_dataset,
        attr_dict=attr_dict,
        attrs_mapping=attrs_mapping,
        reversed_attrs_mapping=reversed_attrs_mapping,
        max_target_positions=max_target_positions,
    )