コード例 #1
0
def load_langpair_dataset(
    data_path, split,
    srcs, src_dicts,
    tgts, tgt_dicts,
    dataset_impl,
    src_max_tokens, tgt_max_tokens,
    **kwargs,
):
    # load source dataset
    src_datasets, src_sizes = OrderedDict(), OrderedDict()
    for idx, src in enumerate(srcs):
        src_path = os.path.join(data_path, '{}.{}'.format(split, src))
        src_datasets[src] = _load_dataset(path=src_path, impl=dataset_impl, dict=src_dicts[src])
        src_datasets[src] = TruncateDataset(src_datasets[src], src_max_tokens[idx])
        src_sizes[src] = src_datasets[src].sizes
    # load target dataset
    tgt_datasets, tgt_sizes = OrderedDict(), OrderedDict()
    for idx, tgt in enumerate(tgts):
        tgt_path = os.path.join(data_path, '{}.{}'.format(split, tgt))
        tgt_datasets[tgt] = _load_dataset(path=tgt_path, impl=dataset_impl, dict=tgt_dicts[tgt])
        tgt_datasets[tgt] = TruncateDataset(tgt_datasets[tgt], tgt_max_tokens[idx])
        tgt_sizes[tgt] = tgt_datasets[tgt].sizes

    return DeepCSLanguagePairDataset(
        src_datasets, src_sizes, src_dicts,
        tgt_datasets, tgt_sizes, tgt_dicts,
        pad=src_dicts[srcs[0]].pad(),
        src_max_tokens=src_max_tokens, tgt_max_tokens=tgt_max_tokens,
        shuffle=(split == 'train'),
    )
コード例 #2
0
def load_tokens_dataset(
    data_path, split, src, src_dict, tgt, tgt_dict, dataset_impl,
    max_source_positions=None, max_target_positions=None, max_positions=None,
    append_source_eos=False, append_target_eos=False,
    shuffle=False,
):
    src_path = os.path.join(data_path, '{}.{}'.format(split, src))
    src_dataset = _load_dataset(src_path, dataset_impl)
    if max_source_positions is not None:
        src_dataset = TruncateDataset(src_dataset, max_source_positions)
    LOGGER.info('loaded {} examples from: {}'.format(len(src_dataset), src_path))

    tgt_path = os.path.join(data_path, '{}.{}'.format(split, tgt))
    tgt_dataset = _load_dataset(tgt_path, dataset_impl)
    if max_target_positions is not None:
        tgt_dataset = TruncateDataset(tgt_dataset, max_target_positions)
    LOGGER.info('loaded {} examples from: {}'.format(len(tgt_dataset), tgt_path))

    return BertDataset(
        src_dataset, src_dataset.sizes, src_dict,
        tgt_dataset, tgt_dataset.sizes, tgt_dict,
        max_source_positions=max_source_positions, max_target_positions=max_target_positions,
        max_positions=max_positions,
        append_source_eos=append_source_eos, append_target_eos=append_target_eos,
        shuffle=shuffle,
    )
コード例 #3
0
def load_token_dataset(
    data_path, split, tgt, tgt_dict, dataset_impl,
    attrs=None, attr_dict=None,
    attrs_mapping=None, reversed_attrs_mapping=None,
    truncate_target=False, max_target_positions=None,
    # kd
    topk_ids_prefix=None, topk_probs_prefix=None, topk=None, distill_topk=None,
):
    # load tokens
    tgt_path = os.path.join(data_path, '{}.{}'.format(split, tgt))
    tgt_dataset = _load_dataset(tgt_path, dataset_impl)
    if truncate_target:
        tgt_dataset = TruncateDataset(tgt_dataset, max_target_positions)
        LOGGER.info('Truncate dataset into max length: {}'.format(max_target_positions))
    LOGGER.info('loaded {} examples from: {}'.format(len(tgt_dataset), tgt_path))
    # load tokens.ext
    tgt_ext_path = os.path.join(data_path, '{}.{}.ext'.format(split, tgt))
    if indexed_dataset.SeqIndexedDataset.exists(tgt_ext_path):
        tgt_ext_dataset = indexed_dataset.SeqIndexedDataset(tgt_ext_path)
        if truncate_target:
            tgt_ext_dataset.clip(max_position=max_target_positions)
        assert len(tgt_dataset) == len(tgt_ext_dataset), (len(tgt_dataset), len(tgt_ext_dataset))
    else:
        tgt_ext_dataset = None
    # load attrs
    if attrs is None:
        attr_dataset = None
    else:
        attr_path = os.path.join(data_path, '{}.code_types'.format(split))
        attr_dataset = _load_dataset(attr_path, dataset_impl)
        if truncate_target:
            tgt_dataset = TruncateDataset(tgt_dataset, max_target_positions)
            LOGGER.info('Truncate dataset\'s attributes into max length: {}'.format(max_target_positions))
        LOGGER.info('loaded {} examples from: {}'.format(len(attr_dataset), attr_path))
        # load attr.ext
        attr_ext_path = os.path.join(data_path, '{}.code_types.ext'.format(split))
        if indexed_dataset.SeqIndexedDataset.exists(attr_ext_path):
            attr_ext_dataset = indexed_dataset.SeqIndexedDataset(attr_ext_path)
            if truncate_target:
                attr_ext_dataset.clip(max_position=max_target_positions)
            assert np.all(tgt_ext_dataset == attr_ext_dataset)
            del attr_ext_dataset

    topk_ids = TeacherOutDataset(topk_ids_prefix)
    topk_probs = TeacherOutDataset(topk_probs_prefix)

    return TopkKDCompletionDataset(
        topk_ids=topk_ids, topk_probs=topk_probs, topk=topk, distill_topk=distill_topk,
        tgt=tgt_dataset, tgt_sizes=tgt_dataset.sizes, tgt_dict=tgt_dict, extends=tgt_ext_dataset,
        attrs=attrs, attr_indices=attr_dataset, attr_dict=attr_dict,
        attrs_mapping=attrs_mapping, reversed_attrs_mapping=reversed_attrs_mapping,
        max_target_positions=max_target_positions,
    )
コード例 #4
0
def load_langpair_dataset(
    data_path, split,
    src, src_dict,
    tgt, tgt_dict,
    dataset_impl,

    left_pad_source, left_pad_target,
    max_source_positions, max_target_positions,
    prepend_bos=False, load_alignments=False,
    truncate_source=False, append_source_id=False,
    truncate_target=False,
    append_eos_to_target=False,
    portion=None,
):
    src_path = os.path.join(data_path, '{}.{}'.format(split, src))
    src_dataset = _load_dataset(path=src_path, impl=dataset_impl, dict=src_dict)

    if portion is not None and split == 'train':
        LOGGER.info('set {}.{} portion to {}'.format(split, src, portion))
        src_dataset = PortionDataset(src_dataset, portion)

    tgt_path = os.path.join(data_path, '{}.{}'.format(split, tgt))
    tgt_dataset = _load_dataset(path=tgt_path, impl=dataset_impl, dict=tgt_dict)
    if truncate_target:
        LOGGER.info('truncate {}.{} to {}'.format(split, tgt, max_target_positions))
        tgt_dataset = TruncateDataset(tgt_dataset, max_target_positions)

    if prepend_bos:
        assert hasattr(src_dict, "bos_index") and hasattr(tgt_dict, "bos_index")
        src_dataset = PrependTokenDataset(src_dataset, src_dict.bos())
        if tgt_dataset is not None:
            tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos())

    eos = None
    if append_source_id:
        if tgt_dataset is not None:
            tgt_dataset = AppendTokenDataset(tgt_dataset, tgt_dict.index('[{}]'.format(tgt)))
        eos = tgt_dict.index('[{}]'.format(tgt))

    if portion is not None and split == 'train':
        LOGGER.info('set {}.{} portion to {}'.format(split, tgt, portion))
        tgt_dataset = PortionDataset(tgt_dataset, portion)

    tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None

    LOGGER.info('loaded {} examples from: {}'.format(len(src_dataset), src_path))
    LOGGER.info('loaded {} examples from: {}'.format(len(tgt_dataset), tgt_path))
    return GraphLanguagePairDataset(
        src_dataset, src_dataset.sizes, src_dict,
        tgt_dataset, tgt_dataset_sizes, tgt_dict,
        left_pad_source=left_pad_source,
        left_pad_target=left_pad_target,
        max_source_positions=max_source_positions,
        max_target_positions=max_target_positions,
        align_dataset=None, eos=eos,
        remove_eos_from_source=True,
        append_eos_to_target=append_eos_to_target,
        shuffle=True,

    )
