Exemplo n.º 1
0
    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = utils.split_paths(self.args.data)
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]
        split_path = os.path.join(data_path, split)

        dataset = data_utils.load_indexed_dataset(
            split_path,
            self.dictionary,
            self.args.dataset_impl,
            combine=combine,
        )
        if dataset is None:
            raise FileNotFoundError("Dataset not found: {} ({})".format(
                split, split_path))

        dataset = StripTokenDataset(dataset, self.dictionary.eos())

        # create continuous blocks of tokens
        dataset = TokenBlockDataset(
            dataset,
            dataset.sizes,
            self.args.tokens_per_sample -
            2,  # one less for <s> and one for </s>
            pad=self.dictionary.pad(),
            eos=self.dictionary.eos(),
            break_mode=self.args.sample_break_mode,
            document_sep_len=0,
        )

        # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
        dataset = PrependTokenDataset(dataset, self.source_dictionary.bos())
        dataset = AppendTokenDataset(dataset, self.source_dictionary.eos())

        mask_whole_words = (get_whole_word_mask(self.args,
                                                self.source_dictionary)
                            if self.args.mask_length != "subword" else None)

        self.datasets[split] = DenoisingDataset(
            dataset,
            dataset.sizes,
            self.dictionary,
            self.mask_idx,
            mask_whole_words,
            shuffle=self.args.shuffle_instance,
            seed=self.seed,
            args=self.args,
        )
        logger.info(
            "Split: {0}, Loaded {1} samples of denoising_dataset".format(
                split,
                len(self.datasets[split]),
            ))
 def load_noise_hyperparam(self, noise_params):
     self.noise = True
     self.noise_params = noise_params
     self.mask_idx = self.task.mask_idx
     self.vocab_length = len(self.task.source_dictionary)
     if self.noise_params['mask_whole_word']:
         self.noise_params['mask_whole_word'] = get_whole_word_mask(
             self.args, self.task.ource_dictionary)
     else:
         self.noise_params['mask_whole_word'] = None
Exemplo n.º 3
0
 def load_noise_hyperparam(self, noise_params):
     self.noise = True
     self.noise_params = noise_params
     self.seed_noise_params = noise_params
     self.mask_idx = self.task.mask_idx
     self.vocab_length = len(self.task.source_dictionary)
     if self.noise_params['mask_whole_word']:
         self.noise_params['mask_whole_word'] = get_whole_word_mask(self.args, self.task.source_dictionary)
     else:
         self.noise_params['mask_whole_word'] = None
     self.seed = noise_params['seed']
     np.random.seed(self.seed)
     utils.set_torch_seed(self.seed)
Exemplo n.º 4
0
def gen_whole_word_mask(args, dictionary):
    def is_beginning_of_word(i):
        if i < dictionary.nspecial:
            # special elements are always considered beginnings
            return True
        tok = dictionary[i]
        if tok.startswith("madeupword"):
            return True

        if tok in ["<unk>", "<s>", "</s>", "<pad>"]:
            return True
        return tok.startswith("\u2581")

    if args.use_mask_whole_words:
        mask_whole_words = torch.ByteTensor(
            list(map(is_beginning_of_word, range(len(dictionary)))))
    else:
        # it will mask every token as word leading token, since no bpe model is loaded for phoneme tokens
        return get_whole_word_mask(args, dictionary)
    return mask_whole_words
Exemplo n.º 5
0
    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        dataset = self._load_dataset_split(split, epoch, combine)

        mask_whole_words = (
            get_whole_word_mask(self.cfg.bpe, self.source_dictionary)
            if self.cfg.mask_length != "subword"
            else None
        )

        self.datasets[split] = DenoisingDataset(
            dataset,
            dataset.sizes,
            self.dictionary,
            self.mask_idx,
            mask_whole_words,
            shuffle=self.cfg.shuffle_instance,
            seed=self.cfg.seed,
            mask=self.cfg.mask,
            mask_random=self.cfg.mask_random,
            insert=self.cfg.insert,
            rotate=self.cfg.rotate,
            permute_sentences=self.cfg.permute_sentences,
            bpe=self.cfg.bpe,
            replace_length=self.cfg.replace_length,
            mask_length=self.cfg.mask_length,
            poisson_lambda=self.cfg.poisson_lambda,
        )
        logger.info(
            "Split: {0}, Loaded {1} samples of denoising_dataset".format(
                split,
                len(self.datasets[split]),
            )
        )
