Exemple #1
0
def load_tokens_dataset(
    data_path, split, srcs, src_dicts, tgts, tgt_dicts, dataset_impl,
    max_srcs=None, max_tgts=None,
    shuffle=False, sample_neg=False, **kwargs,
):
    langs = kwargs.get('langs', None)
    # source langs
    src_paths = OrderedDict()
    for src in srcs:
        if langs is None:
            src_paths[src] = [os.path.join(data_path, f'{split}.{src}')]
        else:
            src_paths[src] = [os.path.join(data_path, f'{split}.{lang}.{src}') for lang in langs]
    src_datasets, src_sizes = OrderedDict(), OrderedDict()
    for idx, (src, paths) in enumerate(src_paths.items()):
        datasets = _load_dataset(paths, dataset_impl)
        if max_srcs is not None:
            datasets = [TruncateDataset(ds, max_srcs[idx]) for ds in datasets]
        datasets = ConcatDataset(datasets, labels=langs)
        src_datasets[src] = datasets
        src_sizes[src] = datasets.sizes
    LOGGER.info('loaded {} modality(ies) from: {}'.format(len(src_datasets), src_paths))
    # target langs
    tgt_paths = OrderedDict()
    for tgt in tgts:
        if langs is None:
            tgt_paths[tgt] = [os.path.join(data_path, f'{split}.{tgt}')]
        else:
            tgt_paths[tgt] = [os.path.join(data_path, f'{split}.{lang}.{tgt}') for lang in langs]
    tgt_datasets, tgt_sizes = OrderedDict(), OrderedDict()
    for idx, (tgt, paths) in enumerate(tgt_paths.items()):
        datasets = _load_dataset(paths, dataset_impl)
        if max_tgts is not None:
            datasets = [TruncateDataset(ds, max_tgts[idx]) for ds in datasets]
        datasets = ConcatDataset(datasets, labels=langs)
        tgt_datasets[tgt] = datasets
        tgt_sizes[tgt] = datasets.sizes
    LOGGER.info('loaded {} modality(ies) from: {}'.format(len(tgt_datasets), tgt_paths))

    return MultiModalitiesRetrievalDataset(
        src_datasets, src_sizes, src_dicts,
        tgt_datasets, tgt_sizes, tgt_dicts,
        max_source_positions=max_srcs, max_target_positions=max_tgts,
        fraction_using_func_name=kwargs.get('fraction_using_func_name', None),
        shuffle=shuffle,
        labels=langs, sample_neg=sample_neg,
    )
def load_inference_token_dataset(
    data_paths,
    split,
    tgt,
    tgt_dict,
    dataset_impl,
    attrs=None,
    attr_dict=None,
    attrs_mapping=None,
    reversed_attrs_mapping=None,
    truncate_target=False,
    max_target_positions=None,
):
    # load tokens
    tgt_dataset = []
    for path in data_paths:
        tgt_path = os.path.join(path, '{}.{}'.format(split, tgt))
        tgt_dataset.append(_load_dataset(tgt_path, dataset_impl))
    tgt_dataset = ConcatDataset(tgt_dataset)
    if truncate_target:
        tgt_dataset = TruncateDataset(tgt_dataset, max_target_positions)
        LOGGER.info('Truncate dataset into max length: {}'.format(
            max_target_positions))
    LOGGER.info('loaded {} examples from: {}'.format(len(tgt_dataset),
                                                     data_paths))
    return CompletionDataset(
        tgt_dataset,
        tgt_dataset.sizes,
        tgt_dict,
        extends=None,
        attrs=None,
        attr_indices=None,
        attr_dict=None,
        attrs_mapping=None,
        reversed_attrs_mapping=None,
        max_target_positions=max_target_positions,
    )