コード例 #5
0
def load_tokens_dataset(
    data_path, split, srcs, src_dicts, tgts, tgt_dicts, dataset_impl,
    max_srcs=None, max_tgts=None,
    shuffle=False, sample_neg=False, **kwargs,
):
    langs = kwargs.get('langs', None)
    # source langs
    src_paths = OrderedDict()
    for src in srcs:
        if langs is None:
            src_paths[src] = [os.path.join(data_path, f'{split}.{src}')]
        else:
            src_paths[src] = [os.path.join(data_path, f'{split}.{lang}.{src}') for lang in langs]
    src_datasets, src_sizes = OrderedDict(), OrderedDict()
    for idx, (src, paths) in enumerate(src_paths.items()):
        datasets = _load_dataset(paths, dataset_impl)
        if max_srcs is not None:
            datasets = [TruncateDataset(ds, max_srcs[idx]) for ds in datasets]
        datasets = ConcatDataset(datasets, labels=langs)
        src_datasets[src] = datasets
        src_sizes[src] = datasets.sizes
    LOGGER.info('loaded {} modality(ies) from: {}'.format(len(src_datasets), src_paths))
    # target langs
    tgt_paths = OrderedDict()
    for tgt in tgts:
        if langs is None:
            tgt_paths[tgt] = [os.path.join(data_path, f'{split}.{tgt}')]
        else:
            tgt_paths[tgt] = [os.path.join(data_path, f'{split}.{lang}.{tgt}') for lang in langs]
    tgt_datasets, tgt_sizes = OrderedDict(), OrderedDict()
    for idx, (tgt, paths) in enumerate(tgt_paths.items()):
        datasets = _load_dataset(paths, dataset_impl)
        if max_tgts is not None:
            datasets = [TruncateDataset(ds, max_tgts[idx]) for ds in datasets]
        datasets = ConcatDataset(datasets, labels=langs)
        tgt_datasets[tgt] = datasets
        tgt_sizes[tgt] = datasets.sizes
    LOGGER.info('loaded {} modality(ies) from: {}'.format(len(tgt_datasets), tgt_paths))

    return MultiModalitiesRetrievalDataset(
        src_datasets, src_sizes, src_dicts,
        tgt_datasets, tgt_sizes, tgt_dicts,
        max_source_positions=max_srcs, max_target_positions=max_tgts,
        fraction_using_func_name=kwargs.get('fraction_using_func_name', None),
        shuffle=shuffle,
        labels=langs, sample_neg=sample_neg,
    )
コード例 #6
0
def load_langpair_dataset(
    data_path,
    split,
    src,
    src_dict,
    tgt,
    tgt_dict,
    dataset_impl,
    left_pad_source,
    max_source_positions,
    src_aux=None,
):
    # load source dataset
    src_path = os.path.join(data_path, '{}.{}'.format(split, src))
    src_dataset = _load_dataset(path=src_path,
                                impl=dataset_impl,
                                dict=src_dict)
    src_dataset = TruncateDataset(src_dataset,
                                  truncation_length=max_source_positions,
                                  truncate_prefix=0)

    # load target dataset
    tgt_path = os.path.join(data_path, '{}.{}'.format(split, tgt))
    tgt_dataset = _load_dataset(path=tgt_path,
                                impl=dataset_impl,
                                dict=tgt_dict)

    # load auxiliary dataset
    aux_datasets = OrderedDict()
    for aux in src_aux:
        aux_path = os.path.join(data_path, '{}.{}'.format(split, aux))
        with open(aux_path, 'rb') as reader:
            aux_datasets[aux] = pickle.load(reader)

    tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None

    LOGGER.info('loaded {} examples from: {}'.format(len(src_dataset),
                                                     src_path))
    LOGGER.info('loaded {} examples from: {}'.format(len(tgt_dataset),
                                                     tgt_path))
    return LanguagePairDataset(
        src_dataset,
        src_dataset.sizes,
        src_dict,
        src_aux=aux_datasets,
        tgt=tgt_dataset,
        tgt_sizes=tgt_dataset_sizes,
        tgt_dict=tgt_dict,
        left_pad_source=left_pad_source,
        max_source_positions=max_source_positions,
        shuffle=(split == 'train'),
    )
コード例 #7
0
def load_inference_token_dataset(
    data_paths,
    split,
    tgt,
    tgt_dict,
    dataset_impl,
    attrs=None,
    attr_dict=None,
    attrs_mapping=None,
    reversed_attrs_mapping=None,
    truncate_target=False,
    max_target_positions=None,
):
    # load tokens
    tgt_dataset = []
    for path in data_paths:
        tgt_path = os.path.join(path, '{}.{}'.format(split, tgt))
        tgt_dataset.append(_load_dataset(tgt_path, dataset_impl))
    tgt_dataset = ConcatDataset(tgt_dataset)
    if truncate_target:
        tgt_dataset = TruncateDataset(tgt_dataset, max_target_positions)
        LOGGER.info('Truncate dataset into max length: {}'.format(
            max_target_positions))
    LOGGER.info('loaded {} examples from: {}'.format(len(tgt_dataset),
                                                     data_paths))
    return CompletionDataset(
        tgt_dataset,
        tgt_dataset.sizes,
        tgt_dict,
        extends=None,
        attrs=None,
        attr_indices=None,
        attr_dict=None,
        attrs_mapping=None,
        reversed_attrs_mapping=None,
        max_target_positions=max_target_positions,
    )