Exemplo n.º 6
0
    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = utils.split_paths(self.args.data)
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]
        split_path = os.path.join(data_path, split)

        dataset = data_utils.load_indexed_dataset(
            split_path,
            self.source_dictionary,
            self.args.dataset_impl,
            combine=combine,
        )
        if dataset is None:
            raise FileNotFoundError("Dataset not found: {} ({})".format(
                split, split_path))

        dataset = maybe_shorten_dataset(
            dataset,
            split,
            self.args.shorten_data_split_list,
            self.args.shorten_method,
            self.args.tokens_per_sample,
            self.args.seed,
        )

        # create continuous blocks of tokens
        dataset = TokenBlockDataset(
            dataset,
            dataset.sizes,
            self.args.tokens_per_sample - 1,  # one less for <s>
            pad=self.source_dictionary.pad(),
            eos=self.source_dictionary.eos(),
            break_mode=self.args.sample_break_mode,
        )
        logger.info("loaded {} blocks from: {}".format(len(dataset),
                                                       split_path))

        # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
        dataset = PrependTokenDataset(dataset, self.source_dictionary.bos())

        # create masked input and targets
        mask_whole_words = (get_whole_word_mask(self.args,
                                                self.source_dictionary)
                            if self.args.mask_whole_words else None)

        src_dataset, tgt_dataset = MaskTokensDataset.apply_mask(
            dataset,
            self.source_dictionary,
            pad_idx=self.source_dictionary.pad(),
            mask_idx=self.mask_idx,
            seed=self.args.seed,
            mask_prob=self.args.mask_prob,
            leave_unmasked_prob=self.args.leave_unmasked_prob,
            random_token_prob=self.args.random_token_prob,
            freq_weighted_replacement=self.args.freq_weighted_replacement,
            mask_whole_words=mask_whole_words,
            mask_multiple_length=self.args.mask_multiple_length,
            mask_stdev=self.args.mask_stdev,
        )

        with data_utils.numpy_seed(self.args.seed + epoch):
            shuffle = np.random.permutation(len(src_dataset))

        self.datasets[split] = SortDataset(
            NestedDictionaryDataset(
                {
                    "id":
                    IdDataset(),
                    "net_input": {
                        "src_tokens":
                        RightPadDataset(
                            src_dataset,
                            pad_idx=self.source_dictionary.pad(),
                        ),
                        "src_lengths":
                        NumelDataset(src_dataset, reduce=False),
                    },
                    "target":
                    RightPadDataset(
                        tgt_dataset,
                        pad_idx=self.source_dictionary.pad(),
                    ),
                    "nsentences":
                    NumSamplesDataset(),
                    "ntokens":
                    NumelDataset(src_dataset, reduce=True),
                },
                sizes=[src_dataset.sizes],
            ),
            sort_order=[
                shuffle,
                src_dataset.sizes,
            ],
        )
Exemplo n.º 7
0
    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = self.args.data.split(":")
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]
        split_path = os.path.join(data_path, split)

        if self.langs is None:
            languages = sorted(
                [
                    name
                    for name in os.listdir(data_path)
                    if os.path.isdir(os.path.join(data_path, name))
                ]
            )
        else:
            languages = self.langs.split(",")
            for name in languages:
                p = os.path.join(data_path, name)
                assert os.path.exists(p), "data not found: {}".format(p)

        logger.info("Training on {0} languages: {1}".format(len(languages), languages))
        logger.info(
            "Language to id mapping: ", {lang: id for id, lang in enumerate(languages)}
        )

        mask_whole_words = get_whole_word_mask(self.args, self.dictionary)
        language_without_segmentations = self.args.no_whole_word_mask_langs.split(",")
        lang_datasets = []
        for language in languages:
            split_path = os.path.join(data_path, language, split)

            dataset = data_utils.load_indexed_dataset(
                split_path,
                self.source_dictionary,
                self.args.dataset_impl,
                combine=combine,
            )
            if dataset is None:
                raise FileNotFoundError(
                    "Dataset not found: {} ({})".format(split, split_path)
                )

            end_token = (
                self.source_dictionary.index("[{}]".format(language))
                if self.args.add_lang_token
                else self.source_dictionary.eos()
            )

            # create continuous blocks of tokens
            dataset = TokenBlockDataset(
                dataset,
                dataset.sizes,
                self.args.tokens_per_sample - 2,  # one less for <s>
                pad=self.source_dictionary.pad(),
                eos=end_token,
                break_mode=self.args.sample_break_mode,
            )
            logger.info("loaded {} blocks from: {}".format(len(dataset), split_path))

            # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
            dataset = PrependTokenDataset(dataset, self.source_dictionary.bos())
            dataset = AppendTokenDataset(dataset, end_token)

            lang_mask_whole_words = (
                mask_whole_words
                if language not in language_without_segmentations
                else None
            )
            lang_dataset = DenoisingDataset(
                dataset,
                dataset.sizes,
                self.dictionary,
                self.mask_idx,
                lang_mask_whole_words,
                shuffle=self.args.shuffle_instance,
                seed=self.seed,
                args=self.args,
                eos=None
                if not self.args.add_lang_token
                else self.source_dictionary.index("[{}]".format(language)),
            )
            lang_datasets.append(lang_dataset)

        dataset_lengths = np.array(
            [len(d) for d in lang_datasets],
            dtype=float,
        )
        logger.info(
            "loaded total {} blocks for all languages".format(
                int(dataset_lengths.sum()),
            )
        )
        if split == self.args.train_subset:
            # For train subset, additionally up or down sample languages.
            sample_probs = self._get_sample_prob(dataset_lengths)
            logger.info(
                "Sample probability by language: {}".format(
                    {
                        lang: "{0:.4f}".format(sample_probs[id])
                        for id, lang in enumerate(languages)
                    }
                )
            )
            size_ratio = (sample_probs * dataset_lengths.sum()) / dataset_lengths
            logger.info(
                "Up/Down Sampling ratio by language: {}".format(
                    {
                        lang: "{0:.2f}".format(size_ratio[id])
                        for id, lang in enumerate(languages)
                    }
                )
            )

            resampled_lang_datasets = [
                ResamplingDataset(
                    lang_datasets[i],
                    size_ratio=size_ratio[i],
                    seed=self.args.seed,
                    epoch=epoch,
                    replace=size_ratio[i] >= 1.0,
                )
                for i, d in enumerate(lang_datasets)
            ]
            dataset = ConcatDataset(
                resampled_lang_datasets,
            )
        else:
            dataset = ConcatDataset(lang_datasets)
            lang_splits = [split]
            for lang_id, lang_dataset in enumerate(lang_datasets):
                split_name = split + "_" + languages[lang_id]
                lang_splits.append(split_name)
                self.datasets[split_name] = lang_dataset

            if split in self.args.valid_subset:
                self.args.valid_subset = self.args.valid_subset.replace(
                    split, ",".join(lang_splits)
                )

        with data_utils.numpy_seed(self.args.seed + epoch):
            shuffle = np.random.permutation(len(dataset))

        self.datasets[split] = SortDataset(
            dataset,
            sort_order=[
                shuffle,
                dataset.sizes,
            ],
        )