Exemple #3
0
    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        # paths = self.args.data.split(':')
        paths = utils.split_paths(self.args['task']['data'])
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]
        # split_path = os.path.join(data_path, split)

        if self.langs is None:
            languages = sorted([
                name for name in os.listdir(data_path)
                if os.path.isdir(os.path.join(data_path, name))
            ])
        else:
            languages = self.langs  # .split(',')
            # for name in languages:
            #     assert os.path.exists(os.path.join(data_path, name)), FileNotFoundError(os.path.join(data_path, name))

        LOGGER.info("| Training on {0} languages: {1}".format(
            len(languages), languages))
        LOGGER.info("| Language to id mapping: ",
                    {lang: id
                     for id, lang in enumerate(languages)})

        mask_whole_words = get_whole_word_mask(self.args, self.dictionary)
        lang_datasets = []
        for language in languages:
            # split_path = os.path.join(data_path, language, split)
            if language == 'docstring':
                split_path = os.path.join(data_path, language,
                                          f"{split}.docstring.spm")
            else:
                split_path = os.path.join(data_path, language,
                                          f"{split}.code.spm")
            # split_path = os.path.join(data_path, language, f"{split}.spm.{language}")
            # dataset = data_utils.load_indexed_dataset(
            #     split_path,
            #     self.source_dictionary,
            #     self.args['dataset']['dataset_impl'],
            #     combine=combine,
            # )
            dataset = load_lang_dataset_denoising(
                path=split_path,
                impl=self.args['dataset']['dataset_impl'],
                dict=self.source_dictionary)

            if dataset is None:
                raise FileNotFoundError('Dataset not found: {} ({})'.format(
                    split, split_path))

            dataset = AppendTokenDataset(
                TruncateDataset(
                    StripTokenDataset(dataset, self.source_dictionary.eos()),
                    self.args['task']['max_source_positions'] -
                    3),  # <lang>, <bos>, <eos>
                token=self.source_dictionary.eos(),
            )

            end_token = self.source_dictionary.index('[{}]'.format(language)) \
                if self.args['task']['add_lang_token'] else self.source_dictionary.eos()

            # create continuous blocks of tokens
            dataset = TokenBlockDataset(
                dataset,
                dataset.sizes,
                self.args['task']['tokens_per_sample'] -
                2,  # one less for <s> and one for </s>
                pad=self.source_dictionary.pad(),
                eos=end_token,
                break_mode=self.args['task']['sample_break_mode'],
                document_sep_len=0,
            )
            LOGGER.info('| loaded {} blocks from: {}'.format(
                len(dataset), split_path))

            # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
            dataset = PrependTokenDataset(dataset,
                                          self.source_dictionary.bos())
            dataset = AppendTokenDataset(dataset, end_token)

            lang_dataset = DenoisingDataset(
                dataset,
                dataset.sizes,
                self.dictionary,
                self.mask_idx,
                mask_whole_words,
                shuffle=self.args['dataset']['shuffle_instance'],
                seed=self.seed,
                args=self.args,
                eos=None if not self.args['task']['add_lang_token'] else
                self.source_dictionary.index('[{}]'.format(language)),
            )
            lang_datasets.append(lang_dataset)

        dataset_lengths = np.array(
            [len(d) for d in lang_datasets],
            dtype=float,
        )
        LOGGER.info('| loaded total {} blocks for all languages'.format(
            dataset_lengths.sum(), ))
        if split == self.args['dataset']['train_subset']:
            # For train subset, additionally up or down sample languages.
            sample_probs = self._get_sample_prob(dataset_lengths)
            LOGGER.info(
                "| Sample probability by language: ", {
                    lang: "{0:.4f}".format(sample_probs[id])
                    for id, lang in enumerate(languages)
                })
            size_ratio = (sample_probs *
                          dataset_lengths.sum()) / dataset_lengths
            LOGGER.info(
                "| Up/Down Sampling ratio by language: ", {
                    lang: "{0:.2f}".format(size_ratio[id])
                    for id, lang in enumerate(languages)
                })

            resampled_lang_datasets = [
                ResamplingDataset(
                    lang_datasets[i],
                    size_ratio=size_ratio[i],
                    seed=self.args['common']['seed'],
                    epoch=epoch,
                    replace=size_ratio[i] >= 1.0,
                ) for i, d in enumerate(lang_datasets)
            ]
            dataset = ConcatDataset(resampled_lang_datasets, )
        else:
            dataset = ConcatDataset(lang_datasets)
            lang_splits = [split]
            # for lang_id, lang_dataset in enumerate(lang_datasets):
            #     split_name = split + '_' + languages[lang_id]
            #     lang_splits.append(split_name)
            #     self.datasets[split_name] = lang_dataset

            if split in self.args['dataset']['valid_subset']:
                self.args['dataset']['valid_subset'] = self.args['dataset'][
                    'valid_subset'].replace(split, ','.join(lang_splits))

        with data_utils.numpy_seed(self.args['common']['seed'] + epoch):
            shuffle = np.random.permutation(len(dataset))

        self.datasets[split] = SortDataset(
            dataset,
            sort_order=[
                shuffle,
                dataset.sizes,
            ],
        )
Exemple #4
0
def load_tokens_dataset(
    data_path, split, src, src_dict, tgt, tgt_dict, dataset_impl, src_max_tokens=None, tgt_max_tokens=None,
    src_aux=None, src_aux_dict=None, tgt_aux=None, tgt_aux_dict=None, src_aux_max_tokens=None, tgt_aux_max_tokens=None,
    fraction_using_func_name=0., labels=None, shuffle=False,
):
    if len(labels) == 1 and os.path.exists(os.path.join(data_path, '{}.{}.idx'.format(split, src))):
        src_paths = [os.path.join(data_path, '{}.{}'.format(split, src))]
    else:
        src_paths = [os.path.join(data_path, '{}.{}.{}'.format(split, lbl, src)) for lbl in labels]
    src_datasets = _load_dataset(src_paths, dataset_impl)
    src_datasets = [TruncateDataset(ds, src_max_tokens) for ds in src_datasets]
    src_datasets = ConcatDataset(src_datasets, labels=labels)

    if len(labels) == 1 and os.path.exists(os.path.join(data_path, '{}.{}.idx'.format(split, tgt))):
        tgt_paths = [os.path.join(data_path, '{}.{}'.format(split, tgt))]
    else:
        tgt_paths = [os.path.join(data_path, '{}.{}.{}'.format(split, lbl, tgt)) for lbl in labels]
    tgt_datasets = _load_dataset(tgt_paths, dataset_impl)
    tgt_datasets = [TruncateDataset(ds, tgt_max_tokens) for ds in tgt_datasets]
    tgt_datasets = ConcatDataset(tgt_datasets, labels=labels)

    LOGGER.info('loaded {} examples from: {}'.format(len(src_datasets), src_paths))
    LOGGER.info('loaded {} examples from: {}'.format(len(tgt_datasets), tgt_paths))

    if split == 'train' and src_aux is not None:
        if len(labels) == 1 and os.path.exists(os.path.join(data_path, '{}.{}.idx'.format(split, src_aux))):
            src_aux_paths = [os.path.join(data_path, '{}.{}'.format(split, src_aux))]
        else:
            src_aux_paths = [os.path.join(data_path, '{}.{}.{}'.format(split, lbl, src_aux)) for lbl in labels]
        src_aux_datasets = _load_dataset(src_aux_paths, dataset_impl)
        if src_aux_max_tokens is None:
            src_aux_max_tokens = src_max_tokens
        src_aux_datasets = [TruncateDataset(ds, src_aux_max_tokens) for ds in src_aux_datasets]
        src_aux_datasets = ConcatDataset(src_aux_datasets, labels=labels)
        LOGGER.info('loaded {} examples from: {}'.format(len(src_aux_datasets), src_aux_paths))
    else:
        src_aux_datasets = None

    if split == 'train' and tgt_aux is not None:
        if len(labels) == 1 and os.path.exists(os.path.join(data_path, '{}.{}.idx'.format(split, tgt_aux))):
            tgt_aux_paths = [os.path.join(data_path, '{}.{}'.format(split, tgt_aux))]
        else:
            tgt_aux_paths = [os.path.join(data_path, '{}.{}.{}'.format(split, lbl, tgt_aux)) for lbl in labels]
        tgt_aux_datasets = _load_dataset(tgt_aux_paths, dataset_impl)
        if tgt_aux_max_tokens is None:
            tgt_aux_max_tokens = tgt_max_tokens
        tgt_aux_datasets = [TruncateDataset(ds, tgt_aux_max_tokens) for ds in tgt_aux_datasets]
        tgt_aux_datasets = ConcatDataset(tgt_aux_datasets, labels=labels)
        LOGGER.info('loaded {} examples from: {}'.format(len(tgt_aux_datasets), tgt_aux_paths))
    else:
        tgt_aux_datasets = None

    return HybridRetrievalDataset(
        src_datasets, src_datasets.sizes, src_dict,
        tgt_datasets, tgt_datasets.sizes, tgt_dict,
        max_source_positions=src_max_tokens, max_target_positions=tgt_max_tokens,
        src_aux=src_aux_datasets,
        src_aux_sizes=None if src_aux_datasets is None else src_aux_datasets.sizes,
        src_aux_dict=src_dict if src_aux_dict is None else src_aux_dict,
        tgt_aux=tgt_aux_datasets,
        tgt_aux_sizes=None if tgt_aux_datasets is None else tgt_aux_datasets.sizes,
        tgt_aux_dict=tgt_dict if tgt_aux_dict is None else tgt_aux_dict,
        fraction_using_func_name=fraction_using_func_name,
        shuffle=shuffle,
        labels=labels,
    )