コード例 #8
0
def load_kd_token_dataset(
    data_path,
    split,
    tgt,
    tgt_dict,
    dataset_impl,
    attrs=None,
    attr_dict=None,
    attrs_mapping=None,
    reversed_attrs_mapping=None,
    truncate_target=False,
    max_target_positions=None,
    # lifelong learning
    prev_tasks=[],
    cur_task=None,
    sample_portion=None,
    # kd
    teacher_out_dir=None,
    topk=None,
    distill_topk=None,
):
    # load tokens
    tgt_path = os.path.join(data_path, cur_task, '{}.{}'.format(split, tgt))
    tgt_dataset = [_load_dataset(tgt_path, dataset_impl)]
    # teacher output
    topk_ids_prefix = os.path.join(str.replace(teacher_out_dir, '*', cur_task),
                                   f'{split}.top{topk}_idx')
    topk_ids = [TeacherOutDataset(topk_ids_prefix)]
    topk_probs_prefix = os.path.join(
        str.replace(teacher_out_dir, '*', cur_task), f'{split}.top{topk}_prob')
    topk_probs = [TeacherOutDataset(topk_probs_prefix)]

    if len(prev_tasks
           ) > 0 and cur_task is not None and sample_portion is not None:
        sample_size_per_task = int(
            len(tgt_dataset[0]) * sample_portion // len(prev_tasks))
    else:
        sample_size_per_task = -1
    if sample_size_per_task > 0:
        for p_task in prev_tasks:
            p_path = os.path.join(data_path, p_task,
                                  '{}.{}'.format(split, tgt))
            p_dataset = _load_dataset(p_path, dataset_impl)
            tgt_dataset.append(
                SliceDataset(p_dataset, end=sample_size_per_task))
            topk_ids.append(
                PlaceholderDataset(placeholder=None,
                                   length=sample_size_per_task))
            topk_probs.append(
                PlaceholderDataset(placeholder=None,
                                   length=sample_size_per_task))
    tgt_dataset = ConcatDataset(tgt_dataset)
    topk_ids = ConcatDataset(topk_ids)
    topk_probs = ConcatDataset(topk_probs)
    if truncate_target:
        tgt_dataset = TruncateDataset(tgt_dataset, max_target_positions)
        LOGGER.info('Truncate dataset into max length: {}'.format(
            max_target_positions))
    LOGGER.info('loaded {} examples from: [{}](current task) + {}(previous tasks)'. \
                format(len(tgt_dataset), cur_task, prev_tasks))

    return TopkKDCompletionDataset(
        topk_ids=topk_ids,
        topk_probs=topk_probs,
        topk=topk,
        distill_topk=distill_topk,
        tgt=tgt_dataset,
        tgt_sizes=tgt_dataset.sizes,
        tgt_dict=tgt_dict,
        extends=None,
        attrs=None,
        attr_indices=None,
        attr_dict=None,
        attrs_mapping=None,
        reversed_attrs_mapping=None,
        max_target_positions=max_target_positions,
        shuffle=(split == 'train'),
    )
コード例 #9
0
def cli_main():
    SEED = 204
    BATCH_SIZE = 64
    MAX_SOURCE_POSITIONS = 1024
    EPOCH = 50

    from ncc.utils.set_seed import set_seed
    set_seed(SEED)

    use_cuda = torch.cuda.is_available()
    if use_cuda:
        device = os.environ.get('CUDA_VISIBALE_DEVICES', [0])[0]  # get first device as default
        torch.cuda.set_device(f'cuda:{device}')
    criterion = DeepTuneLoss(task=None, sentence_avg=-1)
    if use_cuda:
        criterion = criterion.cuda()

    data = []
    for i, platform in enumerate(LANGUAGES):
        DATA_DIR = os.path.join(DATASET_DIR, f'mapping/{platform}/data-mmap')

        def get_attr(attr):
            oracle_file = os.path.join(DATA_DIR, f'train.{attr}')
            with open(oracle_file, 'rb') as reader:
                out = pickle.load(reader)
            return np.asarray(out)

        platform_name = mapping_metrics.platform2str(platform)
        benchmarks = get_attr('benchmark')
        runtime_cpus = get_attr('runtime_cpu')
        runtime_gpus = get_attr('runtime_gpu')

        #################### load dataset ####################
        src_dataset = load_mmap_dataset(os.path.join(DATA_DIR, f'train.src_tokens'))
        src_dataset = TruncateDataset(src_dataset, truncation_length=MAX_SOURCE_POSITIONS, truncate_prefix=0)
        tgt_dataset = load_mmap_dataset(os.path.join(DATA_DIR, f'train.oracle'))

        src_dict = Dictionary.load(os.path.join(DATA_DIR, 'src_tokens.dict.jsonl'))
        src_aux = OrderedDict()
        src_aux['transfer'] = get_attr('transfer')
        src_aux['wgsize'] = get_attr('wgsize')

        tgt_dict = Dictionary.load(os.path.join(DATA_DIR, 'oracle.dict.jsonl'))

        dataset = LanguagePairDataset(
            src=src_dataset, src_sizes=src_dataset.sizes, src_dict=src_dict, src_aux=src_aux,
            tgt=tgt_dataset, tgt_sizes=tgt_dataset.sizes, tgt_dict=tgt_dict, tgt_aux=None,
            left_pad_source=True, max_source_positions=MAX_SOURCE_POSITIONS,
        )
        #################### load dataset ####################

        # build toy dataset for 10-fold cross validation
        tgt_data = [tgt_dataset[idx].item() for idx in range(len(tgt_dataset))]
        src_data = [None] * len(tgt_data)

        # 10-fold cross-validation
        kf = StratifiedKFold(n_splits=10, shuffle=True, random_state=SEED)
        for j, (train_ids, test_ids) in enumerate(kf.split(src_data, tgt_data)):
            # deeptune model
            model = DeepTuneEncoder(dictionary=src_dict, embed_dim=64,
                                    rnn_cell='lstm', rnn_hidden_dim=64, rnn_dropout=0., rnn_num_layers=2,
                                    aux_dim=2, inner_dim=32, out_dim=2)
            if use_cuda:
                model = model.cuda()
            optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
            for epoch_i in range(EPOCH):
                if dataset.shuffle:
                    random.shuffle(train_ids)
                train_batch_sampler = data_utils.batch_by_size(
                    train_ids,
                    num_tokens_fn=lambda *args: -1,
                    max_sentences=BATCH_SIZE,
                )
                train_dataloader = DataLoader(dataset=dataset,
                                              batch_sampler=train_batch_sampler,
                                              collate_fn=collate, )
                with tqdm(total=len(train_dataloader)) as t:
                    for sample_i, sample in enumerate(train_dataloader, start=1):
                        t.set_description(f'Epoch {epoch_i + 1}/{EPOCH} Batch {sample_i}/{len(train_dataloader)}')
                        if use_cuda:
                            sample = move_to_cuda(sample)
                        loss, sample_size, logging_output = criterion(model, sample)
                        loss.div_(sample_size)
                        t.set_postfix(loss=loss.item())
                        t.update()

                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()

            # test accuracy
            test_batch_sampler = data_utils.batch_by_size(
                test_ids,
                num_tokens_fn=lambda *args: -1,
                max_sentences=BATCH_SIZE,
            )
            test_dataloader = DataLoader(dataset=dataset,
                                         batch_sampler=test_batch_sampler,
                                         collate_fn=collate, )
            predictions, ground_truth = [], []
            for sample in test_dataloader:
                if use_cuda:
                    sample = move_to_cuda(sample)
                hybrid_out, _ = model(**sample['net_input'])
                predictions.append(hybrid_out.max(dim=-1)[1])
                ground_truth.append(sample['target'].view(-1))
            predictions = torch.cat(predictions)
            ground_truth = torch.cat(ground_truth)

            accuracy = (predictions == ground_truth).tolist()
            # runtimes of baseline mapping (CPU on AMD, GPU on NVIDIA)
            gt_runtimes = (runtime_cpus if platform == "amd" else runtime_gpus)[test_ids]
            pred_runtimes = [
                (runtime_cpus if pred == 0 else runtime_gpus)[idx]
                for idx, pred in zip(test_ids, predictions)
            ]
            speedup = gt_runtimes / pred_runtimes

            # record results
            for benchmark_, o_, p_, accuracy_, p_speedup_ in \
                zip(benchmarks[test_ids], ground_truth, predictions, accuracy, speedup):
                data.append({
                    "Model": model.__class__.__name__,
                    "Platform": platform_name,
                    'Benchmark': mapping_metrics.escape_benchmark_name(benchmark_),
                    'Benchmark Suite': mapping_metrics.escape_suite_name(benchmark_),
                    "Oracle Mapping": o_,
                    "Predicted Mapping": p_,
                    "Accuracy": accuracy_,
                    "Speedup": p_speedup_,
                })
            del model, optimizer
    performance = pd.DataFrame(
        data, index=range(1, len(data) + 1), columns=[
            "Model",
            "Platform",
            "Benchmark",
            "Benchmark Suite",
            "Oracle Mapping",
            "Predicted Mapping",
            "Accuracy",
            "Speedup"
        ])
    benchmark_out = performance.groupby(['Platform', 'Benchmark Suite'])[['Platform', 'Accuracy', 'Speedup']].mean()
    benchmark_out['Accuracy'] = round(benchmark_out['Accuracy'] * 100, 2)
    benchmark_out['Speedup'] = round(benchmark_out['Speedup'], 2)
    print(benchmark_out)
    out = performance.groupby(['Platform'])[['Platform', 'Accuracy', 'Speedup']].mean()
    out['Accuracy'] = round(out['Accuracy'] * 100, 2)
    out['Speedup'] = round(out['Speedup'], 2)
    print(out)
コード例 #10
0
def load_token_dataset(
    data_paths,
    split,
    tgt,
    tgt_dict,
    dataset_impl,
    attrs=None,
    attr_dict=None,
    attrs_mapping=None,
    reversed_attrs_mapping=None,
    truncate_target=False,
    max_target_positions=None,
):
    # load tokens
    tgt_dataset = []
    for data_path in data_paths:
        tgt_path = os.path.join(data_path, '{}.{}'.format(split, tgt))
        tgt_dataset.append(_load_dataset(tgt_path, dataset_impl))
    tgt_dataset = ConcatDataset(tgt_dataset)
    if truncate_target:
        tgt_dataset = TruncateDataset(tgt_dataset, max_target_positions)
        LOGGER.info('Truncate dataset into max length: {}'.format(
            max_target_positions))
    LOGGER.info('loaded {} examples from: {}'.format(len(tgt_dataset),
                                                     data_paths))
    # load tokens.ext
    tgt_ext_paths = [
        os.path.join(data_path, '{}.{}.ext'.format(split, tgt))
        for data_path in data_paths
    ]
    if all(
            indexed_dataset.SeqIndexedDataset.exists(tgt_ext_path)
            for tgt_ext_path in tgt_ext_paths):
        tgt_ext_dataset = indexed_dataset.SeqIndexedDataset(tgt_ext_paths[0])
        for tgt_ext_path in tgt_ext_paths[1:]:
            tgt_ext_dataset.append(
                indexed_dataset.SeqIndexedDataset(tgt_ext_path))
        if truncate_target:
            tgt_ext_dataset.clip(max_position=max_target_positions)
        assert len(tgt_dataset) == len(tgt_ext_dataset), (len(tgt_dataset),
                                                          len(tgt_ext_dataset))
    else:
        tgt_ext_dataset = None
    # load attrs
    if attrs is None:
        attr_dataset = None
    else:
        attr_dataset = []
        for data_path in data_paths:
            attr_path = os.path.join(data_path, '{}.code_types'.format(split))
            attr_dataset.append(_load_dataset(attr_path, dataset_impl))
        attr_dataset = ConcatDataset(attr_dataset)
        if truncate_target:
            tgt_dataset = TruncateDataset(tgt_dataset, max_target_positions)
            LOGGER.info(
                'Truncate dataset\'s attributes into max length: {}'.format(
                    max_target_positions))
        LOGGER.info('loaded {} examples from: {}'.format(
            len(attr_dataset), data_path))
        # load attr.ext
        attr_ext_paths = [
            os.path.join(data_path, '{}.code_types.ext'.format(split))
            for data_path in data_paths
        ]
        if all(
                indexed_dataset.SeqIndexedDataset.exists(attr_ext_path)
                for attr_ext_path in attr_ext_paths):
            attr_ext_dataset = indexed_dataset.SeqIndexedDataset(
                attr_ext_paths[0])
            for attr_ext_path in attr_ext_paths[1:]:
                attr_ext_dataset.append(
                    indexed_dataset.SeqIndexedDataset(attr_ext_path))
            if truncate_target:
                attr_ext_dataset.clip(max_position=max_target_positions)
            assert np.all(tgt_ext_dataset == attr_ext_dataset)
            del attr_ext_dataset

    return CompletionDataset(
        tgt_dataset,
        tgt_dataset.sizes,
        tgt_dict,
        extends=tgt_ext_dataset,
        attrs=attrs,
        attr_indices=attr_dataset,
        attr_dict=attr_dict,
        attrs_mapping=attrs_mapping,
        reversed_attrs_mapping=reversed_attrs_mapping,
        max_target_positions=max_target_positions,
    )
コード例 #11
0
def load_langpair_dataset(
    args,
    programming_langs,
    data_path,
    split,
    src,
    src_dict,
    tgt,
    tgt_dict,
    combine,
    dataset_impl,
    upsample_primary,
    left_pad_source,
    left_pad_target,
    max_source_positions,
    max_target_positions,
    prepend_bos=False,
    load_alignments=False,
    truncate_source=False,
    append_source_id=False,
    is_distill=False,
):
    def split_exists(split, src, data_path):
        filename = os.path.join(data_path,
                                '{}.{}'.format(split,
                                               src))  # -{}.{} , tgt, lang
        return indexed_dataset.dataset_exists(filename, impl=dataset_impl)

    src_datasets = []
    tgt_datasets = []

    topk_idxs = []
    topk_probs = []
    expert_scores = []
    dataset_ids = []
    lng_borders = [0]
    is_train = split == 'train'

    for ds_idx, program_lang in enumerate(programming_langs):
        lang_data_path = os.path.join(data_path, program_lang)

        split_k = split
        # infer langcode
        if split_exists(split_k, src, lang_data_path):
            prefix = os.path.join(lang_data_path,
                                  '{}.'.format(split_k))  # {}-{}. , src, tgt
        elif split_exists(split_k, tgt, lang_data_path):
            prefix = os.path.join(lang_data_path,
                                  '{}.'.format(split_k))  # {}-{}. , tgt, src
        else:
            raise NotImplementedError('No data in {}'.format(lang_data_path))

        src_dataset = data_utils.load_indexed_dataset(prefix + src, src,
                                                      src_dict, dataset_impl)
        if truncate_source:
            src_dataset = AppendTokenDataset(
                TruncateDataset(
                    StripTokenDataset(src_dataset, src_dict.eos()),
                    max_source_positions - 1,
                ),
                src_dict.eos(),
            )
        src_datasets.append(src_dataset)
        length = len(src_dataset)
        lng_borders.append(lng_borders[-1] + length)

        tgt_dataset = data_utils.load_indexed_dataset(prefix + tgt, tgt,
                                                      tgt_dict, dataset_impl)
        if tgt_dataset is not None:
            tgt_datasets.append(tgt_dataset)

        for i in range(length):
            dataset_ids.append(ds_idx)

        if is_distill and is_train:  # distill only for train
            path = '{}_{}_{}_topk_idx'.format(lang_data_path, src, tgt)
            topk_idxs.append(TeacherOutputDataset(path))
            path = '{}_{}_{}_topk_prob'.format(lang_data_path, src, tgt)
            topk_probs.append(TeacherOutputDataset(path))
            expert_bleu = os.path.join(
                data_path,
                'expert_bleu_{}_{}_{}.json'.format(program_lang, src, tgt))
            expert_bleu = json.load(open(expert_bleu))
            expert_scores.append(expert_bleu[f"bleu_{program_lang}"])

    assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0

    sample_ratios = [1] * len(src_datasets)
    sample_ratios[0] = upsample_primary
    src_dataset = ConcatDataset(src_datasets, sample_ratios)
    if len(tgt_datasets) > 0:
        tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
    else:
        tgt_dataset = None

    LOGGER.info('src data: {}, tgt data: {}'.format(len(src_dataset),
                                                    len(tgt_dataset)))

    if is_distill and is_train:  # distill only for train
        topk_idx_dataset = ConcatDataset(topk_idxs)
        topk_probs_dataset = ConcatDataset(topk_probs)
        assert len(topk_probs_dataset) == len(src_dataset), (
            len(topk_probs_dataset), len(src_dataset))
        assert len(topk_idx_dataset) == len(src_dataset)
    else:
        topk_idx_dataset = None
        topk_probs_dataset = None

    if prepend_bos:
        assert hasattr(src_dict, "bos_index") and hasattr(
            tgt_dict, "bos_index")
        src_dataset = PrependTokenDataset(src_dataset, src_dict.bos())
        if tgt_dataset is not None:
            tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos())

    tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None

    return UniversalDataset(
        args,
        src_dataset,
        src_dataset.sizes,
        src_dict,
        tgt_dataset,
        tgt_dataset_sizes,
        tgt_dict,
        dataset_ids=dataset_ids,
        lng_borders=lng_borders,
        dataset_names=programming_langs,
        left_pad_source=left_pad_source,
        left_pad_target=left_pad_target,
        max_source_positions=max_source_positions,
        max_target_positions=max_target_positions,
        topk_idxs=topk_idx_dataset,
        topk_probs=topk_probs_dataset,
        expert_scores=expert_scores,
        is_train=is_train,
    )