Exemplo n.º 8
0
    def load_dataset(self, split, epoch=0, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = self.args.data.split(':')
        assert len(paths) > 0
        data_path = paths[epoch % len(paths)]
        split_path = os.path.join(data_path, split)

        dataset = data_utils.load_indexed_dataset(
            split_path,
            self.source_dictionary,
            self.args.dataset_impl,
            combine=combine,
        )
        if dataset is None:
            raise FileNotFoundError('Dataset not found: {} ({})'.format(
                split, split_path))

        # create continuous blocks of tokens
        dataset = TokenBlockDataset(
            dataset,
            dataset.sizes,
            self.args.tokens_per_sample - 1,  # one less for <s>
            pad=self.source_dictionary.pad(),
            eos=self.source_dictionary.eos(),
            break_mode=self.args.sample_break_mode,
        )
        print('| loaded {} blocks from: {}'.format(len(dataset), split_path))

        # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
        dataset = PrependTokenDataset(dataset, self.source_dictionary.bos())

        # create masked input and targets
        mask_whole_words = get_whole_word_mask(self.args, self.source_dictionary) \
            if self.args.mask_whole_words else None

        src_dataset, tgt_dataset = MaskTokensDataset.apply_mask(
            dataset,
            self.source_dictionary,
            pad_idx=self.source_dictionary.pad(),
            mask_idx=self.mask_idx,
            seed=self.args.seed,
            mask_prob=self.args.mask_prob,
            leave_unmasked_prob=self.args.leave_unmasked_prob,
            random_token_prob=self.args.random_token_prob,
            freq_weighted_replacement=self.args.freq_weighted_replacement,
            mask_whole_words=mask_whole_words,
        )

        with data_utils.numpy_seed(self.args.seed + epoch):
            shuffle = np.random.permutation(len(src_dataset))

        self.datasets[split] = SortDataset(
            NestedDictionaryDataset(
                {
                    'id':
                    IdDataset(),
                    'net_input': {
                        'src_tokens':
                        PadDataset(
                            src_dataset,
                            pad_idx=self.source_dictionary.pad(),
                            left_pad=False,
                        ),
                        'src_lengths':
                        NumelDataset(src_dataset, reduce=False),
                    },
                    'target':
                    PadDataset(
                        tgt_dataset,
                        pad_idx=self.source_dictionary.pad(),
                        left_pad=False,
                    ),
                    'nsentences':
                    NumSamplesDataset(),
                    'ntokens':
                    NumelDataset(src_dataset, reduce=True),
                },
                sizes=[src_dataset.sizes],
            ),
            sort_order=[
                shuffle,
                src_dataset.sizes,
            ],
        )
Exemplo n.º 9
0
        dataset,
        dataset.sizes,
        tokens_per_sample - 2,  # one less for <s> and one for </s>
        pad=dictionary.pad(),
        eos=dictionary.eos(),
        break_mode=args.sample_break_mode,
        document_sep_len=0,
    )

    assert len(dataset) == prev_size

    # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
    dataset = PrependTokenDataset(dataset, source_dictionary.bos())
    dataset = AppendTokenDataset(dataset, source_dictionary.eos())

    mask_whole_words = (get_whole_word_mask(args, source_dictionary)
                        if mask_length != 'subword' else None)

    bpe = encoders.build_bpe(args)
    eoh = dictionary.indices[bpe.encode('</h>')]
    denoising_dataset = DenoisingDataset(dataset,
                                         dataset.sizes,
                                         dictionary,
                                         mask_idx,
                                         mask_whole_words,
                                         shuffle=False,
                                         seed=seed,
                                         args=args,
                                         eoh=eoh)

    for i in range(len(denoising_dataset)):
