コード例 #1
0
ファイル: interactive.py プロジェクト: zzzzxciid/fairseq
def make_batches(lines, args, src_dict, max_positions):
    tokens = [
        tokenizer.Tokenizer.tokenize(src_str, src_dict,
                                     add_if_not_exist=False).long()
        for src_str in lines
    ]
    lengths = np.array([t.numel() for t in tokens])
    itr = data.EpochBatchIterator(
        dataset=data.LanguagePairDataset(tokens, lengths, src_dict),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=max_positions,
    ).next_epoch_itr(shuffle=False)
    for batch in itr:
        yield Batch(
            srcs=[lines[i] for i in batch['id']],
            tokens=batch['net_input']['src_tokens'],
            lengths=batch['net_input']['src_lengths'],
        ), batch['id']
コード例 #2
0
def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoch):
    tokens = torch.LongTensor(list(range(epoch_size))).view(1, -1)
    tokens_ds = data.TokenBlockDataset(
        tokens,
        sizes=[tokens.size(-1)],
        block_size=1,
        pad=0,
        eos=1,
        include_targets=False,
    )
    trainer = mock_trainer(epoch, num_updates, iterations_in_epoch)
    dataset = data.LanguagePairDataset(
        tokens_ds, tokens_ds.sizes, mock_dict(), shuffle=False
    )
    epoch_itr = data.EpochBatchIterator(
        dataset=dataset,
        collate_fn=dataset.collater,
        batch_sampler=[[i] for i in range(epoch_size)],
    )
    return trainer, epoch_itr
コード例 #3
0
def make_batches(lines, args, task, max_positions):
    tokens = [
        tokenizer.Tokenizer.tokenize(src_str,
                                     task.source_dictionary,
                                     add_if_not_exist=False).long()
        for src_str in lines
    ]
    lengths = np.array([t.numel() for t in tokens])
    itr = task.get_batch_iterator(
        dataset=data.LanguagePairDataset(tokens, lengths,
                                         task.source_dictionary),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=max_positions,
    ).next_epoch_itr(shuffle=False)
    for batch in itr:
        yield Batch(
            srcs=[lines[i] for i in batch["id"]],
            tokens=batch["net_input"]["src_tokens"],
            lengths=batch["net_input"]["src_lengths"],
        ), batch["id"]
コード例 #4
0
def load_binarized_dataset(
    train_corpus: ParallelCorpusConfig,
    eval_corpus: ParallelCorpusConfig,
    train_split: str,
    eval_split: str,
    args: argparse.Namespace,
) -> data.LanguageDatasets:
    source_dict = pytorch_translate_dictionary.Dictionary.load(
        args.source_vocab_file)
    target_dict = pytorch_translate_dictionary.Dictionary.load(
        args.target_vocab_file)

    dataset = data.LanguageDatasets(
        src=train_corpus.source.dialect,
        dst=train_corpus.target.dialect,
        src_dict=source_dict,
        dst_dict=target_dict,
    )

    for split, corpus in [(train_split, train_corpus),
                          (eval_split, eval_corpus)]:
        if (not indexed_dataset.IndexedInMemoryDataset.exists(
                corpus.source.data_file)
                or not indexed_dataset.IndexedInMemoryDataset.exists(
                    corpus.target.data_file)):
            raise ValueError(
                f"One or both of source file: {corpus.source.data_file} and "
                f"target file: {corpus.target.data_file} for split {split} "
                f"was not found.")

        dataset.splits[split] = data.LanguagePairDataset(
            src=indexed_dataset.IndexedInMemoryDataset(
                corpus.source.data_file),
            dst=indexed_dataset.IndexedInMemoryDataset(
                corpus.target.data_file),
            pad_idx=source_dict.pad(),
            eos_idx=source_dict.eos(),
        )

    return dataset