コード例 #12
0
def load_langpair_dataset(
    data_path,
    split,
    src,
    src_dict,
    tgt,
    tgt_dict,
    dataset_impl,
    # combine, dataset_impl, upsample_primary,
    left_pad_source,
    left_pad_target,
    max_source_positions,
    max_target_positions,
    prepend_bos=True,
    append_eos=True,
    load_alignments=False,
    truncate_source=False,
    append_source_id=False,
    truncate_target=False,
):
    # truncate sentence for prepend <bos> and append <eos>
    max_target_positions -= int(prepend_bos) + int(append_eos)

    # load source dataset
    src_path = os.path.join(data_path, '{}.{}'.format(split, src))
    src_dataset = _load_dataset(path=src_path,
                                impl=dataset_impl,
                                dict=src_dict)

    if truncate_source:
        # sntn => sntn[:max_source_positions]
        LOGGER.info('truncate {}.{} to {}'.format(split, src,
                                                  max_source_positions))
        src_dataset = TruncateDataset(src_dataset, max_source_positions)

    # load target dataset
    tgt_path = os.path.join(data_path, '{}.{}'.format(split, tgt))
    tgt_dataset = _load_dataset(path=tgt_path,
                                impl=dataset_impl,
                                dict=tgt_dict)
    if truncate_target:
        # sntn => sntn[:max_target_positions]
        LOGGER.info('truncate {}.{} to {}'.format(split, tgt,
                                                  max_target_positions))
        tgt_dataset = TruncateDataset(
            tgt_dataset, max_target_positions)  # 2 for BOS and EOS
    # sntn[:max_target_positions] => <bos> sntn[:max_target_positions]
    if prepend_bos:
        tgt_dataset = PrependTokenDataset(tgt_dataset, token=tgt_dict.bos())
    if append_eos:
        tgt_dataset = AppendTokenDataset(tgt_dataset, token=tgt_dict.eos())
    tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None

    # load tgt ids
    tgt_ids_path = os.path.join(data_path, '{}.id'.format(split))
    tgt_ids = _load_ids(tgt_ids_path)

    LOGGER.info('loaded {} examples from: {}'.format(len(src_dataset),
                                                     src_path))
    LOGGER.info('loaded {} examples from: {}'.format(len(tgt_dataset),
                                                     tgt_path))
    return BELanguagePairDataset(
        src_dataset,
        src_dataset.sizes,
        src_dict,
        tgt_dataset,
        tgt_dataset_sizes,
        tgt_dict,
        tgt_ids=tgt_ids,
        left_pad_source=left_pad_source,
        left_pad_target=left_pad_target,
        max_source_positions=max_source_positions,
        max_target_positions=max_target_positions,
        align_dataset=None,
        bos=src_dict.bos(),
        eos=src_dict.eos(),
        # shuffle=True,
        shuffle=False,  # debug
    )