def load_kd_token_dataset(
    data_path,
    split,
    tgt,
    tgt_dict,
    dataset_impl,
    attrs=None,
    attr_dict=None,
    attrs_mapping=None,
    reversed_attrs_mapping=None,
    truncate_target=False,
    max_target_positions=None,
    # lifelong learning
    prev_tasks=[],
    cur_task=None,
    sample_portion=None,
    # kd
    teacher_out_dir=None,
    topk=None,
    distill_topk=None,
):
    # load tokens
    tgt_path = os.path.join(data_path, cur_task, '{}.{}'.format(split, tgt))
    tgt_dataset = [_load_dataset(tgt_path, dataset_impl)]
    # teacher output
    topk_ids_prefix = os.path.join(str.replace(teacher_out_dir, '*', cur_task),
                                   f'{split}.top{topk}_idx')
    topk_ids = [TeacherOutDataset(topk_ids_prefix)]
    topk_probs_prefix = os.path.join(
        str.replace(teacher_out_dir, '*', cur_task), f'{split}.top{topk}_prob')
    topk_probs = [TeacherOutDataset(topk_probs_prefix)]

    if len(prev_tasks
           ) > 0 and cur_task is not None and sample_portion is not None:
        sample_size_per_task = int(
            len(tgt_dataset[0]) * sample_portion // len(prev_tasks))
    else:
        sample_size_per_task = -1
    if sample_size_per_task > 0:
        for p_task in prev_tasks:
            p_path = os.path.join(data_path, p_task,
                                  '{}.{}'.format(split, tgt))
            p_dataset = _load_dataset(p_path, dataset_impl)
            tgt_dataset.append(
                SliceDataset(p_dataset, end=sample_size_per_task))
            topk_ids.append(
                PlaceholderDataset(placeholder=None,
                                   length=sample_size_per_task))
            topk_probs.append(
                PlaceholderDataset(placeholder=None,
                                   length=sample_size_per_task))
    tgt_dataset = ConcatDataset(tgt_dataset)
    topk_ids = ConcatDataset(topk_ids)
    topk_probs = ConcatDataset(topk_probs)
    if truncate_target:
        tgt_dataset = TruncateDataset(tgt_dataset, max_target_positions)
        LOGGER.info('Truncate dataset into max length: {}'.format(
            max_target_positions))
    LOGGER.info('loaded {} examples from: [{}](current task) + {}(previous tasks)'. \
                format(len(tgt_dataset), cur_task, prev_tasks))

    return TopkKDCompletionDataset(
        topk_ids=topk_ids,
        topk_probs=topk_probs,
        topk=topk,
        distill_topk=distill_topk,
        tgt=tgt_dataset,
        tgt_sizes=tgt_dataset.sizes,
        tgt_dict=tgt_dict,
        extends=None,
        attrs=None,
        attr_indices=None,
        attr_dict=None,
        attrs_mapping=None,
        reversed_attrs_mapping=None,
        max_target_positions=max_target_positions,
        shuffle=(split == 'train'),
    )
Exemple #6
0
def load_langpair_dataset(
    data_path,
    split,
    domains,
    src,
    src_dict,
    tgt,
    tgt_dict,
    combine,
    dataset_impl,
    upsample_primary,
    left_pad_source,
    left_pad_target,
    max_source_positions,
    max_target_positions,
    prepend_bos=False,
    load_alignments=False,
    truncate_source=False,
    append_source_id=False,
    append_eos_to_target=False,
):
    def split_exists(split, src, data_path, domain):
        filename = os.path.join(data_path, domain,
                                '{}.{}'.format(split,
                                               src))  # -{}.{} , tgt, lang
        return indexed_dataset.dataset_exists(filename, impl=dataset_impl)

    src_datasets = []
    tgt_datasets = []

    for dm in domains:
        # load datasets of src domains
        for k in itertools.count():
            split_k = split + (str(k) if k > 0 else '')

            # infer langcode
            if split_exists(split_k, src, data_path, dm):
                prefix = os.path.join(
                    data_path, dm, '{}.'.format(split_k))  # {}-{}. , src, tgt
            elif split_exists(split_k, tgt, data_path, dm):
                prefix = os.path.join(
                    data_path, dm, '{}.'.format(split_k))  # {}-{}. , tgt, src

            else:
                if k > 0:
                    break
                else:
                    raise FileNotFoundError(
                        'Dataset not found: {} ({})'.format(split, data_path))

            src_dataset = data_utils.load_indexed_dataset(
                prefix + src, src, src_dict, dataset_impl)
            if truncate_source:
                src_dataset = AppendTokenDataset(
                    TruncateDataset(
                        StripTokenDataset(src_dataset, src_dict.eos()),
                        max_source_positions - 1,
                    ),
                    src_dict.eos(),
                )
            src_datasets.append(src_dataset)

            tgt_dataset = data_utils.load_indexed_dataset(
                prefix + tgt, tgt, tgt_dict, dataset_impl)
            if tgt_dataset is not None:
                tgt_datasets.append(tgt_dataset)

            if not combine:
                break

    assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0

    if len(src_datasets) == 1:
        src_dataset = src_datasets[0]
        tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None
    else:
        sample_ratios = [1] * len(src_datasets)
        sample_ratios[0] = upsample_primary
        src_dataset = ConcatDataset(src_datasets, sample_ratios)
        if len(tgt_datasets) > 0:
            tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
        else:
            tgt_dataset = None

    if prepend_bos:
        assert hasattr(src_dict, "bos_index") and hasattr(
            tgt_dict, "bos_index")
        src_dataset = PrependTokenDataset(src_dataset, src_dict.bos())
        if tgt_dataset is not None:
            tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos())

    eos = None
    if append_source_id:
        src_dataset = AppendTokenDataset(src_dataset,
                                         src_dict.index('[{}]'.format(src)))
        if tgt_dataset is not None:
            tgt_dataset = AppendTokenDataset(
                tgt_dataset, tgt_dict.index('[{}]'.format(tgt)))
        eos = tgt_dict.index('[{}]'.format(tgt))

    # align_dataset = None
    # if load_alignments:
    #     align_path = os.path.join(data_path, '{}.align.{}-{}'.format(split, src, tgt))
    #     if indexed_dataset.dataset_exists(align_path, impl=dataset_impl):
    #         align_dataset = data_utils.load_indexed_dataset(align_path, None, dataset_impl)

    tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None
    return LanguagePairDataset(
        src_dataset,
        src_dataset.sizes,
        src_dict,
        tgt_dataset,
        tgt_dataset_sizes,
        tgt_dict,
        left_pad_source=left_pad_source,
        left_pad_target=left_pad_target,
        max_source_positions=max_source_positions,
        max_target_positions=max_target_positions,
        align_dataset=None,
        eos=eos,
        remove_eos_from_source=True,
        append_eos_to_target=append_eos_to_target,
        shuffle=True,  # TODO debug: shuffle=False
    )