コード例 #5
0
ファイル: data.py プロジェクト: kc17/translate
def load_binarized_dataset(
    train_corpus: ParallelCorpusConfig,
    eval_corpus: ParallelCorpusConfig,
    train_split: str,
    eval_split: str,
    args: argparse.Namespace,
) -> data.LanguageDatasets:
    source_dict = pytorch_translate_dictionary.Dictionary.load(
        args.source_vocab_file)
    target_dict = pytorch_translate_dictionary.Dictionary.load(
        args.target_vocab_file)

    dataset = data.LanguageDatasets(
        src=train_corpus.source.dialect,
        dst=train_corpus.target.dialect,
        src_dict=source_dict,
        dst_dict=target_dict,
    )

    for split, corpus in [(train_split, train_corpus),
                          (eval_split, eval_corpus)]:
        if not os.path.exists(corpus.source.data_file):
            raise ValueError(
                f"{corpus.source.data_file} for {split} not found!")
        if not os.path.exists(corpus.target.data_file):
            raise ValueError(
                f"{corpus.target.data_file} for {split} not found!")

        dataset.splits[split] = data.LanguagePairDataset(
            src=InMemoryNumpyDataset.create_from_file(corpus.source.data_file),
            dst=InMemoryNumpyDataset.create_from_file(corpus.target.data_file),
            pad_idx=source_dict.pad(),
            eos_idx=source_dict.eos(),
        )

    return dataset
コード例 #6
0
    def load_dataset_from_text(
        self,
        split: str,
        source_text_file: str,
        target_text_file: str,
        append_eos: Optional[bool] = False,
        reverse_source: Optional[bool] = True,
    ):
        append_bos = getattr(self.args, "append_bos", False)
        if self.char_target_dict is not None:
            dst_dataset = char_data.InMemoryNumpyWordCharDataset()
            dst_dataset.parse(
                path=target_text_file,
                word_dict=self.target_dictionary,
                char_dict=self.char_target_dict,
                reverse_order=False,
                append_eos=True,
            )
        else:
            dst_dataset = data.IndexedRawTextDataset(
                path=target_text_file,
                dictionary=self.target_dictionary,
                # We always append EOS to the target sentence since we still want
                # the model to output an indication the sentence has finished, even
                # if we don't append the EOS symbol to the source sentence
                # (to prevent the model from misaligning UNKs or other words
                # to the frequently occurring EOS).
                append_eos=True,
                # We don't reverse the order of the target sentence, since
                # even if the source sentence is fed to the model backwards,
                # we still want the model to start outputting from the first word.
                reverse_order=False,
            )

        if self.char_source_dict is not None:
            src_dataset = char_data.InMemoryNumpyWordCharDataset()
            src_dataset.parse(
                path=source_text_file,
                word_dict=self.source_dictionary,
                char_dict=self.char_source_dict,
                reverse_order=reverse_source,
                append_eos=append_eos,
            )
            char_data_class = (
                char_data.LanguagePairCharDataset
                if self.char_target_dict is not None
                else char_data.LanguagePairSourceCharDataset
            )
            self.datasets[split] = char_data_class(
                src=src_dataset,
                src_sizes=src_dataset.sizes,
                src_dict=self.source_dictionary,
                tgt=dst_dataset,
                tgt_sizes=dst_dataset.sizes,
                tgt_dict=self.target_dictionary,
            )
        else:
            src_dataset = data.IndexedRawTextDataset(
                path=source_text_file,
                dictionary=self.source_dictionary,
                append_eos=append_eos,
                reverse_order=reverse_source,
            )
            self.datasets[split] = data.LanguagePairDataset(
                src=src_dataset,
                src_sizes=src_dataset.sizes,
                src_dict=self.source_dictionary,
                tgt=dst_dataset,
                tgt_sizes=dst_dataset.sizes,
                tgt_dict=self.target_dictionary,
                left_pad_source=False,
                append_bos=append_bos,
            )

        print(f"| {split} {len(self.datasets[split])} examples")