コード例 #13
0
ファイル: retrieval.py プロジェクト: mir-am/naturalcc
def load_tokens_dataset(
    data_path,
    split,
    src,
    src_dict,
    tgt,
    tgt_dict,
    dataset_impl,
    src_max_tokens=None,
    tgt_max_tokens=None,
    src_aux=None,
    src_aux_dict=None,
    tgt_aux=None,
    tgt_aux_dict=None,
    src_aux_max_tokens=None,
    tgt_aux_max_tokens=None,
    fraction_using_func_name=0.,
):
    src_path = os.path.join(data_path, '{}.{}'.format(split, src))
    src_dataset = _load_dataset(src_path, dataset_impl)
    src_dataset = TruncateDataset(src_dataset, src_max_tokens)

    tgt_path = os.path.join(data_path, '{}.{}'.format(split, tgt))
    tgt_dataset = _load_dataset(tgt_path, dataset_impl)
    tgt_dataset = TruncateDataset(tgt_dataset, tgt_max_tokens)

    LOGGER.info('loaded {} examples from: {}'.format(len(src_dataset),
                                                     src_path))
    LOGGER.info('loaded {} examples from: {}'.format(len(tgt_dataset),
                                                     tgt_path))

    if src_aux is not None:
        src_aux_path = os.path.join(data_path, '{}.{}'.format(split, src_aux))
        src_aux_dataset = _load_dataset(src_aux_path, dataset_impl)
        if src_aux_max_tokens is None:
            src_aux_max_tokens = src_max_tokens
        src_aux_dataset = TruncateDataset(src_aux_dataset, src_aux_max_tokens)
        LOGGER.info('loaded {} examples from: {}'.format(
            len(src_aux_dataset), src_aux_path))

    if tgt_aux is not None:
        tgt_aux_path = os.path.join(data_path, '{}.{}'.format(split, tgt_aux))
        tgt_aux_dataset = _load_dataset(tgt_aux_path, dataset_impl)
        if tgt_aux_max_tokens is None:
            tgt_aux_max_tokens = tgt_max_tokens
        tgt_aux_dataset = TruncateDataset(tgt_aux_dataset, tgt_aux_max_tokens)
        LOGGER.info('loaded {} examples from: {}'.format(
            len(tgt_aux_dataset), tgt_aux_path))

    return RetrievalDataset(
        src_dataset,
        src_dataset.sizes,
        src_dict,
        tgt_dataset,
        tgt_dataset.sizes,
        tgt_dict,
        max_source_positions=src_max_tokens,
        max_target_positions=tgt_max_tokens,
        src_aux=src_aux_dataset,
        src_aux_sizes=None if src_aux is None else src_aux_dataset.sizes,
        src_aux_dict=src_dict if src_aux_dict is None else src_aux_dict,
        tgt_aux=tgt_aux_dataset,
        tgt_aux_sizes=None if tgt_aux is None else tgt_aux_dataset.sizes,
        tgt_aux_dict=tgt_dict if tgt_aux_dict is None else tgt_aux_dict,
        fraction_using_func_name=fraction_using_func_name,
    )
コード例 #14
0
    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        # paths = self.args.data.split(':')
        paths = utils.split_paths(self.args['task']['data'])
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]
        # split_path = os.path.join(data_path, split)

        if self.langs is None:
            languages = sorted([
                name for name in os.listdir(data_path)
                if os.path.isdir(os.path.join(data_path, name))
            ])
        else:
            languages = self.langs  # .split(',')
            # for name in languages:
            #     assert os.path.exists(os.path.join(data_path, name)), FileNotFoundError(os.path.join(data_path, name))

        LOGGER.info("| Training on {0} languages: {1}".format(
            len(languages), languages))
        LOGGER.info("| Language to id mapping: ",
                    {lang: id
                     for id, lang in enumerate(languages)})

        mask_whole_words = get_whole_word_mask(self.args, self.dictionary)
        lang_datasets = []
        for language in languages:
            # split_path = os.path.join(data_path, language, split)
            if language == 'docstring':
                split_path = os.path.join(data_path, language,
                                          f"{split}.docstring.spm")
            else:
                split_path = os.path.join(data_path, language,
                                          f"{split}.code.spm")
            # split_path = os.path.join(data_path, language, f"{split}.spm.{language}")
            # dataset = data_utils.load_indexed_dataset(
            #     split_path,
            #     self.source_dictionary,
            #     self.args['dataset']['dataset_impl'],
            #     combine=combine,
            # )
            dataset = load_lang_dataset_denoising(
                path=split_path,
                impl=self.args['dataset']['dataset_impl'],
                dict=self.source_dictionary)

            if dataset is None:
                raise FileNotFoundError('Dataset not found: {} ({})'.format(
                    split, split_path))

            dataset = AppendTokenDataset(
                TruncateDataset(
                    StripTokenDataset(dataset, self.source_dictionary.eos()),
                    self.args['task']['max_source_positions'] -
                    3),  # <lang>, <bos>, <eos>
                token=self.source_dictionary.eos(),
            )

            end_token = self.source_dictionary.index('[{}]'.format(language)) \
                if self.args['task']['add_lang_token'] else self.source_dictionary.eos()

            # create continuous blocks of tokens
            dataset = TokenBlockDataset(
                dataset,
                dataset.sizes,
                self.args['task']['tokens_per_sample'] -
                2,  # one less for <s> and one for </s>
                pad=self.source_dictionary.pad(),
                eos=end_token,
                break_mode=self.args['task']['sample_break_mode'],
                document_sep_len=0,
            )
            LOGGER.info('| loaded {} blocks from: {}'.format(
                len(dataset), split_path))

            # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
            dataset = PrependTokenDataset(dataset,
                                          self.source_dictionary.bos())
            dataset = AppendTokenDataset(dataset, end_token)

            lang_dataset = DenoisingDataset(
                dataset,
                dataset.sizes,
                self.dictionary,
                self.mask_idx,
                mask_whole_words,
                shuffle=self.args['dataset']['shuffle_instance'],
                seed=self.seed,
                args=self.args,
                eos=None if not self.args['task']['add_lang_token'] else
                self.source_dictionary.index('[{}]'.format(language)),
            )
            lang_datasets.append(lang_dataset)

        dataset_lengths = np.array(
            [len(d) for d in lang_datasets],
            dtype=float,
        )
        LOGGER.info('| loaded total {} blocks for all languages'.format(
            dataset_lengths.sum(), ))
        if split == self.args['dataset']['train_subset']:
            # For train subset, additionally up or down sample languages.
            sample_probs = self._get_sample_prob(dataset_lengths)
            LOGGER.info(
                "| Sample probability by language: ", {
                    lang: "{0:.4f}".format(sample_probs[id])
                    for id, lang in enumerate(languages)
                })
            size_ratio = (sample_probs *
                          dataset_lengths.sum()) / dataset_lengths
            LOGGER.info(
                "| Up/Down Sampling ratio by language: ", {
                    lang: "{0:.2f}".format(size_ratio[id])
                    for id, lang in enumerate(languages)
                })

            resampled_lang_datasets = [
                ResamplingDataset(
                    lang_datasets[i],
                    size_ratio=size_ratio[i],
                    seed=self.args['common']['seed'],
                    epoch=epoch,
                    replace=size_ratio[i] >= 1.0,
                ) for i, d in enumerate(lang_datasets)
            ]
            dataset = ConcatDataset(resampled_lang_datasets, )
        else:
            dataset = ConcatDataset(lang_datasets)
            lang_splits = [split]
            # for lang_id, lang_dataset in enumerate(lang_datasets):
            #     split_name = split + '_' + languages[lang_id]
            #     lang_splits.append(split_name)
            #     self.datasets[split_name] = lang_dataset

            if split in self.args['dataset']['valid_subset']:
                self.args['dataset']['valid_subset'] = self.args['dataset'][
                    'valid_subset'].replace(split, ','.join(lang_splits))

        with data_utils.numpy_seed(self.args['common']['seed'] + epoch):
            shuffle = np.random.permutation(len(dataset))

        self.datasets[split] = SortDataset(
            dataset,
            sort_order=[
                shuffle,
                dataset.sizes,
            ],
        )
