def _get_eval_dataset(self, task: TranslationTask):
        split_words = self.args.input_type == 'word'
        src_dataset = TextLookupDataset(
            task.src_dataset,
            self.dictionaries[task.source_language + '.src'],
            words=split_words,
            lower=task.lower,
            bos=False,
            eos=False,
            trunc_len=self.args.seq_length_trunc)

        if self.args.eval_noise:
            src_dataset = NoisyTextDataset(src_dataset, self.args.word_shuffle,
                                           self.args.noise_word_dropout,
                                           self.args.word_blank,
                                           self.args.bpe_symbol)

        if task.tgt_dataset is not None:
            tgt_dataset = TextLookupDataset(
                task.tgt_dataset,
                self.dictionaries[task.target_language + '.tgt'],
                words=split_words,
                lower=task.lower,
                bos=True,
                eos=True,
                trunc_len=self.args.seq_length_trunc)
        else:
            tgt_dataset = None
        dataset = ParallelDataset(src_dataset, tgt_dataset)
        return dataset
示例#2
0
    def _get_train_dataset(self):
        logger.info('Loading training data')
        split_words = self.args.input_type == 'word'

        src_data, src_lengths = TextLookupDataset.load(self.args.train_src, self.src_dict, self.args.data_dir,
                                                       self.args.load_into_memory, split_words,
                                                       bos=False, eos=False, trunc_len=self.args.src_seq_length_trunc,
                                                       lower=self.args.lower)

        if self.args.translation_noise:
            src_data = NoisyTextDataset(src_data, self.args.word_shuffle, self.args.noise_word_dropout,
                                        self.args.word_blank, self.args.bpe_symbol)

        tgt_data, tgt_lengths = TextLookupDataset.load(self.args.train_tgt, self.tgt_dict, self.args.data_dir,
                                                       self.args.load_into_memory, split_words,
                                                       bos=True, eos=True, trunc_len=self.args.tgt_seq_length_trunc,
                                                       lower=self.args.lower)
        src_data.lengths = src_lengths
        tgt_data.lengths = tgt_lengths
        dataset = ParallelDataset(src_data, tgt_data)
        logger.info('Number of training sentences: {:,d}'.format(len(dataset)))
        return dataset
    def _get_eval_iterator(self, task):
        split_words = self.args.input_type == 'word'
        noisy_data = NoisyTextDataset(
            TextLookupDataset(task.src_dataset,
                              self.dictionary,
                              split_words,
                              bos=False,
                              eos=False,
                              lower=self.args.lower), self.args.word_shuffle,
            self.args.noise_word_dropout, self.args.word_blank,
            self.args.bpe_symbol)

        clean_data = TextLookupDataset(task.tgt_dataset,
                                       self.dictionary,
                                       split_words,
                                       bos=True,
                                       eos=True,
                                       lower=self.args.lower)

        dataset = ParallelDataset(noisy_data, clean_data)
        return dataset.get_iterator(batch_size=self.args.batch_size,
                                    num_workers=self.args.data_loader_threads,
                                    cuda=self.args.cuda)
    def load_data(self, model_args=None):
        logger.info('Loading training data')
        split_words = self.args.input_type == 'word'

        train_clean_name = os.path.basename(self.args.train_clean)

        if self.args.load_into_memory:
            clean_data = TextLineDataset.load_into_memory(
                self.args.train_clean)
        else:
            offsets = os.path.join(self.args.data_dir,
                                   train_clean_name + '.idx.npy')
            clean_data = TextLineDataset.load_indexed(self.args.train_clean,
                                                      offsets)

        if self.args.train_noisy is not None:
            train_noisy_name = os.path.basename(self.args.train_noisy)
            if self.args.load_into_memory:
                noisy_data = TextLineDataset.load_into_memory(
                    self.args.train_noisy)
            else:
                offsets = os.path.join(self.args.data_dir,
                                       train_noisy_name + '.idx.npy')
                noisy_data = TextLineDataset.load_indexed(
                    self.args.train_noisy, offsets)
        else:
            noisy_data = clean_data

        noisy_data = NoisyTextDataset(
            TextLookupDataset(noisy_data,
                              self.dictionary,
                              words=split_words,
                              bos=False,
                              eos=False,
                              trunc_len=self.args.seq_length_trunc,
                              lower=self.args.lower), self.args.word_shuffle,
            self.args.word_dropout, self.args.word_blank, self.args.bpe_symbol)

        clean_data = TextLookupDataset(clean_data,
                                       self.dictionary,
                                       words=split_words,
                                       bos=True,
                                       eos=True,
                                       trunc_len=self.args.seq_length_trunc,
                                       lower=self.args.lower)
        dataset = ParallelDataset(noisy_data, clean_data)
        logger.info('Number of training sentences: {:,d}'.format(len(dataset)))

        clean_len_filename = os.path.join(self.args.data_dir,
                                          train_clean_name + '.len.npy')
        lengths = np.load(clean_len_filename)

        def filter_fn(i):
            return lengths[i] <= self.args.seq_length

        logger.info('Generating batches')
        batches = generate_length_based_batches_from_lengths(
            lengths,
            self.args.batch_size_words,
            self.args.batch_size_sents,
            self.args.batch_size_multiplier,
            self.args.pad_count,
            filter_fn=filter_fn)
        logger.info('Number of training batches: {:,d}'.format(len(batches)))

        filtered = len(lengths) - sum(len(batch) for batch in batches)
        logger.info('Filtered {:,d}/{:,d} training examples for length'.format(
            filtered, len(lengths)))

        sampler = PreGeneratedBatchSampler(batches, self.args.curriculum == 0)

        model = self.build_model(model_args)
        params = list(filter(lambda p: p.requires_grad, model.parameters()))
        lr_scheduler, optimizer = self._build_optimizer(params)
        return TrainData(model, dataset, sampler, lr_scheduler, optimizer,
                         self._get_training_metrics())
    def _get_train_dataset(self):
        logger.info('Loading training data')
        split_words = self.args.input_type == 'word'
        src_bos = self.model.output_select == 'encoder_bos'
        tgt_lang_bos = self.model.output_select == 'decoder_bos'

        if self.data_mode == 'all_to_all':
            assert self.args.join_src_tgt_vocab
            filenames = {
                lang: name
                for lang, name in map(lambda x: x.split('='),
                                      self.args.train_data)
            }
            datasets = {}
            lengths = {}
            for lang in self.source_languages:  # == target_languages
                dictionary = self.dictionaries[lang + '.src']
                dataset, length = TextLookupDataset.load(
                    filenames[lang],
                    dictionary,
                    self.args.data_dir,
                    self.args.load_into_memory,
                    split_words,
                    bos=src_bos,
                    eos=False,
                    trunc_len=self.args.seq_length_trunc,
                    lower=self.args.lower)
                datasets[lang] = dataset
                lengths[lang] = length

            exclude_pairs = [
                pair.split('-') for pair in self.args.exclude_pairs
            ]

            if len(self.noisy_languages) > 0:
                dataset = NoisyMultiParallelDataset(
                    exclude_pairs, src_bos, tgt_lang_bos,
                    self.args.word_shuffle, self.args.noise_word_dropout,
                    self.args.word_blank, self.args.bpe_symbol,
                    self.noisy_languages, **datasets)
            else:
                dataset = MultiParallelDataset(exclude_pairs, src_bos,
                                               tgt_lang_bos, **datasets)
            dataset.lengths = lengths
            logger.info('Number of training sentences: {:,d}'.format(
                dataset.num_sentences))
            logger.info('Number of training examples: {:,d}'.format(
                len(dataset)))
            return dataset
        else:
            pairs = [pair.split('-') for pair in self.args.langs]
            files = [(self.args.train_data[2 * i],
                      self.args.train_data[2 * i + 1])
                     for i in range(len(pairs))]

            if self.args.invert_pairs:
                pairs.extend([(p[1], p[0]) for p in pairs])
                files.extend([(f[1], f[0]) for f in files])

            datasets = []
            src_lengths = []
            tgt_lengths = []
            for (source, target), (source_file,
                                   target_file) in zip(pairs, files):
                src_dataset, src_length = TextLookupDataset.load(
                    source_file,
                    self.dictionaries[source + '.src'],
                    self.args.data_dir,
                    self.args.load_into_memory,
                    split_words,
                    bos=src_bos,
                    eos=False,
                    trunc_len=self.args.seq_length_trunc,
                    lang=target if src_bos else None)

                if source in self.noisy_languages:
                    src_dataset = NoisyTextDataset(
                        src_dataset, self.args.word_shuffle,
                        self.args.noise_word_dropout, self.args.word_blank,
                        self.args.bpe_symbol)

                tgt_dataset, tgt_length = TextLookupDataset.load(
                    target_file,
                    self.dictionaries[target + '.tgt'],
                    self.args.data_dir,
                    self.args.load_into_memory,
                    split_words,
                    bos=True,
                    eos=True,
                    trunc_len=self.args.seq_length_trunc,
                    lang=target if tgt_lang_bos else None)
                dataset = ParallelDataset(src_dataset, tgt_dataset)
                datasets.append(dataset)
                src_lengths.append(src_length)
                tgt_lengths.append(tgt_length)

            dataset = ConcatDataset(*datasets, balance=self.args.balance_pairs)
            src_lengths = dataset.concat_lengths(*src_lengths)
            tgt_lengths = dataset.concat_lengths(*tgt_lengths)
            dataset.lengths = (src_lengths, tgt_lengths)
            logger.info('Number of training sentences: {:,d}'.format(
                sum(len(ds) for ds in datasets)))
            return dataset
import argparse

from nmtg.data import Dictionary
from nmtg.data.noisy_text import NoisyTextDataset
from nmtg.data.text_lookup_dataset import TextLookupDataset
from nmtg.tasks.denoising_text_task import DenoisingTextTask

parser = argparse.ArgumentParser()
DenoisingTextTask.add_options(parser)
args = parser.parse_args()

task = DenoisingTextTask.setup_task(args)
dictionary = Dictionary.infer_from_text(task.tgt_dataset)

noisy_text = NoisyTextDataset(TextLookupDataset(task.src_dataset, dictionary, True,
                                                args.lower, False, False, False),
                              args.word_shuffle, args.noise_word_dropout, args.word_blank, args.bpe_symbol)

for i in range(len(noisy_text)):
    print(task.tgt_dataset[i])
    print(dictionary.string(noisy_text[i]))
    input()