コード例 #7
0
def generate(args):
    pytorch_translate_options.print_args(args)

    src_dict = pytorch_translate_dictionary.Dictionary.load(args.source_vocab_file)
    dst_dict = pytorch_translate_dictionary.Dictionary.load(args.target_vocab_file)
    use_char_source = args.char_source_vocab_file != ""
    if use_char_source:
        char_source_dict = pytorch_translate_dictionary.Dictionary.load(
            args.char_source_vocab_file
        )
        # this attribute is used for CharSourceModel construction
        args.char_source_dict_size = len(char_source_dict)
    else:
        char_source_dict = None

    dataset = data.LanguageDatasets(
        src=args.source_lang, dst=args.target_lang, src_dict=src_dict, dst_dict=dst_dict
    )
    models, model_args = pytorch_translate_utils.load_diverse_ensemble_for_inference(
        args.path, dataset.src_dict, dataset.dst_dict
    )
    append_eos_to_source = model_args[0].append_eos_to_source
    reverse_source = model_args[0].reverse_source
    assert all(
        a.append_eos_to_source == append_eos_to_source
        and a.reverse_source == reverse_source
        for a in model_args
    )
    if args.source_binary_file != "":
        assert args.target_binary_file != ""
        dst_dataset = pytorch_translate_data.InMemoryNumpyDataset.create_from_file(
            args.target_binary_file
        )
        if use_char_source:
            src_dataset = char_data.InMemoryNumpyWordCharDataset.create_from_file(
                args.source_binary_file
            )
            gen_split = char_data.LanguagePairSourceCharDataset(
                src=src_dataset,
                dst=dst_dataset,
                pad_idx=src_dict.pad(),
                eos_idx=dst_dict.eos(),
            )
        else:
            src_dataset = pytorch_translate_data.InMemoryNumpyDataset.create_from_file(
                args.source_binary_file
            )
            gen_split = data.LanguagePairDataset(
                src=src_dataset,
                dst=dst_dataset,
                pad_idx=src_dict.pad(),
                eos_idx=dst_dict.eos(),
            )
    elif pytorch_translate_data.is_multilingual(args):
        gen_split = pytorch_translate_data.make_language_pair_dataset_from_text_multilingual(
            source_text_file=args.source_text_file[0],
            target_text_file=args.target_text_file,
            source_lang_id=args.multiling_source_lang_id,
            target_lang_id=args.multiling_target_lang_id,
            source_dict=src_dict,
            target_dict=dst_dict,
            append_eos=append_eos_to_source,
            reverse_source=reverse_source,
        )
    elif args.source_ensembling:
        gen_split = multisource_data.make_multisource_language_pair_dataset_from_text(
            source_text_files=args.source_text_file,
            target_text_file=args.target_text_file,
            source_dict=src_dict,
            target_dict=dst_dict,
            append_eos=append_eos_to_source,
            reverse_source=reverse_source,
        )
    else:
        gen_split = pytorch_translate_data.make_language_pair_dataset_from_text(
            source_text_file=args.source_text_file[0],
            target_text_file=args.target_text_file,
            source_dict=src_dict,
            target_dict=dst_dict,
            append_eos=append_eos_to_source,
            reverse_source=reverse_source,
            char_source_dict=char_source_dict,
        )
    dataset.splits[args.gen_subset] = gen_split

    if args.source_lang is None or args.target_lang is None:
        # record inferred languages in args
        args.source_lang, args.target_lang = dataset.src, dataset.dst

    print(f"| [{dataset.src}] dictionary: {len(dataset.src_dict)} types")
    print(f"| [{dataset.dst}] dictionary: {len(dataset.dst_dict)} types")
    print(f"| {args.gen_subset} {len(dataset.splits[args.gen_subset])} examples")
    scorer, num_sentences, gen_timer, _ = _generate_score(
        models=models, args=args, dataset=dataset, dataset_split=args.gen_subset
    )
    print(
        f"| Translated {num_sentences} sentences ({gen_timer.n} tokens) "
        f"in {gen_timer.sum:.1f}s ({1. / gen_timer.avg:.2f} tokens/s)"
    )
    print(
        f"| Generate {args.gen_subset} with beam={args.beam}: "
        f"{scorer.result_string()}"
    )
    return scorer.score()