Exemplo n.º 10
0
    def load_dataset(self, split, epoch=0, combine=False, **kwargs):
        text_data = safe_load_indexed_dataset(
            os.path.join(self.args.data_path, split + '.text'), )
        annotation_data = MMapNumpyArray(
            os.path.join(self.args.data_path, split + '.annotations.npy'), )
        annotated_text = AnnotatedText(
            text_data=text_data,
            annotation_data=annotation_data,
            dictionary=self.dictionary,
            mask_type=self.args.mask_type,
            non_mask_rate=self.args.non_mask_rate,
        )
        dataset = TokenBlockAnnotatedDataset(
            annotated_text=annotated_text,
            max_positions=self.max_positions() - 5,  # <cls>, e1/e2 start/end
            pad=self.dictionary.pad(),
            eos=self.dictionary.eos(),
            seed=self.seed,
            document_sep_len=1,
        )
        if split == 'train' and self.args.epoch_size is not None:
            dataset = EpochSplitDataset(
                dataset=dataset,
                epoch_size=self.args.epoch_size,
                seed=self.args.seed,
            )

        dataset = PrependTokenDataset(dataset, self.dictionary.bos())

        # create masked input and targets
        mask_whole_words = get_whole_word_mask(self.args, self.source_dictionary) \
            if self.args.mask_whole_words else None

        src_dataset, tgt_dataset = CustomMaskTokensDataset.apply_mask(
            dataset,
            self.dictionary,
            pad_idx=self.dictionary.pad(),
            mask_idx=self.dictionary.mask(),
            seed=self.seed,
            mask_prob=self.args.mask_prob,
            leave_unmasked_prob=self.args.leave_unmasked_prob,
            random_token_prob=self.args.random_token_prob,
            freq_weighted_replacement=self.args.freq_weighted_replacement,
            mask_whole_words=mask_whole_words,
        )

        dataset = DictionaryDataset(
            {
                'id':
                IdDataset(),
                'src_tokens':
                PadDataset(
                    src_dataset,
                    pad_idx=self.source_dictionary.pad(),
                    left_pad=False,
                ),
                'src_lengths':
                NumelDataset(src_dataset, reduce=False),
                'target':
                PadDataset(
                    tgt_dataset,
                    pad_idx=self.source_dictionary.pad(),
                    left_pad=False,
                ),
                'nsentences':
                NumSamplesDataset(),
                'ntokens':
                NumelDataset(src_dataset, reduce=True),
            },
            main_key='src_tokens',
        )

        n_examples = getattr(self.args, 'n_' + split + '_examples', None)
        if n_examples is not None:
            dataset = FixedSizeDataset(
                dataset=dataset,
                size=n_examples,
                seed=self.seed,
            )

        self.datasets[split] = dataset
