def load_masked_code_docstring_dataset_unilm(data_path,
                                             split,
                                             src,
                                             src_dict,
                                             tgt,
                                             tgt_dict,
                                             combine,
                                             dataset_impl,
                                             upsample_primary,
                                             left_pad_source,
                                             left_pad_target,
                                             max_source_positions,
                                             max_target_positions,
                                             prepend_bos=False,
                                             load_alignments=False,
                                             max_src_len=0,
                                             max_tgt_len=0,
                                             truncate_source=False,
                                             append_source_id=False):
    source_path = os.path.join(data_path, '{}.code'.format(split))
    target_path = os.path.join(data_path, '{}.docstring'.format(split))

    # source_dataset
    source_dataset = data_utils.load_indexed_dataset(source_path, 'text',
                                                     src_dict, dataset_impl)
    if source_dataset is None:
        raise FileNotFoundError('Dataset not found: {} ({})'.format(
            split, source_path))

    # target_dataset
    target_dataset = data_utils.load_indexed_dataset(target_path, 'text',
                                                     tgt_dict, dataset_impl)
    if target_dataset is None:
        raise FileNotFoundError('Dataset not found: {} ({})'.format(
            split, target_path))

    eos = None
    align_dataset = None
    target_dataset_sizes = target_dataset.sizes if target_dataset is not None else None

    return MaskCodeDocstringPairDataset(
        source_dataset,
        source_dataset.sizes,
        src_dict,
        target_dataset,
        target_dataset_sizes,
        tgt_dict,
        left_pad_source=left_pad_source,
        left_pad_target=left_pad_target,
        max_source_positions=max_source_positions,
        max_target_positions=max_target_positions,
        align_dataset=align_dataset,
        eos=eos,
        skipgram_prb=0.0,
        skipgram_size=0.0,
    )
Exemplo n.º 2
0
    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = utils.split_paths(self.args.data)
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]
        split_path = os.path.join(data_path, split)

        dataset = data_utils.load_indexed_dataset(
            split_path,
            self.dictionary,
            self.args.dataset_impl,
            combine=combine,
        )
        if dataset is None:
            raise FileNotFoundError('Dataset not found: {} ({})'.format(
                split, split_path))

        dataset = StripTokenDataset(dataset, self.dictionary.eos())

        # create continuous blocks of tokens
        dataset = TokenBlockDataset(
            dataset,
            dataset.sizes,
            self.args.tokens_per_sample -
            2,  # one less for <s> and one for </s>
            pad=self.dictionary.pad(),
            eos=self.dictionary.eos(),
            break_mode=self.args.sample_break_mode,
            document_sep_len=0)

        # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
        dataset = PrependTokenDataset(dataset, self.source_dictionary.bos())
        dataset = AppendTokenDataset(dataset, self.source_dictionary.eos())

        mask_whole_words = get_whole_word_mask(self.args, self.source_dictionary) \
            if self.args.mask_length != 'subword' else None

        self.datasets[split] = DenoisingDataset(
            dataset,
            dataset.sizes,
            self.dictionary,
            self.mask_idx,
            mask_whole_words,
            shuffle=self.args.shuffle_instance,
            seed=self.seed,
            args=self.args)
        LOGGER.info(
            "Split: {0}, Loaded {1} samples of denoising_dataset".format(
                split,
                len(self.datasets[split]),
            ))