コード例 #8
0
def decode_from_file(models, task, args, use_cuda, source_filename=None,
                     target_filename=None, output_filename=None):
    # Set dictionaries
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # I/O files
    source_filename = source_filename if source_filename is not None else args.decode_source_file
    target_filename = target_filename if target_filename is not None else args.decode_target_file
    output_filename = output_filename if output_filename is not None else args.decode_output_file
    if output_filename is not None:
        base_filename = output_filename
    else:
        base_filename = source_filename
        if args.num_shards:
            base_filename += "%.2d" % args.shard_id
    decode_filename = _decode_filename(base_filename, args)
    outfile = open(decode_filename, "w")
    if args.decode_to_file:
        print("| [decode] writing decodes into {}".format(decode_filename))

    # Get sorted input (and reversed)
    sorted_inputs, sorted_keys, sorted_targets = _get_sorted_inputs(
        source_filename, args.num_shards, args.delimiter, target_filename, args.shard_id)

    # Build input iterator
    src_tokens = [
        tokenizer.Tokenizer.tokenize(src_str, src_dict, add_if_not_exist=False).long()
        for src_str in sorted_inputs]
    src_sizes = np.array([t.numel() for t in src_tokens])
    tgt_tokens = [
        tokenizer.Tokenizer.tokenize(tgt_str, tgt_dict, add_if_not_exist=False).long()
        for tgt_str in sorted_targets] if sorted_targets is not None else None
    tgt_sizes = np.array([t.numel() for t in tgt_tokens]) if tgt_tokens is not None else None
    print('| loading {} examples, {} tokens'.format(len(sorted_inputs), sum(src_sizes)))

    dataset = data.LanguagePairDataset(
        src_tokens, src_sizes, src_dict, tgt_tokens, tgt_sizes, tgt_dict, shuffle=False)
    itr = data.EpochBatchIterator(
        dataset=dataset,
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=models[0].max_positions(),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=8,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    gen_timer = StopwatchMeter()
    if args.score_reference:
        translator = SequenceScorer(models, task.target_dictionary)
    else:
        translator = SequenceGenerator(
            models, task.target_dictionary, beam_size=args.beam,
            stop_early=(not args.no_early_stop), normalize_scores=(not args.unnormalized),
            len_penalty=args.lenpen, unk_penalty=args.unkpen,
            sampling=args.sampling, sampling_topk=args.sampling_topk, minlen=args.min_len,
        )

    if use_cuda:
        translator.cuda()

    # Generate and compute BLEU score
    scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
    num_sentences = 0
    has_target = True

    if args.score_reference:
        translations = translator.score_batched_itr(itr, cuda=use_cuda, timer=gen_timer)
    else:
        translations = translator.generate_batched_itr(
            itr, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b,
            cuda=use_cuda, timer=gen_timer, prefix_size=args.prefix_size,
        )

    decodes = dict()
    sids = []
    wps_meter = TimeMeter()
    start = time.perf_counter()
    for sample_id, src_tokens, target_tokens, hypos in translations:
        # Process input and ground truth
        has_target = target_tokens is not None
        target_tokens = target_tokens.int().cpu() if has_target else None

        # Either retrieve the original sentences or regenerate them from tokens.
        if align_dict is not None:
            src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id)
            target_str = task.dataset(args.gen_subset).tgt.get_original_text(sample_id)
        else:
            src_str = src_dict.string(src_tokens, args.remove_bpe)
            if has_target:
                target_str = tgt_dict.string(target_tokens, args.remove_bpe, escape_unk=True)

        if not args.quiet:
            try:
                print('S-{}\t{}'.format(sample_id, src_str))
                if has_target:
                    print('T-{}\t{}'.format(sample_id, target_str))
            except:
                print('S-{}\t{}'.format(sample_id, src_str.encode('utf-8')))
                if has_target:
                    print('T-{}\t{}'.format(sample_id, target_str.encode('utf-8')))

        # Process top predictions
        for i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]):
            hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                hypo_tokens=hypo['tokens'].int().cpu(),
                src_str=src_str,
                alignment=hypo['alignment'].int().cpu(),
                align_dict=align_dict,
                tgt_dict=tgt_dict,
                remove_bpe=args.remove_bpe,
            )
            if i == 0:
                decodes[sample_id.tolist()] = hypo_str
                # sids.append(sample_id.tolist())

            if not args.quiet:
                try:
                    print('H-{}\t{}\t{}'.format(sample_id, hypo['score'], hypo_str))
                except:
                    print('H-{}\t{}\t{}'.format(sample_id, hypo['score'], hypo_str.encode('utf-8')))
                print('P-{}\t{}'.format(
                    sample_id,
                    ' '.join(map(
                        lambda x: '{:.4f}'.format(x),
                        hypo['positional_scores'].tolist(),
                    ))
                ))
                print('A-{}\t{}'.format(
                    sample_id,
                    ' '.join(map(lambda x: str(utils.item(x)), alignment))
                ))

            # Score only the top hypothesis
            if has_target and i == 0:
                if align_dict is not None or args.remove_bpe is not None:
                    # Convert back to tokens for evaluation with unk replacement and/or without BPE
                    target_tokens = tokenizer.Tokenizer.tokenize(
                        target_str, tgt_dict, add_if_not_exist=True)
                scorer.add(target_tokens, hypo_tokens)

        wps_meter.update(src_tokens.size(0))

        num_sentences += 1
        if args.quiet and num_sentences % 100 == 0:
            print("| {} / {} sentences decoded ({})".format(num_sentences, len(sorted_inputs), len(decodes)))

    used_time = time.perf_counter() - start
    print("| Used time:" + repr(used_time))
    print("| Average time:" + repr(used_time / len(sorted_inputs)))

    if args.decode_to_file:
        print("| [decode] writing decodes into {}".format(decode_filename))
        # print(sids)
        for index in range(len(sorted_inputs)):
            try:
                outfile.write("{}{}".format(decodes[sorted_keys[index]], args.delimiter))
            except:
                outfile.write("{}{}".format(decodes[sorted_keys[index]].encode('utf-8'), args.delimiter))
        outfile.close()

    print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format(
        num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()))