Exemplo n.º 11
0
    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = utils.split_paths(self.args.data)
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]
        split_path = os.path.join(data_path, split)

        dataset = data_utils.load_indexed_dataset(
            split_path,
            self.source_dictionary,
            self.args.dataset_impl,
            combine=combine,
        )
        if dataset is None:
            raise FileNotFoundError('Dataset not found: {} ({})'.format(split, split_path))

        dataset = maybe_shorten_dataset(
            dataset,
            split,
            self.args.shorten_data_split_list,
            self.args.shorten_method,
            self.args.tokens_per_sample,
            self.args.seed,
        )

        # create continuous blocks of tokens
        dataset = TokenBlockDataset(
            dataset,
            dataset.sizes,
            self.args.tokens_per_sample,
            pad=self.source_dictionary.pad(),
            eos=self.source_dictionary.eos(),
            break_mode=self.args.sample_break_mode,
        )
        logger.info('loaded {} blocks from: {}'.format(len(dataset), split_path))

        # remove tail
        dataset = RemoveTailDataset(dataset)

        # create masked input and targets
        mask_whole_words = get_whole_word_mask(self.args, self.source_dictionary) \
            if self.args.mask_whole_words else None

        src_dataset, tgt_dataset = MaskTokensDataset.apply_mask(
            dataset,
            self.source_dictionary,
            pad_idx=self.source_dictionary.pad(),
            mask_idx=self.mask_idx,
            seed=self.args.seed,
            mask_prob=self.args.mask_prob,
            leave_unmasked_prob=self.args.leave_unmasked_prob,
            random_token_prob=self.args.random_token_prob,
            freq_weighted_replacement=self.args.freq_weighted_replacement,
            mask_whole_words=mask_whole_words,
        )

        with data_utils.numpy_seed(self.args.seed + epoch):
            shuffle = np.random.permutation(len(src_dataset))

        self.datasets[split] = SortDataset(
            NestedDictionaryDataset(
                {
                    'id': IdDataset(),
                    'net_input': {
                        'src_tokens': RightPadDataset(
                            src_dataset,
                            pad_idx=self.source_dictionary.pad(),
                        ),
                        'src_lengths': NumelDataset(src_dataset, reduce=False),
                    },
                    'target': RightPadDataset(
                        tgt_dataset,
                        pad_idx=self.source_dictionary.pad(),
                    ),
                    'nsentences': NumSamplesDataset(),
                    'ntokens': NumelDataset(src_dataset, reduce=True),
                },
                sizes=[src_dataset.sizes],
            ),
            sort_order=[
                shuffle,
                src_dataset.sizes,
            ],
        )
Exemplo n.º 12
0
    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        dataset = self._load_dataset_split(split, epoch, combine)

        # create masked input and targets
        mask_whole_words = (get_whole_word_mask(self.args,
                                                self.source_dictionary)
                            if self.cfg.mask_whole_words else None)

        src_dataset, tgt_dataset = MaskTokensDataset.apply_mask(
            dataset,
            self.source_dictionary,
            pad_idx=self.source_dictionary.pad(),
            mask_idx=self.mask_idx,
            seed=self.cfg.seed,
            mask_prob=self.cfg.mask_prob,
            leave_unmasked_prob=self.cfg.leave_unmasked_prob,
            random_token_prob=self.cfg.random_token_prob,
            freq_weighted_replacement=self.cfg.freq_weighted_replacement,
            mask_whole_words=mask_whole_words,
            mask_multiple_length=self.cfg.mask_multiple_length,
            mask_stdev=self.cfg.mask_stdev,
        )

        with data_utils.numpy_seed(self.cfg.seed):
            shuffle = np.random.permutation(len(src_dataset))

        target_dataset = RightPadDataset(
            tgt_dataset,
            pad_idx=self.source_dictionary.pad(),
        )

        input_dict = {
            "src_tokens":
            RightPadDataset(
                src_dataset,
                pad_idx=self.source_dictionary.pad(),
            ),
            "src_lengths":
            NumelDataset(src_dataset, reduce=False),
        }
        if self.cfg.include_target_tokens:
            input_dict["target_tokens"] = target_dataset

        self.datasets[split] = SortDataset(
            NestedDictionaryDataset(
                {
                    "id": IdDataset(),
                    "net_input": input_dict,
                    "target": target_dataset,
                    "nsentences": NumSamplesDataset(),
                    "ntokens": NumelDataset(src_dataset, reduce=True),
                },
                sizes=[src_dataset.sizes],
            ),
            sort_order=[
                shuffle,
                src_dataset.sizes,
            ],
        )