Exemple #7
0
def load_token_dataset(
    data_path,
    split,
    tgt,
    tgt_dict,
    dataset_impl,
    attrs=None,
    attr_dict=None,
    attrs_mapping=None,
    reversed_attrs_mapping=None,
    truncate_target=False,
    max_target_positions=None,
    # lifelong learning
    prev_tasks=[],
    cur_task=None,
    sample_portion=None,
):
    # load tokens
    tgt_path = os.path.join(data_path, cur_task, '{}.{}'.format(split, tgt))
    tgt_dataset = [_load_dataset(tgt_path, dataset_impl)]
    kd_ids = [PlaceholderDataset(placeholder=True, length=len(tgt_dataset[0]))]

    if len(prev_tasks
           ) > 0 and cur_task is not None and sample_portion is not None:
        sample_size_per_task = int(
            len(tgt_dataset[0]) * sample_portion // len(prev_tasks))
    else:
        sample_size_per_task = -1
    if sample_size_per_task > 0:
        for p_task in prev_tasks:
            p_path = os.path.join(data_path, p_task,
                                  '{}.{}'.format(split, tgt))
            p_dataset = _load_dataset(p_path, dataset_impl)
            tgt_dataset.append(
                SliceDataset(p_dataset, end=sample_size_per_task))
            kd_ids.append(
                PlaceholderDataset(placeholder=False,
                                   length=sample_size_per_task))
    tgt_dataset = ConcatDataset(tgt_dataset)
    kd_ids = ConcatDataset(kd_ids)
    if truncate_target:
        tgt_dataset = TruncateDataset(tgt_dataset, max_target_positions)
        LOGGER.info('Truncate dataset into max length: {}'.format(
            max_target_positions))
    LOGGER.info('loaded {} examples from: [{}](current task) + {}(previous tasks)'. \
                format(len(tgt_dataset), cur_task, prev_tasks))

    return LifelongKDCompletionDataset(
        kd_indices=kd_ids,
        tgt=tgt_dataset,
        tgt_sizes=tgt_dataset.sizes,
        tgt_dict=tgt_dict,
        extends=None,
        attrs=None,
        attr_indices=None,
        attr_dict=None,
        attrs_mapping=None,
        reversed_attrs_mapping=None,
        max_target_positions=max_target_positions,
        shuffle=(split == 'train'),
    )
Exemple #8
0
def load_token_dataset(
    data_path,
    split,
    tgt,
    tgt_dict,
    dataset_impl,
    attrs=None,
    attr_dict=None,
    attrs_mapping=None,
    reversed_attrs_mapping=None,
    truncate_target=False,
    max_target_positions=None,
    # lifelong learning
    prev_tasks=[],
    cur_task=None,
    sample_portion=None,
):
    # load tokens
    tgt_path = os.path.join(data_path, cur_task, '{}.{}'.format(split, tgt))
    tgt_dataset = [_load_dataset(tgt_path, dataset_impl)]
    if len(prev_tasks
           ) > 0 and cur_task is not None and sample_portion is not None:
        sample_size_per_task = int(
            len(tgt_dataset[0]) * sample_portion // len(prev_tasks))
    else:
        sample_size_per_task = -1
    if sample_size_per_task > 0:
        for p_task in prev_tasks:
            p_path = os.path.join(data_path, p_task,
                                  '{}.{}'.format(split, tgt))
            p_dataset = _load_dataset(p_path, dataset_impl)
            tgt_dataset.append(
                SliceDataset(p_dataset, end=sample_size_per_task))
    tgt_dataset = ConcatDataset(tgt_dataset)
    if truncate_target:
        tgt_dataset = TruncateDataset(tgt_dataset, max_target_positions)
        LOGGER.info('Truncate dataset into max length: {}'.format(
            max_target_positions))
    LOGGER.info('loaded {} examples from: [{}](current task) + {}(previous tasks)'. \
                format(len(tgt_dataset), cur_task, prev_tasks))
    # load tokens.ext
    tgt_ext_path = os.path.join(data_path, cur_task,
                                '{}.{}.ext'.format(split, tgt))
    if indexed_dataset.SeqIndexedDataset.exists(tgt_ext_path):
        tgt_ext_dataset = indexed_dataset.SeqIndexedDataset(tgt_ext_path)
        if sample_size_per_task > 0:
            for p_task in prev_tasks:
                p_ext_path = os.path.join(data_path, p_task,
                                          '{}.{}.ext'.format(split, tgt))
                p_ext_dataset = indexed_dataset.SeqIndexedDataset(p_ext_path)
                p_ext_dataset.truncate(end=sample_size_per_task)
                tgt_ext_dataset.append(p_ext_dataset)
        if truncate_target:
            tgt_ext_dataset.clip(max_position=max_target_positions)
        assert len(tgt_dataset) == len(tgt_ext_dataset), (len(tgt_dataset),
                                                          len(tgt_ext_dataset))
    else:
        tgt_ext_dataset = None
    # load attrs
    if attrs is None:
        attr_dataset = None
    else:
        attr_path = os.path.join(data_path, cur_task,
                                 '{}.code_types'.format(split))
        attr_dataset = [_load_dataset(attr_path, dataset_impl)]
        if sample_size_per_task > 0:
            for p_task in prev_tasks:
                p_attr_path = os.path.join(data_path, p_task,
                                           '{}.code_types'.format(split))
                p_attr_dataset = _load_dataset(p_attr_path, dataset_impl)
                attr_dataset.append(
                    SliceDataset(p_attr_dataset, end=sample_size_per_task))
        attr_dataset = ConcatDataset(attr_dataset)
        if truncate_target:
            tgt_dataset = TruncateDataset(tgt_dataset, max_target_positions)
            LOGGER.info(
                'Truncate dataset\'s attributes into max length: {}'.format(
                    max_target_positions))
        LOGGER.info('loaded attributes {} examples from: [{}](current task) + {}(previous tasks)'. \
                    format(len(attr_dataset), cur_task, prev_tasks))
        # load attr.ext
        attr_ext_path = os.path.join(data_path, cur_task,
                                     '{}.code_types.ext'.format(split))
        if indexed_dataset.SeqIndexedDataset.exists(attr_ext_path):
            attr_ext_dataset = indexed_dataset.SeqIndexedDataset(attr_ext_path)
            if sample_size_per_task > 0:
                for p_task in prev_tasks:
                    p_attr_ext_path = os.path.join(
                        data_path, p_task, '{}.code_types.ext'.format(split))
                    p_attr_ext_dataset = indexed_dataset.SeqIndexedDataset(
                        p_attr_ext_path)
                    p_attr_ext_dataset.truncate(end=sample_size_per_task)
                    attr_ext_dataset.append(p_attr_ext_dataset)
            if truncate_target:
                attr_ext_dataset.clip(max_position=max_target_positions)
            assert np.all(tgt_ext_dataset == attr_ext_dataset)
            del attr_ext_dataset

    return CompletionDataset(
        tgt_dataset,
        tgt_dataset.sizes,
        tgt_dict,
        extends=tgt_ext_dataset,
        attrs=attrs,
        attr_indices=attr_dataset,
        attr_dict=attr_dict,
        attrs_mapping=attrs_mapping,
        reversed_attrs_mapping=reversed_attrs_mapping,
        max_target_positions=max_target_positions,
    )
def load_langpair_dataset(
    data_path,
    split,
    src,
    src_dict,
    tgt,
    tgt_dict,
    dataset_impl,
    # combine, dataset_impl, upsample_primary,
    left_pad_source,
    left_pad_target,
    max_source_positions,
    max_target_positions,
    prepend_bos=True,
    append_eos=True,
    load_alignments=False,
    truncate_source=False,
    append_source_id=False,
    truncate_target=False,
    # lifelong learning
    prev_tasks=[],
    cur_task=None,
    sample_portion=None,
):
    # truncate sentence for prepend <bos> and append <eos>
    max_target_positions -= int(prepend_bos) + int(append_eos)

    # load source dataset
    src_path = os.path.join(data_path, cur_task, '{}.{}'.format(split, src))
    src_dataset = [
        _load_dataset(path=src_path, impl=dataset_impl, dict=src_dict)
    ]
    # load previous tasks
    if len(prev_tasks
           ) > 0 and cur_task is not None and sample_portion is not None:
        sample_size_per_task = int(
            len(src_dataset[0]) * sample_portion // len(prev_tasks))
    else:
        sample_size_per_task = -1
    if sample_size_per_task > 0:
        for p_task in prev_tasks:
            p_path = os.path.join(data_path, p_task,
                                  '{}.{}'.format(split, src))
            p_dataset = _load_dataset(p_path, dataset_impl, src_dict)
            src_dataset.append(
                SliceDataset(p_dataset, end=sample_size_per_task))
    src_dataset = ConcatDataset(src_dataset)
    # truncate dataset
    if truncate_source:
        # sntn => sntn[:max_source_positions]
        LOGGER.info('truncate {}.{} to {}'.format(split, src,
                                                  max_source_positions))
        src_dataset = TruncateDataset(src_dataset, max_source_positions)

    # load target dataset
    tgt_path = os.path.join(data_path, cur_task, '{}.{}'.format(split, tgt))
    tgt_dataset = [
        _load_dataset(path=tgt_path, impl=dataset_impl, dict=tgt_dict)
    ]
    if sample_size_per_task > 0:
        for p_task in prev_tasks:
            p_path = os.path.join(data_path, p_task,
                                  '{}.{}'.format(split, tgt))
            p_dataset = _load_dataset(p_path, dataset_impl, tgt_dict)
            tgt_dataset.append(
                SliceDataset(p_dataset, end=sample_size_per_task))
    tgt_dataset = ConcatDataset(tgt_dataset)
    if truncate_target:
        # sntn => sntn[:max_target_positions]
        LOGGER.info('truncate {}.{} to {}'.format(split, tgt,
                                                  max_target_positions))
        tgt_dataset = TruncateDataset(
            tgt_dataset, max_target_positions)  # 2 for BOS and EOS
    # sntn[:max_target_positions] => <bos> sntn[:max_target_positions]
    if prepend_bos:
        tgt_dataset = PrependTokenDataset(tgt_dataset, token=tgt_dict.bos())
    if append_eos:
        tgt_dataset = AppendTokenDataset(tgt_dataset, token=tgt_dict.eos())
    tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None

    assert len(src_dataset) == len(tgt_dataset), (len(src_dataset),
                                                  len(tgt_dataset))
    LOGGER.info('loaded {} examples from: [{}](current task) + {}(previous tasks)'. \
                format(len(src_dataset), cur_task, prev_tasks))
    return BELanguagePairDataset(
        src_dataset,
        src_dataset.sizes,
        src_dict,
        tgt_dataset,
        tgt_dataset_sizes,
        tgt_dict,
        left_pad_source=left_pad_source,
        left_pad_target=left_pad_target,
        max_source_positions=max_source_positions,
        max_target_positions=max_target_positions,
        align_dataset=None,
        bos=src_dict.bos(),
        eos=src_dict.eos(),
        shuffle=(split == 'train'),
    )
def load_inference_langpair_dataset(
    data_paths,
    split,
    src,
    src_dict,
    tgt,
    tgt_dict,
    dataset_impl,
    # combine, dataset_impl, upsample_primary,
    left_pad_source,
    left_pad_target,
    max_source_positions,
    max_target_positions,
    prepend_bos=True,
    append_eos=True,
    load_alignments=False,
    truncate_source=False,
    append_source_id=False,
    truncate_target=False,
):
    # truncate sentence for prepend <bos> and append <eos>
    max_target_positions -= int(prepend_bos) + int(append_eos)

    # load source dataset
    src_dataset = []
    for data_path in data_paths:
        src_path = os.path.join(data_path, '{}.{}'.format(split, src))
        src_dataset.append(
            _load_dataset(path=src_path, impl=dataset_impl, dict=src_dict))
    src_dataset = ConcatDataset(src_dataset)
    # truncate dataset
    if truncate_source:
        # sntn => sntn[:max_source_positions]
        LOGGER.info('truncate {}.{} to {}'.format(split, src,
                                                  max_source_positions))
        src_dataset = TruncateDataset(src_dataset, max_source_positions)

    # load target dataset
    tgt_dataset = []
    for data_path in data_paths:
        tgt_path = os.path.join(data_path, '{}.{}'.format(split, src))
        tgt_dataset.append(
            _load_dataset(path=tgt_path, impl=dataset_impl, dict=tgt_dict))
    tgt_dataset = ConcatDataset(tgt_dataset)
    if truncate_target:
        # sntn => sntn[:max_target_positions]
        LOGGER.info('truncate {}.{} to {}'.format(split, tgt,
                                                  max_target_positions))
        tgt_dataset = TruncateDataset(
            tgt_dataset, max_target_positions)  # 2 for BOS and EOS
    # sntn[:max_target_positions] => <bos> sntn[:max_target_positions]
    if prepend_bos:
        tgt_dataset = PrependTokenDataset(tgt_dataset, token=tgt_dict.bos())
    if append_eos:
        tgt_dataset = AppendTokenDataset(tgt_dataset, token=tgt_dict.eos())
    tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None

    LOGGER.info('loaded {} examples from: {}.{}'.format(
        len(src_dataset), data_paths, src))
    LOGGER.info('loaded {} examples from: {}.{}'.format(
        len(tgt_dataset), data_paths, tgt))
    return BELanguagePairDataset(
        src_dataset,
        src_dataset.sizes,
        src_dict,
        tgt_dataset,
        tgt_dataset_sizes,
        tgt_dict,
        left_pad_source=left_pad_source,
        left_pad_target=left_pad_target,
        max_source_positions=max_source_positions,
        max_target_positions=max_target_positions,
        align_dataset=None,
        bos=src_dict.bos(),
        eos=src_dict.eos(),
        shuffle=(split == 'train'),
    )
Exemple #11
0
def load_langpair_dataset(
    args,
    programming_langs,
    data_path,
    split,
    src,
    src_dict,
    tgt,
    tgt_dict,
    combine,
    dataset_impl,
    upsample_primary,
    left_pad_source,
    left_pad_target,
    max_source_positions,
    max_target_positions,
    prepend_bos=False,
    load_alignments=False,
    truncate_source=False,
    append_source_id=False,
    is_distill=False,
):
    def split_exists(split, src, data_path):
        filename = os.path.join(data_path,
                                '{}.{}'.format(split,
                                               src))  # -{}.{} , tgt, lang
        return indexed_dataset.dataset_exists(filename, impl=dataset_impl)

    src_datasets = []
    tgt_datasets = []

    topk_idxs = []
    topk_probs = []
    expert_scores = []
    dataset_ids = []
    lng_borders = [0]
    is_train = split == 'train'

    for ds_idx, program_lang in enumerate(programming_langs):
        lang_data_path = os.path.join(data_path, program_lang)

        split_k = split
        # infer langcode
        if split_exists(split_k, src, lang_data_path):
            prefix = os.path.join(lang_data_path,
                                  '{}.'.format(split_k))  # {}-{}. , src, tgt
        elif split_exists(split_k, tgt, lang_data_path):
            prefix = os.path.join(lang_data_path,
                                  '{}.'.format(split_k))  # {}-{}. , tgt, src
        else:
            raise NotImplementedError('No data in {}'.format(lang_data_path))

        src_dataset = data_utils.load_indexed_dataset(prefix + src, src,
                                                      src_dict, dataset_impl)
        if truncate_source:
            src_dataset = AppendTokenDataset(
                TruncateDataset(
                    StripTokenDataset(src_dataset, src_dict.eos()),
                    max_source_positions - 1,
                ),
                src_dict.eos(),
            )
        src_datasets.append(src_dataset)
        length = len(src_dataset)
        lng_borders.append(lng_borders[-1] + length)

        tgt_dataset = data_utils.load_indexed_dataset(prefix + tgt, tgt,
                                                      tgt_dict, dataset_impl)
        if tgt_dataset is not None:
            tgt_datasets.append(tgt_dataset)

        for i in range(length):
            dataset_ids.append(ds_idx)

        if is_distill and is_train:  # distill only for train
            path = '{}_{}_{}_topk_idx'.format(lang_data_path, src, tgt)
            topk_idxs.append(TeacherOutputDataset(path))
            path = '{}_{}_{}_topk_prob'.format(lang_data_path, src, tgt)
            topk_probs.append(TeacherOutputDataset(path))
            expert_bleu = os.path.join(
                data_path,
                'expert_bleu_{}_{}_{}.json'.format(program_lang, src, tgt))
            expert_bleu = json.load(open(expert_bleu))
            expert_scores.append(expert_bleu[f"bleu_{program_lang}"])

    assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0

    sample_ratios = [1] * len(src_datasets)
    sample_ratios[0] = upsample_primary
    src_dataset = ConcatDataset(src_datasets, sample_ratios)
    if len(tgt_datasets) > 0:
        tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
    else:
        tgt_dataset = None

    LOGGER.info('src data: {}, tgt data: {}'.format(len(src_dataset),
                                                    len(tgt_dataset)))

    if is_distill and is_train:  # distill only for train
        topk_idx_dataset = ConcatDataset(topk_idxs)
        topk_probs_dataset = ConcatDataset(topk_probs)
        assert len(topk_probs_dataset) == len(src_dataset), (
            len(topk_probs_dataset), len(src_dataset))
        assert len(topk_idx_dataset) == len(src_dataset)
    else:
        topk_idx_dataset = None
        topk_probs_dataset = None

    if prepend_bos:
        assert hasattr(src_dict, "bos_index") and hasattr(
            tgt_dict, "bos_index")
        src_dataset = PrependTokenDataset(src_dataset, src_dict.bos())
        if tgt_dataset is not None:
            tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos())

    tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None

    return UniversalDataset(
        args,
        src_dataset,
        src_dataset.sizes,
        src_dict,
        tgt_dataset,
        tgt_dataset_sizes,
        tgt_dict,
        dataset_ids=dataset_ids,
        lng_borders=lng_borders,
        dataset_names=programming_langs,
        left_pad_source=left_pad_source,
        left_pad_target=left_pad_target,
        max_source_positions=max_source_positions,
        max_target_positions=max_target_positions,
        topk_idxs=topk_idx_dataset,
        topk_probs=topk_probs_dataset,
        expert_scores=expert_scores,
        is_train=is_train,
    )
Exemple #12
0
def load_token_dataset(
    data_paths,
    split,
    tgt,
    tgt_dict,
    dataset_impl,
    attrs=None,
    attr_dict=None,
    attrs_mapping=None,
    reversed_attrs_mapping=None,
    truncate_target=False,
    max_target_positions=None,
):
    # load tokens
    tgt_dataset = []
    for data_path in data_paths:
        tgt_path = os.path.join(data_path, '{}.{}'.format(split, tgt))
        tgt_dataset.append(_load_dataset(tgt_path, dataset_impl))
    tgt_dataset = ConcatDataset(tgt_dataset)
    if truncate_target:
        tgt_dataset = TruncateDataset(tgt_dataset, max_target_positions)
        LOGGER.info('Truncate dataset into max length: {}'.format(
            max_target_positions))
    LOGGER.info('loaded {} examples from: {}'.format(len(tgt_dataset),
                                                     data_paths))
    # load tokens.ext
    tgt_ext_paths = [
        os.path.join(data_path, '{}.{}.ext'.format(split, tgt))
        for data_path in data_paths
    ]
    if all(
            indexed_dataset.SeqIndexedDataset.exists(tgt_ext_path)
            for tgt_ext_path in tgt_ext_paths):
        tgt_ext_dataset = indexed_dataset.SeqIndexedDataset(tgt_ext_paths[0])
        for tgt_ext_path in tgt_ext_paths[1:]:
            tgt_ext_dataset.append(
                indexed_dataset.SeqIndexedDataset(tgt_ext_path))
        if truncate_target:
            tgt_ext_dataset.clip(max_position=max_target_positions)
        assert len(tgt_dataset) == len(tgt_ext_dataset), (len(tgt_dataset),
                                                          len(tgt_ext_dataset))
    else:
        tgt_ext_dataset = None
    # load attrs
    if attrs is None:
        attr_dataset = None
    else:
        attr_dataset = []
        for data_path in data_paths:
            attr_path = os.path.join(data_path, '{}.code_types'.format(split))
            attr_dataset.append(_load_dataset(attr_path, dataset_impl))
        attr_dataset = ConcatDataset(attr_dataset)
        if truncate_target:
            tgt_dataset = TruncateDataset(tgt_dataset, max_target_positions)
            LOGGER.info(
                'Truncate dataset\'s attributes into max length: {}'.format(
                    max_target_positions))
        LOGGER.info('loaded {} examples from: {}'.format(
            len(attr_dataset), data_path))
        # load attr.ext
        attr_ext_paths = [
            os.path.join(data_path, '{}.code_types.ext'.format(split))
            for data_path in data_paths
        ]
        if all(
                indexed_dataset.SeqIndexedDataset.exists(attr_ext_path)
                for attr_ext_path in attr_ext_paths):
            attr_ext_dataset = indexed_dataset.SeqIndexedDataset(
                attr_ext_paths[0])
            for attr_ext_path in attr_ext_paths[1:]:
                attr_ext_dataset.append(
                    indexed_dataset.SeqIndexedDataset(attr_ext_path))
            if truncate_target:
                attr_ext_dataset.clip(max_position=max_target_positions)
            assert np.all(tgt_ext_dataset == attr_ext_dataset)
            del attr_ext_dataset

    return CompletionDataset(
        tgt_dataset,
        tgt_dataset.sizes,
        tgt_dict,
        extends=tgt_ext_dataset,
        attrs=attrs,
        attr_indices=attr_dataset,
        attr_dict=attr_dict,
        attrs_mapping=attrs_mapping,
        reversed_attrs_mapping=reversed_attrs_mapping,
        max_target_positions=max_target_positions,
    )
def load_langpair_dataset(
    data_path,
    split,
    src,
    src_dict,
    tgt,
    tgt_dict,
    combine,
    dataset_impl,
    upsample_primary,
    left_pad_source,
    left_pad_target,
    max_source_positions,
    max_target_positions,
    load_alignments=False,
    truncate_source=False,
    truncate_target=False,
):
    def split_exists(split, src, data_path):
        filename = os.path.join(data_path,
                                '{}.{}'.format(split,
                                               src))  # -{}.{} , tgt, lang
        return indexed_dataset.dataset_exists(filename, impl=dataset_impl)

    src_datasets = []
    tgt_datasets = []

    for k in itertools.count():
        split_k = split + (str(k) if k > 0 else '')

        # infer langcode
        if split_exists(split_k, src, data_path):
            prefix = os.path.join(data_path,
                                  '{}.'.format(split_k))  # {}-{}. , src, tgt
        elif split_exists(split_k, tgt, data_path):
            prefix = os.path.join(data_path,
                                  '{}.'.format(split_k))  # {}-{}. , tgt, src

        else:
            if k > 0:
                break
            else:
                raise FileNotFoundError('Dataset not found: {} ({})'.format(
                    split, data_path))

        src_dataset = data_utils.load_indexed_dataset(prefix + src, 'text',
                                                      src_dict, dataset_impl)
        if truncate_source and max_source_positions:
            src_dataset = TruncateDataset(src_dataset, max_source_positions)
        src_datasets.append(src_dataset)

        tgt_dataset = data_utils.load_indexed_dataset(prefix + tgt, 'text',
                                                      tgt_dict, dataset_impl)
        if truncate_target and max_target_positions:
            tgt_dataset = PrependTokenDataset(
                AppendTokenDataset(TruncateDataset(tgt_dataset,
                                                   max_target_positions - 2),
                                   token=tgt_dict.eos()), tgt_dict.bos())

        if tgt_dataset is not None:
            tgt_datasets.append(tgt_dataset)

        if not combine:
            break

    assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0

    if len(src_datasets) == 1:
        src_dataset = src_datasets[0]
        tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None
    else:
        sample_ratios = [1] * len(src_datasets)
        sample_ratios[0] = upsample_primary
        src_dataset = ConcatDataset(src_datasets, sample_ratios)
        if len(tgt_datasets) > 0:
            tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
        else:
            tgt_dataset = None

    tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None
    return LanguagePairDataset(
        src_dataset,
        src_dataset.sizes,
        src_dict,
        tgt_dataset,
        tgt_dataset_sizes,
        tgt_dict,
        left_pad_source=left_pad_source,
        left_pad_target=left_pad_target,
        max_source_positions=max_source_positions,
        max_target_positions=max_target_positions,
        shuffle=True,  # TODO debug: shuffle=False
    )