コード例 #9
0
def eval_from_file(models, task, args, use_cuda, source_filename=None,
                   target_filename=None, score_filename=None):
    # Set dictionaries
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary

    # I/O files
    source_filename = source_filename if source_filename is not None else args.source_file
    target_filename = target_filename if target_filename is not None else args.target_file
    score_filename = score_filename if score_filename is not None else args.score_file
    if score_filename is None:
        score_filename = target_filename + ".eval.score"
    outfile = open(score_filename, "w")

    # Get sorted input (and reversed)
    sorted_inputs, sorted_keys, sorted_targets = _get_sorted_inputs(
        source_filename, args.num_shards, args.delimiter, target_filename, args.shard_id,
        args.dup_src, args.dup_tgt)

    # Build input iterator
    src_tokens = [
        tokenizer.Tokenizer.tokenize(src_str, src_dict, add_if_not_exist=False).long()
        for src_str in sorted_inputs]
    tgt_tokens = [
        tokenizer.Tokenizer.tokenize(tgt_str, tgt_dict, add_if_not_exist=False).long()
        for tgt_str in sorted_targets] if sorted_targets is not None else None
    src_sizes = np.array([t.numel() for t in src_tokens])
    tgt_sizes = np.array([t.numel() for t in tgt_tokens])
    print('| loading {} examples, {} tokens'.format(len(sorted_inputs), sum(src_sizes)))

    dataset = data.LanguagePairDataset(
        src_tokens, src_sizes, src_dict, tgt_tokens, tgt_sizes, tgt_dict, shuffle=False)
    itr = data.EpochBatchIterator(
        dataset=dataset,
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=models[0].max_positions(),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=8,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
    ).next_epoch_itr(shuffle=False)

    gen_timer = StopwatchMeter()
    scorer = SequenceScorer(models, task.target_dictionary)
    if use_cuda:
        scorer.cuda()

    all_scores = dict()
    score_sum = 0.
    count, sen_count = 0, 0
    results = scorer.score_batched_itr(itr, cuda=use_cuda, timer=gen_timer)
    wps_meter = TimeMeter()
    for sample_id, src_tokens, target_tokens, hypos in results:
        for i, hypo in enumerate(hypos):
            pos_scores = hypo['positional_scores']
            inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf'))
            if inf_scores.any():
                print('| Skipping tokens with inf scores:',
                      task.target_dictionary.string(hypo['tokens'][inf_scores.nonzero()]))
                pos_scores = pos_scores[(~inf_scores).nonzero()]
            score_sum += pos_scores.sum()
            count += pos_scores.numel()
            sentence_score = hypo['score']
            if i == 0:
                all_scores[sample_id.tolist()] = sentence_score
        sen_count += 1
        wps_meter.update(src_tokens.size(0))

    print("| [eval] writing scores into {}".format(score_filename))
    # print(sids)
    for index in range(len(sorted_inputs)):
        outfile.write("{}{}".format(all_scores[sorted_keys[index]], args.delimiter))
    outfile.close()

    avg_nll_loss = -score_sum / count
    print('| Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
    print('| Loss: {:.4f}, Perplexity: {:.2f}'.format(avg_nll_loss, np.exp(avg_nll_loss)))
コード例 #10
0
def load_binarized_dataset(
    train_corpus: ParallelCorpusConfig,
    eval_corpus: ParallelCorpusConfig,
    train_split: str,
    eval_split: str,
    args: argparse.Namespace,
    use_char_source: bool = False,
) -> data.LanguageDatasets:
    if is_multilingual(args):  # Dummy dictionaries
        source_dict = pytorch_translate_dictionary.Dictionary()
        target_dict = pytorch_translate_dictionary.Dictionary()
    else:
        source_dict = pytorch_translate_dictionary.Dictionary.load(
            args.source_vocab_file)
        target_dict = pytorch_translate_dictionary.Dictionary.load(
            args.target_vocab_file)

    if use_char_source:
        char_source_dict = pytorch_translate_dictionary.Dictionary.load(
            args.char_source_vocab_file)
        # this attribute is used for CharSourceModel construction
        args.char_source_dict_size = len(char_source_dict)

    dataset = data.LanguageDatasets(
        src=train_corpus.source.dialect,
        dst=train_corpus.target.dialect,
        src_dict=source_dict,
        dst_dict=target_dict,
    )

    for split, corpus in [(train_split, train_corpus),
                          (eval_split, eval_corpus)]:
        if not os.path.exists(corpus.source.data_file):
            raise ValueError(
                f"{corpus.source.data_file} for {split} not found!")
        if not os.path.exists(corpus.target.data_file):
            raise ValueError(
                f"{corpus.target.data_file} for {split} not found!")

        dst_dataset = InMemoryNumpyDataset.create_from_file(
            corpus.target.data_file)
        if use_char_source:
            src_dataset = char_data.InMemoryNumpyWordCharDataset.create_from_file(
                corpus.source.data_file)
            dataset.splits[split] = char_data.LanguagePairSourceCharDataset(
                src=src_dataset,
                dst=dst_dataset,
                pad_idx=source_dict.pad(),
                eos_idx=source_dict.eos(),
            )
        else:
            src_dataset = InMemoryNumpyDataset.create_from_file(
                corpus.source.data_file)
            dataset.splits[split] = data.LanguagePairDataset(
                src=src_dataset,
                dst=dst_dataset,
                pad_idx=source_dict.pad(),
                eos_idx=source_dict.eos(),
            )

    return dataset