Exemplo n.º 3
0
    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """

        paths = utils.split_paths(self.args['task']['data'])
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]
        split_path = os.path.join(data_path, split)

        dataset = data_utils.load_indexed_dataset(
            path=split_path,
            dictionary=self.source_dictionary,
            dataset_impl=self.args['dataset']['dataset_impl'],
            combine=combine,
        )
        if dataset is None:
            raise FileNotFoundError('Dataset not found: {} ({})'.format(
                split, split_path))

        # create continuous blocks of tokens
        dataset = TokenBlockDataset(
            dataset,
            dataset.sizes,
            self.args['task']['tokens_per_sample'] - 1,  # one less for <s>
            pad=self.source_dictionary.pad(),
            eos=self.source_dictionary.eos(),
            break_mode=self.args['task']['sample_break_mode'],
        )
        LOGGER.info('loaded {} blocks from: {}'.format(len(dataset),
                                                       split_path))

        # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
        dataset = PrependTokenDataset(dataset, self.source_dictionary.bos())

        # create masked input and targets
        mask_whole_words = get_whole_word_mask(self.args, self.source_dictionary) \
            if self.args['task']['mask_whole_words'] else None

        src_dataset, tgt_dataset = MaskTokensDataset.apply_mask(
            dataset,
            self.source_dictionary,
            pad_idx=self.source_dictionary.pad(),
            mask_idx=self.mask_idx,
            seed=self.args['common']['seed'],
            mask_prob=self.args['task']['mask_prob'],
            leave_unmasked_prob=self.args['task']['leave_unmasked_prob'],
            random_token_prob=self.args['task']['random_token_prob'],
            freq_weighted_replacement=self.args['task']
            ['freq_weighted_replacement'],
            mask_whole_words=mask_whole_words,
        )

        with data_utils.numpy_seed(self.args['common']['seed'] + epoch):
            shuffle = np.random.permutation(len(src_dataset))

        self.datasets[split] = SortDataset(
            NestedDictionaryDataset(
                {
                    'id':
                    IdDataset(),
                    'net_input': {
                        'src_tokens':
                        PadDataset(
                            src_dataset,
                            pad_idx=self.source_dictionary.pad(),
                            left_pad=False,
                        ),
                        'src_lengths':
                        NumelDataset(src_dataset, reduce=False),
                    },
                    'target':
                    PadDataset(
                        tgt_dataset,
                        pad_idx=self.source_dictionary.pad(),
                        left_pad=False,
                    ),
                    'nsentences':
                    NumSamplesDataset(),
                    'ntokens':
                    NumelDataset(src_dataset, reduce=True),
                },
                sizes=[src_dataset.sizes],
            ),
            sort_order=[
                shuffle,
                src_dataset.sizes,
            ],
        )
def load_masked_code_docstring_dataset_roberta(args,
                                               epoch,
                                               data_path,
                                               split,
                                               src,
                                               src_dict,
                                               tgt,
                                               tgt_dict,
                                               combine,
                                               dataset_impl,
                                               upsample_primary,
                                               left_pad_source,
                                               left_pad_target,
                                               max_source_positions,
                                               max_target_positions,
                                               prepend_bos=False,
                                               load_alignments=False,
                                               truncate_source=False,
                                               append_source_id=False):
    source_path = os.path.join(data_path, '{}.code'.format(split))
    target_path = os.path.join(data_path, '{}.docstring'.format(split))

    # source_dataset
    source_dataset = data_utils.load_indexed_dataset(source_path,
                                                     'text',
                                                     src_dict,
                                                     tokenizer=None,
                                                     dataset_impl=dataset_impl)
    if source_dataset is None:
        raise FileNotFoundError('Dataset not found: {} ({})'.format(
            split, source_path))
    # target_dataset
    target_dataset = data_utils.load_indexed_dataset(target_path,
                                                     'text',
                                                     tgt_dict,
                                                     tokenizer=None,
                                                     dataset_impl=dataset_impl)
    if target_dataset is None:
        raise FileNotFoundError('Dataset not found: {} ({})'.format(
            split, target_path))

    # concate dataset
    dataset = ConcatSentencesDataset([source_dataset, target_dataset])
    # create continuous blocks of tokens
    dataset = TokenBlockDataset(
        dataset,
        dataset.sizes,
        args['task']['tokens_per_sample'] - 1,  # one less for <s>
        pad=src_dict.pad(),
        eos=src_dict.eos(),
        break_mode=args['task']['sample_break_mode'],
    )
    # LOGGER.info('loaded {} blocks from: {}'.format(len(dataset), split_path))

    # # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
    dataset = PrependTokenDataset(dataset,
                                  src_dict.bos())  # .source_dictionary.bos()
    #
    # # create masked input and targets
    mask_whole_words = get_whole_word_mask(args, src_dict) \
        if args['task']['mask_whole_words'] else None

    src_dataset, tgt_dataset = MaskTokensDataset.apply_mask(
        dataset,
        src_dict,
        pad_idx=src_dict.pad(),
        mask_idx=src_dict.index(constants.T_MASK),  # self.mask_idx,
        seed=args['common']['seed'],
        mask_prob=args['task']['mask_prob'],
        leave_unmasked_prob=args['task']['leave_unmasked_prob'],
        random_token_prob=args['task']['random_token_prob'],
        freq_weighted_replacement=args['task']['freq_weighted_replacement'],
        mask_whole_words=mask_whole_words,
    )

    with data_utils.numpy_seed(args['common']['seed'] + epoch):
        shuffle = np.random.permutation(len(src_dataset))

    return SortDataset(
        NestedDictionaryDataset(
            {
                'id':
                IdDataset(),
                'net_input': {
                    'src_tokens':
                    PadDataset(
                        src_dataset,
                        pad_idx=src_dict.pad(),
                        left_pad=False,
                    ),
                    'src_lengths':
                    NumelDataset(src_dataset, reduce=False),
                },
                'target':
                PadDataset(
                    tgt_dataset,
                    pad_idx=src_dict.pad(),
                    left_pad=False,
                ),
                'nsentences':
                NumSamplesDataset(),
                'ntokens':
                NumelDataset(src_dataset, reduce=True),
            },
            sizes=[src_dataset.sizes],
        ),
        sort_order=[
            shuffle,
            src_dataset.sizes,
        ],
    )
Exemplo n.º 5
0
def load_masked_traverse_dataset_roberta(
    args,
    epoch,
    data_path,
    split,
    source_dictionary,
    combine,
):
    split_path = os.path.join(data_path, '{}.ast_trav_df'.format(split))
    dataset = data_utils.load_indexed_dataset(
        path=split_path,
        dictionary=source_dictionary,
        dataset_impl=args['dataset']['dataset_impl'],
        combine=combine,
    )
    if dataset is None:
        raise FileNotFoundError('Dataset not found: {} ({})'.format(
            split, split_path))

    # # create continuous blocks of tokens
    # dataset = TokenBlockDataset(
    #     dataset,
    #     dataset.sizes,
    #     args['task']['tokens_per_sample'] - 1,  # one less for <s>
    #     pad=source_dictionary.pad(),
    #     eos=source_dictionary.eos(),
    #     break_mode=args['task']['sample_break_mode'],
    # )
    # LOGGER.info('loaded {} blocks from: {}'.format(len(dataset), split_path))

    # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
    dataset = PrependTokenDataset(
        dataset, source_dictionary.bos())  # .source_dictionary.bos()

    # create masked input and targets
    mask_whole_words = get_whole_word_mask(args, source_dictionary) \
        if args['task']['mask_whole_words'] else None

    src_dataset, tgt_dataset = MaskTokensDataset.apply_mask(
        dataset,
        source_dictionary,
        pad_idx=source_dictionary.pad(),
        mask_idx=source_dictionary.index(constants.MASK),  # self.mask_idx,
        seed=args['common']['seed'],
        mask_prob=args['task']['mask_prob'],
        leave_unmasked_prob=args['task']['leave_unmasked_prob'],
        random_token_prob=args['task']['random_token_prob'],
        freq_weighted_replacement=args['task']['freq_weighted_replacement'],
        mask_whole_words=mask_whole_words,
    )

    with data_utils.numpy_seed(args['common']['seed'] + epoch):
        shuffle = np.random.permutation(len(src_dataset))

    return SortDataset(
        NestedDictionaryDataset(
            {
                'id':
                IdDataset(),
                'net_input': {
                    'src_tokens':
                    PadDataset(
                        src_dataset,
                        pad_idx=source_dictionary.pad(),
                        left_pad=False,
                    ),
                    'src_lengths':
                    NumelDataset(src_dataset, reduce=False),
                },
                'target':
                PadDataset(
                    tgt_dataset,
                    pad_idx=source_dictionary.pad(),
                    left_pad=False,
                ),
                'nsentences':
                NumSamplesDataset(),
                'ntokens':
                NumelDataset(src_dataset, reduce=True),
            },
            sizes=[src_dataset.sizes],
        ),
        sort_order=[
            shuffle,
            src_dataset.sizes,
        ],
    )
Exemplo n.º 6
0
def load_langpair_dataset(
    data_path,
    split,
    domains,
    src,
    src_dict,
    tgt,
    tgt_dict,
    combine,
    dataset_impl,
    upsample_primary,
    left_pad_source,
    left_pad_target,
    max_source_positions,
    max_target_positions,
    prepend_bos=False,
    load_alignments=False,
    truncate_source=False,
    append_source_id=False,
    append_eos_to_target=False,
):
    def split_exists(split, src, data_path, domain):
        filename = os.path.join(data_path, domain,
                                '{}.{}'.format(split,
                                               src))  # -{}.{} , tgt, lang
        return indexed_dataset.dataset_exists(filename, impl=dataset_impl)

    src_datasets = []
    tgt_datasets = []

    for dm in domains:
        # load datasets of src domains
        for k in itertools.count():
            split_k = split + (str(k) if k > 0 else '')

            # infer langcode
            if split_exists(split_k, src, data_path, dm):
                prefix = os.path.join(
                    data_path, dm, '{}.'.format(split_k))  # {}-{}. , src, tgt
            elif split_exists(split_k, tgt, data_path, dm):
                prefix = os.path.join(
                    data_path, dm, '{}.'.format(split_k))  # {}-{}. , tgt, src

            else:
                if k > 0:
                    break
                else:
                    raise FileNotFoundError(
                        'Dataset not found: {} ({})'.format(split, data_path))

            src_dataset = data_utils.load_indexed_dataset(
                prefix + src, src, src_dict, dataset_impl)
            if truncate_source:
                src_dataset = AppendTokenDataset(
                    TruncateDataset(
                        StripTokenDataset(src_dataset, src_dict.eos()),
                        max_source_positions - 1,
                    ),
                    src_dict.eos(),
                )
            src_datasets.append(src_dataset)

            tgt_dataset = data_utils.load_indexed_dataset(
                prefix + tgt, tgt, tgt_dict, dataset_impl)
            if tgt_dataset is not None:
                tgt_datasets.append(tgt_dataset)

            if not combine:
                break

    assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0

    if len(src_datasets) == 1:
        src_dataset = src_datasets[0]
        tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None
    else:
        sample_ratios = [1] * len(src_datasets)
        sample_ratios[0] = upsample_primary
        src_dataset = ConcatDataset(src_datasets, sample_ratios)
        if len(tgt_datasets) > 0:
            tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
        else:
            tgt_dataset = None

    if prepend_bos:
        assert hasattr(src_dict, "bos_index") and hasattr(
            tgt_dict, "bos_index")
        src_dataset = PrependTokenDataset(src_dataset, src_dict.bos())
        if tgt_dataset is not None:
            tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos())

    eos = None
    if append_source_id:
        src_dataset = AppendTokenDataset(src_dataset,
                                         src_dict.index('[{}]'.format(src)))
        if tgt_dataset is not None:
            tgt_dataset = AppendTokenDataset(
                tgt_dataset, tgt_dict.index('[{}]'.format(tgt)))
        eos = tgt_dict.index('[{}]'.format(tgt))

    # align_dataset = None
    # if load_alignments:
    #     align_path = os.path.join(data_path, '{}.align.{}-{}'.format(split, src, tgt))
    #     if indexed_dataset.dataset_exists(align_path, impl=dataset_impl):
    #         align_dataset = data_utils.load_indexed_dataset(align_path, None, dataset_impl)

    tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None
    return LanguagePairDataset(
        src_dataset,
        src_dataset.sizes,
        src_dict,
        tgt_dataset,
        tgt_dataset_sizes,
        tgt_dict,
        left_pad_source=left_pad_source,
        left_pad_target=left_pad_target,
        max_source_positions=max_source_positions,
        max_target_positions=max_target_positions,
        align_dataset=None,
        eos=eos,
        remove_eos_from_source=True,
        append_eos_to_target=append_eos_to_target,
        shuffle=True,  # TODO debug: shuffle=False
    )
Exemplo n.º 7
0
def load_langpair_dataset(
    args,
    programming_langs,
    data_path,
    split,
    src,
    src_dict,
    tgt,
    tgt_dict,
    combine,
    dataset_impl,
    upsample_primary,
    left_pad_source,
    left_pad_target,
    max_source_positions,
    max_target_positions,
    prepend_bos=False,
    load_alignments=False,
    truncate_source=False,
    append_source_id=False,
    is_distill=False,
):
    def split_exists(split, src, data_path):
        filename = os.path.join(data_path,
                                '{}.{}'.format(split,
                                               src))  # -{}.{} , tgt, lang
        return indexed_dataset.dataset_exists(filename, impl=dataset_impl)

    src_datasets = []
    tgt_datasets = []

    topk_idxs = []
    topk_probs = []
    expert_scores = []
    dataset_ids = []
    lng_borders = [0]
    is_train = split == 'train'

    for ds_idx, program_lang in enumerate(programming_langs):
        lang_data_path = os.path.join(data_path, program_lang)

        split_k = split
        # infer langcode
        if split_exists(split_k, src, lang_data_path):
            prefix = os.path.join(lang_data_path,
                                  '{}.'.format(split_k))  # {}-{}. , src, tgt
        elif split_exists(split_k, tgt, lang_data_path):
            prefix = os.path.join(lang_data_path,
                                  '{}.'.format(split_k))  # {}-{}. , tgt, src
        else:
            raise NotImplementedError('No data in {}'.format(lang_data_path))

        src_dataset = data_utils.load_indexed_dataset(prefix + src, src,
                                                      src_dict, dataset_impl)
        if truncate_source:
            src_dataset = AppendTokenDataset(
                TruncateDataset(
                    StripTokenDataset(src_dataset, src_dict.eos()),
                    max_source_positions - 1,
                ),
                src_dict.eos(),
            )
        src_datasets.append(src_dataset)
        length = len(src_dataset)
        lng_borders.append(lng_borders[-1] + length)

        tgt_dataset = data_utils.load_indexed_dataset(prefix + tgt, tgt,
                                                      tgt_dict, dataset_impl)
        if tgt_dataset is not None:
            tgt_datasets.append(tgt_dataset)

        for i in range(length):
            dataset_ids.append(ds_idx)

        if is_distill and is_train:  # distill only for train
            path = '{}_{}_{}_topk_idx'.format(lang_data_path, src, tgt)
            topk_idxs.append(TeacherOutputDataset(path))
            path = '{}_{}_{}_topk_prob'.format(lang_data_path, src, tgt)
            topk_probs.append(TeacherOutputDataset(path))
            expert_bleu = os.path.join(
                data_path,
                'expert_bleu_{}_{}_{}.json'.format(program_lang, src, tgt))
            expert_bleu = json.load(open(expert_bleu))
            expert_scores.append(expert_bleu[f"bleu_{program_lang}"])

    assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0

    sample_ratios = [1] * len(src_datasets)
    sample_ratios[0] = upsample_primary
    src_dataset = ConcatDataset(src_datasets, sample_ratios)
    if len(tgt_datasets) > 0:
        tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
    else:
        tgt_dataset = None

    LOGGER.info('src data: {}, tgt data: {}'.format(len(src_dataset),
                                                    len(tgt_dataset)))

    if is_distill and is_train:  # distill only for train
        topk_idx_dataset = ConcatDataset(topk_idxs)
        topk_probs_dataset = ConcatDataset(topk_probs)
        assert len(topk_probs_dataset) == len(src_dataset), (
            len(topk_probs_dataset), len(src_dataset))
        assert len(topk_idx_dataset) == len(src_dataset)
    else:
        topk_idx_dataset = None
        topk_probs_dataset = None

    if prepend_bos:
        assert hasattr(src_dict, "bos_index") and hasattr(
            tgt_dict, "bos_index")
        src_dataset = PrependTokenDataset(src_dataset, src_dict.bos())
        if tgt_dataset is not None:
            tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos())

    tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None

    return UniversalDataset(
        args,
        src_dataset,
        src_dataset.sizes,
        src_dict,
        tgt_dataset,
        tgt_dataset_sizes,
        tgt_dict,
        dataset_ids=dataset_ids,
        lng_borders=lng_borders,
        dataset_names=programming_langs,
        left_pad_source=left_pad_source,
        left_pad_target=left_pad_target,
        max_source_positions=max_source_positions,
        max_target_positions=max_target_positions,
        topk_idxs=topk_idx_dataset,
        topk_probs=topk_probs_dataset,
        expert_scores=expert_scores,
        is_train=is_train,
    )
Exemplo n.º 8
0
def load_langpair_dataset(
    data_path,
    split,
    src,
    src_dict,
    tgt,
    tgt_dict,
    dataset_impl,
    left_pad_source,
    left_pad_target,
    max_source_positions,
    max_target_positions,
    prepend_bos=False,
    truncate_source=False,
    append_source_id=False,
    truncate_target=False,
    append_eos_to_target=False,
    portion=None,
):
    # load source dataset
    src_path = os.path.join(data_path, '{}.{}'.format(split, src))
    src_dataset = data_utils.load_indexed_dataset(src_path,
                                                  dictionary=src_dict,
                                                  dataset_impl=dataset_impl)

    # load target dataset
    tgt_path = os.path.join(data_path, '{}.{}'.format(split, tgt))
    tgt_dataset = data_utils.load_indexed_dataset(tgt_path,
                                                  dictionary=tgt_dict,
                                                  dataset_impl=dataset_impl)

    # few-shot learning
    if portion is not None and split == 'train':
        LOGGER.info('set {}.{} portion to {}'.format(split, src, portion))
        src_dataset = PortionDataset(src_dataset, portion)
        LOGGER.info('set {}.{} portion to {}'.format(split, tgt, portion))
        tgt_dataset = PortionDataset(tgt_dataset, portion)

    # prepend BOS
    if prepend_bos:
        assert hasattr(src_dict, "bos_index") and hasattr(
            tgt_dict, "bos_index")
        src_dataset = PrependTokenDataset(src_dataset, src_dict.bos())
        if tgt_dataset is not None:
            tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos())

    # append EOS
    if truncate_source:
        LOGGER.info('truncate {}.{} to {}'.format(split, src,
                                                  max_source_positions))
        src_dataset = AppendTokenDataset(
            TruncateDataset(
                StripTokenDataset(src_dataset, src_dict.eos()),
                max_source_positions - 1 - int(append_source_id),
            ),
            src_dict.eos(),
        )

    if truncate_target:
        LOGGER.info('truncate {}.{} to {}'.format(split, tgt,
                                                  max_target_positions))
        tgt_dataset = AppendTokenDataset(
            TruncateDataset(
                StripTokenDataset(tgt_dataset, tgt_dict.eos()),
                max_target_positions - 1 - int(append_source_id),
            ),
            tgt_dict.eos(),
        )

    # append [lang]
    eos = None
    if append_source_id:
        src_dataset = AppendTokenDataset(src_dataset,
                                         src_dict.index(f"[{src}]"))
        if tgt_dataset is not None:
            tgt_dataset = AppendTokenDataset(tgt_dataset,
                                             tgt_dict.index(f"[{tgt}]"))
        eos = tgt_dict.index(f"[{tgt}]")

    tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None

    LOGGER.info('loaded {} examples from: {}'.format(len(src_dataset),
                                                     src_path))
    LOGGER.info('loaded {} examples from: {}'.format(len(tgt_dataset),
                                                     tgt_path))
    return LanguagePairDataset(
        src_dataset,
        src_dataset.sizes,
        src_dict,
        tgt_dataset,
        tgt_dataset_sizes,
        tgt_dict,
        left_pad_source=left_pad_source,
        left_pad_target=left_pad_target,
        max_source_positions=max_source_positions,
        max_target_positions=max_target_positions,
        align_dataset=None,
        eos=eos,
        shuffle=(split == "train"),
    )
def load_langpair_dataset(
    data_path,
    split,
    src,
    src_dict,
    tgt,
    tgt_dict,
    combine,
    dataset_impl,
    upsample_primary,
    left_pad_source,
    left_pad_target,
    max_source_positions,
    max_target_positions,
    load_alignments=False,
    truncate_source=False,
    truncate_target=False,
):
    def split_exists(split, src, data_path):
        filename = os.path.join(data_path,
                                '{}.{}'.format(split,
                                               src))  # -{}.{} , tgt, lang
        return indexed_dataset.dataset_exists(filename, impl=dataset_impl)

    src_datasets = []
    tgt_datasets = []

    for k in itertools.count():
        split_k = split + (str(k) if k > 0 else '')

        # infer langcode
        if split_exists(split_k, src, data_path):
            prefix = os.path.join(data_path,
                                  '{}.'.format(split_k))  # {}-{}. , src, tgt
        elif split_exists(split_k, tgt, data_path):
            prefix = os.path.join(data_path,
                                  '{}.'.format(split_k))  # {}-{}. , tgt, src

        else:
            if k > 0:
                break
            else:
                raise FileNotFoundError('Dataset not found: {} ({})'.format(
                    split, data_path))

        src_dataset = data_utils.load_indexed_dataset(prefix + src, 'text',
                                                      src_dict, dataset_impl)
        if truncate_source and max_source_positions:
            src_dataset = TruncateDataset(src_dataset, max_source_positions)
        src_datasets.append(src_dataset)

        tgt_dataset = data_utils.load_indexed_dataset(prefix + tgt, 'text',
                                                      tgt_dict, dataset_impl)
        if truncate_target and max_target_positions:
            tgt_dataset = PrependTokenDataset(
                AppendTokenDataset(TruncateDataset(tgt_dataset,
                                                   max_target_positions - 2),
                                   token=tgt_dict.eos()), tgt_dict.bos())

        if tgt_dataset is not None:
            tgt_datasets.append(tgt_dataset)

        if not combine:
            break

    assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0

    if len(src_datasets) == 1:
        src_dataset = src_datasets[0]
        tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None
    else:
        sample_ratios = [1] * len(src_datasets)
        sample_ratios[0] = upsample_primary
        src_dataset = ConcatDataset(src_datasets, sample_ratios)
        if len(tgt_datasets) > 0:
            tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
        else:
            tgt_dataset = None

    tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None
    return LanguagePairDataset(
        src_dataset,
        src_dataset.sizes,
        src_dict,
        tgt_dataset,
        tgt_dataset_sizes,
        tgt_dict,
        left_pad_source=left_pad_source,
        left_pad_target=left_pad_target,
        max_source_positions=max_source_positions,
        max_target_positions=max_target_positions,
        shuffle=True,  # TODO debug: shuffle=False
    )