コード例 #15
0
ファイル: summarization.py プロジェクト: mir-am/naturalcc
def load_langpair_dataset(
    data_path,
    split,
    src,
    src_dict,
    tgt,
    tgt_dict,
    dataset_impl,
    # combine, dataset_impl, upsample_primary,
    left_pad_source,
    left_pad_target,
    max_source_positions,
    max_target_positions,
    prepend_bos=False,
    load_alignments=False,
    truncate_source=False,
    append_source_id=False,
    truncate_target=False,
    append_eos_to_target=False,
):
    # load source dataset
    src_path = os.path.join(data_path, '{}.{}'.format(split, src))
    src_dataset = _load_dataset(path=src_path,
                                impl=dataset_impl,
                                dict=src_dict)

    if truncate_source:
        # sntn => sntn[:max_source_positions]
        src_dataset = TruncateDataset(src_dataset, max_source_positions)

    # load target dataset
    tgt_path = os.path.join(data_path, '{}.{}'.format(split, tgt))
    tgt_dataset = _load_dataset(path=tgt_path,
                                impl=dataset_impl,
                                dict=tgt_dict)
    if truncate_target:
        # sntn => sntn[:max_target_positions]
        tgt_dataset = TruncateDataset(tgt_dataset, max_target_positions)

    if prepend_bos:
        assert hasattr(src_dict, "bos_index") and hasattr(
            tgt_dict, "bos_index")
        src_dataset = PrependTokenDataset(src_dataset, src_dict.bos())
        if tgt_dataset is not None:
            tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos())

    eos = None
    if append_source_id:
        src_dataset = AppendTokenDataset(src_dataset,
                                         src_dict.index('[{}]'.format(src)))
        if tgt_dataset is not None:
            tgt_dataset = AppendTokenDataset(
                tgt_dataset, tgt_dict.index('[{}]'.format(tgt)))
        eos = tgt_dict.index('[{}]'.format(tgt))

    # align_dataset = None
    # if load_alignments:
    #     align_path = os.path.join(data_path, '{}.align.{}-{}'.format(split, src, tgt))
    #     if indexed_dataset.dataset_exists(align_path, impl=dataset_impl):
    #         align_dataset = data_utils.load_indexed_dataset(align_path, None, dataset_impl)

    tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None

    LOGGER.info('loaded {} examples from: {}'.format(len(src_dataset),
                                                     src_path))
    LOGGER.info('loaded {} examples from: {}'.format(len(tgt_dataset),
                                                     tgt_path))
    return LanguagePairDataset(
        src_dataset,
        src_dataset.sizes,
        src_dict,
        tgt_dataset,
        tgt_dataset_sizes,
        tgt_dict,
        left_pad_source=left_pad_source,
        left_pad_target=left_pad_target,
        max_source_positions=max_source_positions,
        max_target_positions=max_target_positions,
        align_dataset=None,
        eos=eos,
        remove_eos_from_source=True,
        append_eos_to_target=append_eos_to_target,
        shuffle=True,
        # shuffle=False,  # debug
    )