Exemplo n.º 13
0
    def load_dataset(self, split, combine=False, **kwargs):
        """Load a given dataset split (e.g., train, valid, test)."""
        def get_path(type, split):
            return os.path.join(self.args.data, type, split)

        def make_dataset(type, dictionary):
            split_path = get_path(type, split)

            dataset = data_utils.load_indexed_dataset(
                split_path,
                dictionary,
                self.args.dataset_impl,
                combine=combine,
            )
            return dataset

        # input0 is source, input1 is synthetic target, input2 is reference
        input0 = make_dataset(self.args.input0, self.source_dictionary)
        assert input0 is not None, 'could not find dataset: {}'.format(
            get_path(type, split))
        input1 = make_dataset(self.args.input1, self.source_dictionary)

        if self.args.init_token is not None:
            input0 = PrependTokenDataset(input0, self.args.init_token)

        if self.args.input2 is not None:
            input2 = make_dataset(self.args.input2, self.source_dictionary)

        if self.args.input2 is not None and self.add_ref_prob > 0 and split != 'valid':
            input3 = PrependTokenDataset(input2, self.args.separator_token)
        else:
            input3 = None

        if input1 is None:
            src_tokens = input0
        else:
            if self.args.separator_token is not None:
                input1 = PrependTokenDataset(input1, self.args.separator_token)

            if self.args.input2 is not None and self.add_ref_prob > 0. and split != 'valid':
                src_tokens = ConcatSentencesDataset(
                    input0,
                    input3,
                    input1,
                    add_ref_prob=self.add_ref_prob,
                    drop_ref_rate=self.args.dropout_ref,
                    pad_idx=self.source_dictionary.pad(),
                    eos_idx=self.source_dictionary.eos(),
                    bos_idx=self.source_dictionary.bos())
            else:
                src_tokens = ConcatSentencesDataset(input0, input1)

        with data_utils.numpy_seed(self.args.seed):
            shuffle = np.random.permutation(len(src_tokens))

        if self.args.truncate_sequence:
            src_tokens = TruncateDataset(src_tokens, self.args.max_positions)

        if self.args.input2 is not None and self.args.add_tran_loss:
            # create masked input and targets
            mask_whole_words = get_whole_word_mask(self.args, self.source_dictionary) \
                if self.args.mask_whole_words else None
            ref_dataset, ref_target_dataset = MaskTokensDataset.apply_mask(
                input2,
                self.source_dictionary,
                pad_idx=self.source_dictionary.pad(),
                mask_idx=self.mask_idx,
                seed=self.args.seed,
                mask_prob=self.args.mask_prob,
                leave_unmasked_prob=self.args.leave_unmasked_prob,
                random_token_prob=self.args.random_token_prob,
                freq_weighted_replacement=self.args.freq_weighted_replacement,
                mask_whole_words=mask_whole_words,
            )

            if self.args.separator_token is not None:
                input2 = PrependTokenDataset(ref_dataset,
                                             self.args.separator_token)
            parallel_src_tokens = ConcatSentencesDataset(input0, input2)
            if self.args.truncate_sequence:
                parallel_src_tokens = TruncateDataset(parallel_src_tokens,
                                                      self.args.max_positions)

        dataset = {
            'id': IdDataset(),
            'net_input': {
                'src_tokens':
                RightPadDataset(
                    src_tokens,
                    pad_idx=self.source_dictionary.pad(),
                ),
                'src_lengths':
                NumelDataset(src_tokens, reduce=False),
            },
            'nsentences': NumSamplesDataset(),
            'ntokens': NumelDataset(src_tokens, reduce=True),
        }

        if self.args.input2 is not None and self.args.add_tran_loss:
            dataset['net_input']['parallel_src_tokens'] = RightPadDataset(
                parallel_src_tokens,
                pad_idx=self.source_dictionary.pad(),
            )

        if self.args.add_prev_output_tokens:
            prev_tokens_dataset = RightPadDataset(
                RollDataset(src_tokens, 1),
                pad_idx=self.dictionary.pad(),
            )
            dataset['net_input'].update(
                prev_output_tokens=prev_tokens_dataset, )

        if not self.args.regression_target:
            label_dataset = make_dataset('label', self.label_dictionary)
            if label_dataset is not None:
                dataset.update(target=OffsetTokensDataset(
                    StripTokenDataset(
                        label_dataset,
                        id_to_strip=self.label_dictionary.eos(),
                    ),
                    offset=-self.label_dictionary.nspecial,
                ))
            if self.args.input2 is not None and self.args.add_tran_loss:
                # used as translation target when calculating loss
                dataset.update(parallel_target=RightPadDataset(
                    ref_target_dataset,
                    pad_idx=self.source_dictionary.pad(),
                ))
        else:
            label_path = "{0}.label".format(get_path('label', split))
            if os.path.exists(label_path):

                def parse_regression_target(i, line):
                    values = line.split()
                    assert len(values) == self.args.num_classes, \
                        f'expected num_classes={self.args.num_classes} regression target values on line {i}, found: "{line}"'
                    return [float(x) for x in values]

                dataset.update(target=RawLabelDataset([
                    parse_regression_target(i, line.strip())
                    for i, line in enumerate(open(label_path).readlines())
                ]))

        nested_dataset = NestedDictionaryDataset(
            dataset,
            sizes=[src_tokens.sizes],
            all_sizes=src_tokens.all_sizes
            if self.args.add_target_num_tokens else None,
            padding_idx=self.source_dictionary.pad(),
            add_ref_prob=self.add_ref_prob if split != 'valid' else 0.,
        )

        if self.args.no_shuffle:
            dataset = nested_dataset
        else:
            dataset = SortDataset(
                nested_dataset,
                # shuffle
                sort_order=[shuffle],
            )

        logger.info("Loaded {0} with #samples: {1}".format(
            split, len(dataset)))

        self.datasets[split] = dataset
        return self.datasets[split]