コード例 #16
0
def load_token_dataset(
    data_path,
    split,
    tgt,
    tgt_dict,
    dataset_impl,
    attrs=None,
    attr_dict=None,
    attrs_mapping=None,
    reversed_attrs_mapping=None,
    truncate_target=False,
    max_target_positions=None,
    # lifelong learning
    prev_tasks=[],
    cur_task=None,
    sample_portion=None,
):
    # load tokens
    tgt_path = os.path.join(data_path, cur_task, '{}.{}'.format(split, tgt))
    tgt_dataset = [_load_dataset(tgt_path, dataset_impl)]
    kd_ids = [PlaceholderDataset(placeholder=True, length=len(tgt_dataset[0]))]

    if len(prev_tasks
           ) > 0 and cur_task is not None and sample_portion is not None:
        sample_size_per_task = int(
            len(tgt_dataset[0]) * sample_portion // len(prev_tasks))
    else:
        sample_size_per_task = -1
    if sample_size_per_task > 0:
        for p_task in prev_tasks:
            p_path = os.path.join(data_path, p_task,
                                  '{}.{}'.format(split, tgt))
            p_dataset = _load_dataset(p_path, dataset_impl)
            tgt_dataset.append(
                SliceDataset(p_dataset, end=sample_size_per_task))
            kd_ids.append(
                PlaceholderDataset(placeholder=False,
                                   length=sample_size_per_task))
    tgt_dataset = ConcatDataset(tgt_dataset)
    kd_ids = ConcatDataset(kd_ids)
    if truncate_target:
        tgt_dataset = TruncateDataset(tgt_dataset, max_target_positions)
        LOGGER.info('Truncate dataset into max length: {}'.format(
            max_target_positions))
    LOGGER.info('loaded {} examples from: [{}](current task) + {}(previous tasks)'. \
                format(len(tgt_dataset), cur_task, prev_tasks))

    return LifelongKDCompletionDataset(
        kd_indices=kd_ids,
        tgt=tgt_dataset,
        tgt_sizes=tgt_dataset.sizes,
        tgt_dict=tgt_dict,
        extends=None,
        attrs=None,
        attr_indices=None,
        attr_dict=None,
        attrs_mapping=None,
        reversed_attrs_mapping=None,
        max_target_positions=max_target_positions,
        shuffle=(split == 'train'),
    )
コード例 #17
0
def load_langpair_dataset(
    data_path,
    split,
    src,
    src_dict,
    tgt,
    tgt_dict,
    dataset_impl,
    # combine, dataset_impl, upsample_primary,
    left_pad_source,
    left_pad_target,
    max_source_positions,
    max_target_positions,
    prepend_bos=True,
    append_eos=True,
    load_alignments=False,
    truncate_source=False,
    append_source_id=False,
    truncate_target=False,
    # lifelong learning
    prev_tasks=[],
    cur_task=None,
    sample_portion=None,
):
    # truncate sentence for prepend <bos> and append <eos>
    max_target_positions -= int(prepend_bos) + int(append_eos)

    # load source dataset
    src_path = os.path.join(data_path, cur_task, '{}.{}'.format(split, src))
    src_dataset = [
        _load_dataset(path=src_path, impl=dataset_impl, dict=src_dict)
    ]
    # load previous tasks
    if len(prev_tasks
           ) > 0 and cur_task is not None and sample_portion is not None:
        sample_size_per_task = int(
            len(src_dataset[0]) * sample_portion // len(prev_tasks))
    else:
        sample_size_per_task = -1
    if sample_size_per_task > 0:
        for p_task in prev_tasks:
            p_path = os.path.join(data_path, p_task,
                                  '{}.{}'.format(split, src))
            p_dataset = _load_dataset(p_path, dataset_impl, src_dict)
            src_dataset.append(
                SliceDataset(p_dataset, end=sample_size_per_task))
    src_dataset = ConcatDataset(src_dataset)
    # truncate dataset
    if truncate_source:
        # sntn => sntn[:max_source_positions]
        LOGGER.info('truncate {}.{} to {}'.format(split, src,
                                                  max_source_positions))
        src_dataset = TruncateDataset(src_dataset, max_source_positions)

    # load target dataset
    tgt_path = os.path.join(data_path, cur_task, '{}.{}'.format(split, tgt))
    tgt_dataset = [
        _load_dataset(path=tgt_path, impl=dataset_impl, dict=tgt_dict)
    ]
    if sample_size_per_task > 0:
        for p_task in prev_tasks:
            p_path = os.path.join(data_path, p_task,
                                  '{}.{}'.format(split, tgt))
            p_dataset = _load_dataset(p_path, dataset_impl, tgt_dict)
            tgt_dataset.append(
                SliceDataset(p_dataset, end=sample_size_per_task))
    tgt_dataset = ConcatDataset(tgt_dataset)
    if truncate_target:
        # sntn => sntn[:max_target_positions]
        LOGGER.info('truncate {}.{} to {}'.format(split, tgt,
                                                  max_target_positions))
        tgt_dataset = TruncateDataset(
            tgt_dataset, max_target_positions)  # 2 for BOS and EOS
    # sntn[:max_target_positions] => <bos> sntn[:max_target_positions]
    if prepend_bos:
        tgt_dataset = PrependTokenDataset(tgt_dataset, token=tgt_dict.bos())
    if append_eos:
        tgt_dataset = AppendTokenDataset(tgt_dataset, token=tgt_dict.eos())
    tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None

    assert len(src_dataset) == len(tgt_dataset), (len(src_dataset),
                                                  len(tgt_dataset))
    LOGGER.info('loaded {} examples from: [{}](current task) + {}(previous tasks)'. \
                format(len(src_dataset), cur_task, prev_tasks))
    return BELanguagePairDataset(
        src_dataset,
        src_dataset.sizes,
        src_dict,
        tgt_dataset,
        tgt_dataset_sizes,
        tgt_dict,
        left_pad_source=left_pad_source,
        left_pad_target=left_pad_target,
        max_source_positions=max_source_positions,
        max_target_positions=max_target_positions,
        align_dataset=None,
        bos=src_dict.bos(),
        eos=src_dict.eos(),
        shuffle=(split == 'train'),
    )
コード例 #18
0
def load_tokens_dataset(
    data_path, split, src, src_dict, tgt, tgt_dict, dataset_impl, src_max_tokens=None, tgt_max_tokens=None,
    src_aux=None, src_aux_dict=None, tgt_aux=None, tgt_aux_dict=None, src_aux_max_tokens=None, tgt_aux_max_tokens=None,
    fraction_using_func_name=0., labels=None, shuffle=False,
):
    if len(labels) == 1 and os.path.exists(os.path.join(data_path, '{}.{}.idx'.format(split, src))):
        src_paths = [os.path.join(data_path, '{}.{}'.format(split, src))]
    else:
        src_paths = [os.path.join(data_path, '{}.{}.{}'.format(split, lbl, src)) for lbl in labels]
    src_datasets = _load_dataset(src_paths, dataset_impl)
    src_datasets = [TruncateDataset(ds, src_max_tokens) for ds in src_datasets]
    src_datasets = ConcatDataset(src_datasets, labels=labels)

    if len(labels) == 1 and os.path.exists(os.path.join(data_path, '{}.{}.idx'.format(split, tgt))):
        tgt_paths = [os.path.join(data_path, '{}.{}'.format(split, tgt))]
    else:
        tgt_paths = [os.path.join(data_path, '{}.{}.{}'.format(split, lbl, tgt)) for lbl in labels]
    tgt_datasets = _load_dataset(tgt_paths, dataset_impl)
    tgt_datasets = [TruncateDataset(ds, tgt_max_tokens) for ds in tgt_datasets]
    tgt_datasets = ConcatDataset(tgt_datasets, labels=labels)

    LOGGER.info('loaded {} examples from: {}'.format(len(src_datasets), src_paths))
    LOGGER.info('loaded {} examples from: {}'.format(len(tgt_datasets), tgt_paths))

    if split == 'train' and src_aux is not None:
        if len(labels) == 1 and os.path.exists(os.path.join(data_path, '{}.{}.idx'.format(split, src_aux))):
            src_aux_paths = [os.path.join(data_path, '{}.{}'.format(split, src_aux))]
        else:
            src_aux_paths = [os.path.join(data_path, '{}.{}.{}'.format(split, lbl, src_aux)) for lbl in labels]
        src_aux_datasets = _load_dataset(src_aux_paths, dataset_impl)
        if src_aux_max_tokens is None:
            src_aux_max_tokens = src_max_tokens
        src_aux_datasets = [TruncateDataset(ds, src_aux_max_tokens) for ds in src_aux_datasets]
        src_aux_datasets = ConcatDataset(src_aux_datasets, labels=labels)
        LOGGER.info('loaded {} examples from: {}'.format(len(src_aux_datasets), src_aux_paths))
    else:
        src_aux_datasets = None

    if split == 'train' and tgt_aux is not None:
        if len(labels) == 1 and os.path.exists(os.path.join(data_path, '{}.{}.idx'.format(split, tgt_aux))):
            tgt_aux_paths = [os.path.join(data_path, '{}.{}'.format(split, tgt_aux))]
        else:
            tgt_aux_paths = [os.path.join(data_path, '{}.{}.{}'.format(split, lbl, tgt_aux)) for lbl in labels]
        tgt_aux_datasets = _load_dataset(tgt_aux_paths, dataset_impl)
        if tgt_aux_max_tokens is None:
            tgt_aux_max_tokens = tgt_max_tokens
        tgt_aux_datasets = [TruncateDataset(ds, tgt_aux_max_tokens) for ds in tgt_aux_datasets]
        tgt_aux_datasets = ConcatDataset(tgt_aux_datasets, labels=labels)
        LOGGER.info('loaded {} examples from: {}'.format(len(tgt_aux_datasets), tgt_aux_paths))
    else:
        tgt_aux_datasets = None

    return HybridRetrievalDataset(
        src_datasets, src_datasets.sizes, src_dict,
        tgt_datasets, tgt_datasets.sizes, tgt_dict,
        max_source_positions=src_max_tokens, max_target_positions=tgt_max_tokens,
        src_aux=src_aux_datasets,
        src_aux_sizes=None if src_aux_datasets is None else src_aux_datasets.sizes,
        src_aux_dict=src_dict if src_aux_dict is None else src_aux_dict,
        tgt_aux=tgt_aux_datasets,
        tgt_aux_sizes=None if tgt_aux_datasets is None else tgt_aux_datasets.sizes,
        tgt_aux_dict=tgt_dict if tgt_aux_dict is None else tgt_aux_dict,
        fraction_using_func_name=fraction_using_func_name,
        shuffle=shuffle,
        labels=labels,
    )
コード例 #19
0
def load_token_dataset(
    data_path,
    split,
    tgt,
    tgt_dict,
    dataset_impl,
    attrs=None,
    attr_dict=None,
    attrs_mapping=None,
    reversed_attrs_mapping=None,
    truncate_target=False,
    max_target_positions=None,
    # lifelong learning
    prev_tasks=[],
    cur_task=None,
    sample_portion=None,
):
    # load tokens
    tgt_path = os.path.join(data_path, cur_task, '{}.{}'.format(split, tgt))
    tgt_dataset = [_load_dataset(tgt_path, dataset_impl)]
    if len(prev_tasks
           ) > 0 and cur_task is not None and sample_portion is not None:
        sample_size_per_task = int(
            len(tgt_dataset[0]) * sample_portion // len(prev_tasks))
    else:
        sample_size_per_task = -1
    if sample_size_per_task > 0:
        for p_task in prev_tasks:
            p_path = os.path.join(data_path, p_task,
                                  '{}.{}'.format(split, tgt))
            p_dataset = _load_dataset(p_path, dataset_impl)
            tgt_dataset.append(
                SliceDataset(p_dataset, end=sample_size_per_task))
    tgt_dataset = ConcatDataset(tgt_dataset)
    if truncate_target:
        tgt_dataset = TruncateDataset(tgt_dataset, max_target_positions)
        LOGGER.info('Truncate dataset into max length: {}'.format(
            max_target_positions))
    LOGGER.info('loaded {} examples from: [{}](current task) + {}(previous tasks)'. \
                format(len(tgt_dataset), cur_task, prev_tasks))
    # load tokens.ext
    tgt_ext_path = os.path.join(data_path, cur_task,
                                '{}.{}.ext'.format(split, tgt))
    if indexed_dataset.SeqIndexedDataset.exists(tgt_ext_path):
        tgt_ext_dataset = indexed_dataset.SeqIndexedDataset(tgt_ext_path)
        if sample_size_per_task > 0:
            for p_task in prev_tasks:
                p_ext_path = os.path.join(data_path, p_task,
                                          '{}.{}.ext'.format(split, tgt))
                p_ext_dataset = indexed_dataset.SeqIndexedDataset(p_ext_path)
                p_ext_dataset.truncate(end=sample_size_per_task)
                tgt_ext_dataset.append(p_ext_dataset)
        if truncate_target:
            tgt_ext_dataset.clip(max_position=max_target_positions)
        assert len(tgt_dataset) == len(tgt_ext_dataset), (len(tgt_dataset),
                                                          len(tgt_ext_dataset))
    else:
        tgt_ext_dataset = None
    # load attrs
    if attrs is None:
        attr_dataset = None
    else:
        attr_path = os.path.join(data_path, cur_task,
                                 '{}.code_types'.format(split))
        attr_dataset = [_load_dataset(attr_path, dataset_impl)]
        if sample_size_per_task > 0:
            for p_task in prev_tasks:
                p_attr_path = os.path.join(data_path, p_task,
                                           '{}.code_types'.format(split))
                p_attr_dataset = _load_dataset(p_attr_path, dataset_impl)
                attr_dataset.append(
                    SliceDataset(p_attr_dataset, end=sample_size_per_task))
        attr_dataset = ConcatDataset(attr_dataset)
        if truncate_target:
            tgt_dataset = TruncateDataset(tgt_dataset, max_target_positions)
            LOGGER.info(
                'Truncate dataset\'s attributes into max length: {}'.format(
                    max_target_positions))
        LOGGER.info('loaded attributes {} examples from: [{}](current task) + {}(previous tasks)'. \
                    format(len(attr_dataset), cur_task, prev_tasks))
        # load attr.ext
        attr_ext_path = os.path.join(data_path, cur_task,
                                     '{}.code_types.ext'.format(split))
        if indexed_dataset.SeqIndexedDataset.exists(attr_ext_path):
            attr_ext_dataset = indexed_dataset.SeqIndexedDataset(attr_ext_path)
            if sample_size_per_task > 0:
                for p_task in prev_tasks:
                    p_attr_ext_path = os.path.join(
                        data_path, p_task, '{}.code_types.ext'.format(split))
                    p_attr_ext_dataset = indexed_dataset.SeqIndexedDataset(
                        p_attr_ext_path)
                    p_attr_ext_dataset.truncate(end=sample_size_per_task)
                    attr_ext_dataset.append(p_attr_ext_dataset)
            if truncate_target:
                attr_ext_dataset.clip(max_position=max_target_positions)
            assert np.all(tgt_ext_dataset == attr_ext_dataset)
            del attr_ext_dataset

    return CompletionDataset(
        tgt_dataset,
        tgt_dataset.sizes,
        tgt_dict,
        extends=tgt_ext_dataset,
        attrs=attrs,
        attr_indices=attr_dataset,
        attr_dict=attr_dict,
        attrs_mapping=attrs_mapping,
        reversed_attrs_mapping=reversed_attrs_mapping,
        max_target_positions=max_target_positions,
    )
コード例 #20
0
def load_langpair_dataset(
    data_path,
    split,
    domains,
    src,
    src_dict,
    tgt,
    tgt_dict,
    combine,
    dataset_impl,
    upsample_primary,
    left_pad_source,
    left_pad_target,
    max_source_positions,
    max_target_positions,
    prepend_bos=False,
    load_alignments=False,
    truncate_source=False,
    append_source_id=False,
    append_eos_to_target=False,
):
    def split_exists(split, src, data_path, domain):
        filename = os.path.join(data_path, domain,
                                '{}.{}'.format(split,
                                               src))  # -{}.{} , tgt, lang
        return indexed_dataset.dataset_exists(filename, impl=dataset_impl)

    src_datasets = []
    tgt_datasets = []

    for dm in domains:
        # load datasets of src domains
        for k in itertools.count():
            split_k = split + (str(k) if k > 0 else '')

            # infer langcode
            if split_exists(split_k, src, data_path, dm):
                prefix = os.path.join(
                    data_path, dm, '{}.'.format(split_k))  # {}-{}. , src, tgt
            elif split_exists(split_k, tgt, data_path, dm):
                prefix = os.path.join(
                    data_path, dm, '{}.'.format(split_k))  # {}-{}. , tgt, src

            else:
                if k > 0:
                    break
                else:
                    raise FileNotFoundError(
                        'Dataset not found: {} ({})'.format(split, data_path))

            src_dataset = data_utils.load_indexed_dataset(
                prefix + src, src, src_dict, dataset_impl)
            if truncate_source:
                src_dataset = AppendTokenDataset(
                    TruncateDataset(
                        StripTokenDataset(src_dataset, src_dict.eos()),
                        max_source_positions - 1,
                    ),
                    src_dict.eos(),
                )
            src_datasets.append(src_dataset)

            tgt_dataset = data_utils.load_indexed_dataset(
                prefix + tgt, tgt, tgt_dict, dataset_impl)
            if tgt_dataset is not None:
                tgt_datasets.append(tgt_dataset)

            if not combine:
                break

    assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0

    if len(src_datasets) == 1:
        src_dataset = src_datasets[0]
        tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None
    else:
        sample_ratios = [1] * len(src_datasets)
        sample_ratios[0] = upsample_primary
        src_dataset = ConcatDataset(src_datasets, sample_ratios)
        if len(tgt_datasets) > 0:
            tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
        else:
            tgt_dataset = None

    if prepend_bos:
        assert hasattr(src_dict, "bos_index") and hasattr(
            tgt_dict, "bos_index")
        src_dataset = PrependTokenDataset(src_dataset, src_dict.bos())
        if tgt_dataset is not None:
            tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos())

    eos = None
    if append_source_id:
        src_dataset = AppendTokenDataset(src_dataset,
                                         src_dict.index('[{}]'.format(src)))
        if tgt_dataset is not None:
            tgt_dataset = AppendTokenDataset(
                tgt_dataset, tgt_dict.index('[{}]'.format(tgt)))
        eos = tgt_dict.index('[{}]'.format(tgt))

    # align_dataset = None
    # if load_alignments:
    #     align_path = os.path.join(data_path, '{}.align.{}-{}'.format(split, src, tgt))
    #     if indexed_dataset.dataset_exists(align_path, impl=dataset_impl):
    #         align_dataset = data_utils.load_indexed_dataset(align_path, None, dataset_impl)

    tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None
    return LanguagePairDataset(
        src_dataset,
        src_dataset.sizes,
        src_dict,
        tgt_dataset,
        tgt_dataset_sizes,
        tgt_dict,
        left_pad_source=left_pad_source,
        left_pad_target=left_pad_target,
        max_source_positions=max_source_positions,
        max_target_positions=max_target_positions,
        align_dataset=None,
        eos=eos,
        remove_eos_from_source=True,
        append_eos_to_target=append_eos_to_target,
        shuffle=True,  # TODO debug: shuffle=False
    )
コード例 #21
0
def load_langpair_dataset(
    data_path,
    split,
    src,
    src_dict,
    tgt,
    tgt_dict,
    combine,
    dataset_impl,
    upsample_primary,
    left_pad_source,
    left_pad_target,
    max_source_positions,
    max_target_positions,
    load_alignments=False,
    truncate_source=False,
    truncate_target=False,
):
    def split_exists(split, src, data_path):
        filename = os.path.join(data_path,
                                '{}.{}'.format(split,
                                               src))  # -{}.{} , tgt, lang
        return indexed_dataset.dataset_exists(filename, impl=dataset_impl)

    src_datasets = []
    tgt_datasets = []

    for k in itertools.count():
        split_k = split + (str(k) if k > 0 else '')

        # infer langcode
        if split_exists(split_k, src, data_path):
            prefix = os.path.join(data_path,
                                  '{}.'.format(split_k))  # {}-{}. , src, tgt
        elif split_exists(split_k, tgt, data_path):
            prefix = os.path.join(data_path,
                                  '{}.'.format(split_k))  # {}-{}. , tgt, src

        else:
            if k > 0:
                break
            else:
                raise FileNotFoundError('Dataset not found: {} ({})'.format(
                    split, data_path))

        src_dataset = data_utils.load_indexed_dataset(prefix + src, 'text',
                                                      src_dict, dataset_impl)
        if truncate_source and max_source_positions:
            src_dataset = TruncateDataset(src_dataset, max_source_positions)
        src_datasets.append(src_dataset)

        tgt_dataset = data_utils.load_indexed_dataset(prefix + tgt, 'text',
                                                      tgt_dict, dataset_impl)
        if truncate_target and max_target_positions:
            tgt_dataset = PrependTokenDataset(
                AppendTokenDataset(TruncateDataset(tgt_dataset,
                                                   max_target_positions - 2),
                                   token=tgt_dict.eos()), tgt_dict.bos())

        if tgt_dataset is not None:
            tgt_datasets.append(tgt_dataset)

        if not combine:
            break

    assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0

    if len(src_datasets) == 1:
        src_dataset = src_datasets[0]
        tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None
    else:
        sample_ratios = [1] * len(src_datasets)
        sample_ratios[0] = upsample_primary
        src_dataset = ConcatDataset(src_datasets, sample_ratios)
        if len(tgt_datasets) > 0:
            tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
        else:
            tgt_dataset = None

    tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None
    return LanguagePairDataset(
        src_dataset,
        src_dataset.sizes,
        src_dict,
        tgt_dataset,
        tgt_dataset_sizes,
        tgt_dict,
        left_pad_source=left_pad_source,
        left_pad_target=left_pad_target,
        max_source_positions=max_source_positions,
        max_target_positions=max_target_positions,
        shuffle=True,  # TODO debug: shuffle=False
    )