Exemplo n.º 14
0
    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = utils.split_paths(self.args.data)
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]
        split_path = os.path.join(data_path, split) #+ '.bpe'

        dataset = data_utils.load_indexed_dataset(
            split_path,
            self.source_dictionary,
            self.args.dataset_impl,
            combine=combine,
        )
        if dataset is None:
            raise FileNotFoundError('Dataset not found: {} ({})'.format(split, split_path))

        # create continuous blocks of tokens
        dataset = TokenBlockDataset(
            dataset,
            dataset.sizes,
            self.args.tokens_per_sample - 1,  # one less for <s>
            pad=self.source_dictionary.pad(),
            eos=self.source_dictionary.eos(),
            break_mode=self.args.sample_break_mode,
        )
        logger.info('loaded {} blocks from: {}'.format(len(dataset), split_path))

        # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
        dataset = PrependTokenDataset(dataset, self.source_dictionary.bos())

        # create masked input and targets
        mask_whole_words = get_whole_word_mask(self.args, self.source_dictionary) \
            if self.args.mask_whole_words else None

        src_dataset, tgt_dataset = MaskTokensDataset.apply_mask(
            dataset,
            self.source_dictionary,
            pad_idx=self.source_dictionary.pad(),
            mask_idx=self.mask_idx,
            seed=self.args.seed,
            mask_prob=self.args.mask_prob,
            leave_unmasked_prob=self.args.leave_unmasked_prob,
            random_token_prob=self.args.random_token_prob,
            freq_weighted_replacement=self.args.freq_weighted_replacement,
            mask_whole_words=mask_whole_words,
        )

        with data_utils.numpy_seed(self.args.seed + epoch):
            shuffle = np.random.permutation(len(src_dataset))


        # load counts
        thresh = 100
        with open(split_path + '.counts') as count_file:
            lines = [line.rstrip() for line in count_file]
            counts = [line.split(' ') for line in lines]
            for i, count in enumerate(counts):
                count = [int(el) for el in count]
                counts[i] = [el if el < thresh else thresh for el in count]
                counts[i] = torch.LongTensor(np.concatenate([[0],counts[i],[0]]))

        # load embeddings
        if not self.args.input_format=='tokens':
            embs = torch.load(split_path + '.features')


       # mask counts and embeddings
        for i, data in enumerate(src_dataset):
            counts[i] = counts[i] * (data != self.mask_idx)
            embs[i] = embs[i] * (data != self.mask_idx)[1:-1, None]

        self.datasets[split] = SortDataset(
            NestedDictionaryDataset(
                {
                    'id': IdDataset(),
                    'net_input': {
                        'src_tokens': PadDataset(
                            src_dataset,
                            pad_idx=self.source_dictionary.pad(),
                            left_pad=False,
                        ),
                        'src_counts': PadDataset(
                            counts,
                            pad_idx=0,
                            left_pad=False,
                        ),
                        'src_embs': EmbeddingDataset(
                            embs,
                            pad_idx=0,
                            left_pad=False,
                        ) if not self.args.input_format=='tokens' else None,
                        'src_lengths': NumelDataset(src_dataset, reduce=False),
                    },
                    'target': PadDataset(
                        tgt_dataset,
                        pad_idx=self.source_dictionary.pad(),
                        left_pad=False,
                    ),
                    'nsentences': NumSamplesDataset(),
                    'ntokens': NumelDataset(src_dataset, reduce=True),
                },
                sizes=[src_dataset.sizes],
            ),
            sort_order=[
                shuffle,
                src_dataset.sizes,
            ],
        )
Exemplo n.º 15
0
    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = utils.split_paths(self.args.data)
        assert len(paths) > 0
        is_train_subset = split == getattr(self.args, "train_subset", None)
        if not is_train_subset:
            # if not training data set, use the first shard for valid and test
            paths = paths[:1]
        data_path = paths[(epoch - 1) % len(paths)]

        # infer langcode
        src, tgt = self.args.source_lang, self.args.target_lang
        """
        this is mask_word_initial
        WordNoising uses mask_word_end or mask_bpe_cont
        probably easiest to write FlippedDataset that reverses sequences
        and use the standard pipeline

        load_langpair_dataset:
            find files by pattern
            load_indexed source
                maybe truncate
                load target
            check shard counts
            sample ratios
            bos, source_id
            load_alignments
            LangpairDataset constructor

        """

        src_dataset, tgt_dataset = load_unpaired_langpair(
            data_path,
            split,
            src,
            self.src_dict,
            tgt,
            self.tgt_dict,
            combine=combine,
            dataset_impl=self.args.dataset_impl,
            max_source_positions=self.args.max_source_positions,
            max_target_positions=self.args.max_target_positions,
            truncate_source=self.args.truncate_source,
            prepend_bos=self.args.prepend_bos,
        )

        if self.args.bpe_dropout > 0:
            src_dataset = DynamicGPT2BPEDropoutResampling(
                self.args,
                src_dataset,
                self.source_dictionary,
                dropout=self.args.bpe_dropout,
            )

        # load backtranslation
        if is_train_subset and not self.args.skip_backtranslation_data:
            """
            noised vs unnoised valdation set? they might converge at different times
            """
            bt_src_dataset, bt_tgt_dataset = load_unpaired_langpair(
                # data_path, "{}.bt".format(split), src, self.src_dict, tgt, self.tgt_dict,
                data_path,
                "{}.bt".format(split),
                src,
                self.src_dict,
                tgt,
                self.tgt_dict,
                combine=combine,
                dataset_impl=self.args.dataset_impl,
                max_source_positions=self.args.max_source_positions,
                max_target_positions=self.args.max_target_positions,
                truncate_source=self.args.truncate_source,
                prepend_bos=self.args.prepend_bos,
            )
            if self.args.bpe == "gpt2":
                mask_is_beginning_of_word = get_whole_word_mask(
                    self.args, self.source_dictionary)
                mask_is_beginning_of_word = mask_is_beginning_of_word.numpy(
                ).astype(np.bool)
                # noiser = GPT2WordNoising(
                #     self.src_dict,
                #     mask_is_beginning_of_word,
                #     self.args.max_word_shuffle_distance,
                #     self.args.word_dropout_prob,
                #     self.args.word_blanking_prob,
                # )
                if self.args.bpe_dropout > 0:
                    bt_src_dataset = DynamicGPT2BPEDropoutResampling(
                        self.args,
                        bt_src_dataset,
                        self.source_dictionary,
                        dropout=self.args.bpe_dropout,
                    )
                noiser = GPT2WordNoisingV2(
                    self.src_dict,
                    mask_is_beginning_of_word,
                    self.args.max_word_shuffle_distance,
                    self.args.word_dropout_prob,
                    self.args.word_blanking_prob,
                )
                bt_src_dataset = DynamicNoisingDataset(
                    bt_src_dataset,
                    self.src_dict,
                    seed=1,
                    noiser=noiser,
                )

                # try:
                #     from icecream import ic
                #     ic.configureOutput(includeContext=True)
                # except ImportError:  # Graceful fallback if IceCream isn't installed.
                #     ic = lambda *a: None if not a else (a[0] if len(a) == 1 else a)  # noqa
                # ic("gpt2 bbpe")
                # bpe = encoders.build_bpe(self.args)
                # def decode(foo):
                #     return bpe.decode(self.src_dict.string(foo))
                # def disp(foo):
                #     return " ".join([bpe.decode(i) for i in self.src_dict.string(foo).split(" ")])
                #     # foo = [bpe.decode(str(i)) for i in range(0,1000)]
                #     # doo = [bpe.decode((i)) for i in self.src_dict.symbols[4:1000]]
                # for i in range(5):
                #     ic(_bt_src_dataset[i])
                #     ic(decode(_bt_src_dataset[i]))
                #     ic(disp(_bt_src_dataset[i]))
                #     ic(disp(bt_src_dataset[i]))
                #     ic(bt_src_dataset[i])
                # import pdb; pdb.set_trace()
            else:
                assert self.args.bpe_dropout <= 0, "BPE dropout not supported for this BPE scheme"
                # standard bpe with @@ as continuation marker
                bt_src_dataset = DynamicNoisingDataset(
                    bt_src_dataset,
                    self.src_dict,
                    seed=1,
                    max_word_shuffle_distance=self.args.
                    max_word_shuffle_distance,
                    word_dropout_prob=self.args.word_dropout_prob,
                    word_blanking_prob=self.args.word_blanking_prob,
                )
            # if self.append_backtranslation_tag:
            if self.args.tagged_backtranslation:
                bt_src_dataset = AppendTokenDataset(
                    AppendTokenDataset(
                        StripTokenDataset(bt_src_dataset, self.src_dict.eos()),
                        self.bt_idx),
                    self.src_dict.eos(),
                )

            sample_ratios = [self.args.upsample_primary, 1]
            src_dataset = ConcatDataset([src_dataset, bt_src_dataset],
                                        sample_ratios)
            tgt_dataset = ConcatDataset([tgt_dataset, bt_tgt_dataset],
                                        sample_ratios)

        self.datasets[split] = LanguagePairDataset(
            src_dataset,
            src_dataset.sizes,
            self.src_dict,
            tgt_dataset,
            tgt_dataset.sizes,
            self.tgt_dict,
            left_pad_source=self.args.left_pad_source,
            left_pad_target=self.args.left_pad_target,
            align_dataset=None,
            eos=self.tgt_dict.eos(),
            num_buckets=self.args.num_batch_buckets,
            shuffle=(split not in ("test", "valid")),
        )