Ejemplo n.º 1
0
 def __getitem__(self, index):
     with data_utils.numpy_seed(self.seed, self.epoch, index):
         item = self.dataset[index]
         item_len = item.size(0)
         excess = item_len - self.truncation_length
         if excess > 0:
             start_idx = np.random.randint(0, excess)
             item = item[start_idx:start_idx + self.truncation_length]
         return item
Ejemplo n.º 2
0
        def construct_batch_sampler(dataset, epoch):
            splits = [
                s for s, _ in self.datasets.items()
                if self.datasets[s] == dataset
            ]
            split = splits[0] if len(splits) > 0 else None
            # NEW implementation
            if epoch is not None:
                # initialize the dataset with the correct starting epoch
                dataset.set_epoch(epoch)

            # get indices ordered by example size
            start_time = time.time()
            logger.info(
                f"start batch sampler: mem usage: {data_utils.get_mem_usage()}"
            )

            with data_utils.numpy_seed(seed):
                indices = dataset.ordered_indices()
            logger.info(
                f"[{split}] @batch_sampler order indices time: {get_time_gap(start_time, time.time())}"
            )
            logger.info(f"mem usage: {data_utils.get_mem_usage()}")

            # filter examples that are too large
            if max_positions is not None:
                my_time = time.time()
                indices = self.filter_indices_by_size(indices, dataset,
                                                      max_positions,
                                                      ignore_invalid_inputs)
                logger.info(
                    f"[{split}] @batch_sampler filter_by_size time: {get_time_gap(my_time, time.time())}"
                )
                logger.info(f"mem usage: {data_utils.get_mem_usage()}")

            # create mini-batches with given size constraints
            my_time = time.time()
            batch_sampler = dataset.batch_by_size(
                indices,
                max_tokens=max_tokens,
                max_sentences=max_sentences,
                required_batch_size_multiple=required_batch_size_multiple,
            )

            logger.info(
                f"[{split}] @batch_sampler batch_by_size time: {get_time_gap(my_time, time.time())}"
            )
            logger.info(
                f"[{split}] per epoch batch_sampler set-up time: {get_time_gap(start_time, time.time())}"
            )
            logger.info(f"mem usage: {data_utils.get_mem_usage()}")

            return batch_sampler
Ejemplo n.º 3
0
    def ordered_indices(self):
        start = time.time()
        with data_utils.numpy_seed(self.seed, self.epoch):
            sampled_indices = []
            num_selected_instances = 0

            # For each dataset i, sample self.distribution[i] * self.total_num_instances
            for i, key in enumerate(self.datasets):

                if i < len(self.datasets) - 1:
                    num_instances = int(self.distribution[i] *
                                        self.total_num_instances)
                    high = self.dataset_offsets[i + 1]
                else:
                    num_instances = self.total_num_instances - num_selected_instances
                    high = self.total_num_instances

                logger.info(f"sampling {num_instances} from {key} dataset")
                num_selected_instances += num_instances

                # First, add k copies of the dataset where k = num_instances // len(dataset).
                # This ensures an equal distribution of the data points as much as possible.
                # For the remaining entries randomly sample them
                dataset_size = len(self.datasets[key])
                num_copies = num_instances // dataset_size
                dataset_indices = (
                    np.random.permutation(high - self.dataset_offsets[i]) +
                    self.dataset_offsets[i])[:num_instances -
                                             num_copies * dataset_size]
                if num_copies > 0:
                    sampled_indices += list(
                        np.concatenate((
                            np.repeat(np.arange(self.dataset_offsets[i], high),
                                      num_copies),
                            dataset_indices,
                        )))
                else:
                    sampled_indices += list(dataset_indices)

            assert (len(sampled_indices) == self.total_num_instances
                    ), f"{len(sampled_indices)} vs {self.total_num_instances}"

            np.random.shuffle(sampled_indices)
            if self.sort_indices:
                sampled_indices.sort(key=lambda i: self.num_tokens(i))

            logger.info("multi_corpus_dataset ordered_indices took {}s".format(
                time.time() - start))
            return np.array(sampled_indices, dtype=np.int64)
Ejemplo n.º 4
0
    def __getitem__(self, index):
        """
        Returns a single noisy sample. Multiple samples are fed to the collater
        create a noising dataset batch.
        """
        src_tokens = self.src_dataset[index]
        src_lengths = torch.LongTensor([len(src_tokens)])
        src_tokens = src_tokens.unsqueeze(0)

        # Transpose src tokens to fit expected shape of x in noising function
        # (batch size, sequence length) -> (sequence length, batch size)
        src_tokens_t = torch.t(src_tokens)

        with data_utils.numpy_seed(self.seed + index):
            noisy_src_tokens = self.noiser.noising(src_tokens_t, src_lengths)

        # Transpose back to expected src_tokens format
        # (sequence length, 1) -> (1, sequence length)
        noisy_src_tokens = torch.t(noisy_src_tokens)
        return noisy_src_tokens[0]
Ejemplo n.º 5
0
    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = utils.split_paths(self.args.data)
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]
        split_path = os.path.join(data_path, split)

        dataset = data_utils.load_indexed_dataset(
            split_path,
            self.source_dictionary,
            self.args.dataset_impl,
            combine=combine,
        )
        if dataset is None:
            raise FileNotFoundError("Dataset not found: {} ({})".format(
                split, split_path))

        dataset = maybe_shorten_dataset(
            dataset,
            split,
            self.args.shorten_data_split_list,
            self.args.shorten_method,
            self.args.tokens_per_sample,
            self.args.seed,
        )

        # create continuous blocks of tokens
        dataset = TokenBlockDataset(
            dataset,
            dataset.sizes,
            self.args.tokens_per_sample - 1,  # one less for <s>
            pad=self.source_dictionary.pad(),
            eos=self.source_dictionary.eos(),
            break_mode=self.args.sample_break_mode,
        )
        logger.info("loaded {} blocks from: {}".format(len(dataset),
                                                       split_path))

        # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
        dataset = PrependTokenDataset(dataset, self.source_dictionary.bos())

        # create masked input and targets
        mask_whole_words = (get_whole_word_mask(self.args,
                                                self.source_dictionary)
                            if self.args.mask_whole_words else None)

        src_dataset, tgt_dataset = MaskTokensDataset.apply_mask(
            dataset,
            self.source_dictionary,
            pad_idx=self.source_dictionary.pad(),
            mask_idx=self.mask_idx,
            seed=self.args.seed,
            mask_prob=self.args.mask_prob,
            leave_unmasked_prob=self.args.leave_unmasked_prob,
            random_token_prob=self.args.random_token_prob,
            freq_weighted_replacement=self.args.freq_weighted_replacement,
            mask_whole_words=mask_whole_words,
            mask_multiple_length=self.args.mask_multiple_length,
            mask_stdev=self.args.mask_stdev,
        )

        with data_utils.numpy_seed(self.args.seed):
            shuffle = np.random.permutation(len(src_dataset))

        self.datasets[split] = SortDataset(
            NestedDictionaryDataset(
                {
                    "id":
                    IdDataset(),
                    "net_input": {
                        "src_tokens":
                        RightPadDataset(
                            src_dataset,
                            pad_idx=self.source_dictionary.pad(),
                        ),
                        "src_lengths":
                        NumelDataset(src_dataset, reduce=False),
                    },
                    "target":
                    RightPadDataset(
                        tgt_dataset,
                        pad_idx=self.source_dictionary.pad(),
                    ),
                    "nsentences":
                    NumSamplesDataset(),
                    "ntokens":
                    NumelDataset(src_dataset, reduce=True),
                },
                sizes=[src_dataset.sizes],
            ),
            sort_order=[
                shuffle,
                src_dataset.sizes,
            ],
        )
Ejemplo n.º 6
0
    def load_dataset(self, split, epoch=1, combine=False):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        loaded_datasets = []

        paths = utils.split_paths(self.args.data)
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]
        logger.info("data_path", data_path)

        for k in itertools.count():
            split_k = split + (str(k) if k > 0 else "")
            path = os.path.join(data_path, split_k)
            ds = indexed_dataset.make_dataset(
                path,
                impl=self.args.dataset_impl,
                fix_lua_indexing=True,
                dictionary=self.dictionary,
            )

            if ds is None:
                if k > 0:
                    break
                else:
                    raise FileNotFoundError(
                        "Dataset not found: {} ({})".format(split, data_path)
                    )

            with data_utils.numpy_seed(self.seed + k):
                loaded_datasets.append(
                    BlockPairDataset(
                        ds,
                        self.dictionary,
                        ds.sizes,
                        self.args.tokens_per_sample,
                        break_mode=self.args.break_mode,
                        doc_break_size=1,
                    )
                )

            logger.info(
                "{} {} {} examples".format(data_path, split_k, len(loaded_datasets[-1]))
            )

            if not combine:
                break

        if len(loaded_datasets) == 1:
            dataset = loaded_datasets[0]
            sizes = dataset.sizes
        else:
            dataset = ConcatDataset(loaded_datasets)
            sizes = np.concatenate([ds.sizes for ds in loaded_datasets])

        self.datasets[split] = MaskedLMDataset(
            dataset=dataset,
            sizes=sizes,
            vocab=self.dictionary,
            pad_idx=self.dictionary.pad(),
            mask_idx=self.dictionary.mask(),
            classif_token_idx=self.dictionary.cls(),
            sep_token_idx=self.dictionary.sep(),
            shuffle=self.args.shuffle_dataset,
            seed=self.seed,
        )
Ejemplo n.º 7
0
    def _collate(self, samples: List[Dict], pad_idx: int, eos_idx: int):
        """
        Does the heavy lifting for creating a batch from the input list of
        examples. The logic is as follows:
            1. Mask the input blocks. In case has_pair is True then we have 2
               blocks to mask.
            2. Prepend the first masked block tensor with the special token
               used as sentence embedding. Eg: CLS in BERT. This happens
               irrespective of the value of has_pair.
            3. If has_pair is True, then append the first masked block with the
               special separator token (eg: SEP for BERT) and compute segment
               label accordingly. In this case, also append the second masked
               block with this special separator token and compute its segment
               label.
            4. For the targets tensor, prepend and append with padding index
               accordingly.
            5. Concatenate all tensors.
        """
        if len(samples) == 0:
            return {}
        # To ensure determinism, we reset the state of the PRNG after every
        # batch based on the seed and the first id of the batch. This ensures
        # that across epochs we get the same mask for the same example. This
        # is needed for reproducibility and is how BERT does masking
        # TODO: Can we add deteminism without this constraint?
        with data_utils.numpy_seed(self.seed + samples[0]["id"]):
            for s in samples:

                # token range is needed for replacing with random token during
                # masking
                token_range = (self.vocab.nspecial, len(self.vocab))

                # mask according to specified probabilities.
                masked_blk_one, masked_tgt_one = self._mask_block(
                    s["block_one"],
                    self.mask_idx,
                    self.pad_idx,
                    token_range,
                )

                tokens = np.concatenate([[self.classif_token_idx],
                                         masked_blk_one])
                targets = np.concatenate([[self.pad_idx], masked_tgt_one])
                segments = np.ones(len(tokens)) * self.segment_id

                # if has_pairs is True then we need to add the SEP token to both
                # the blocks after masking and re-compute segments based on the new
                # lengths.
                if self.has_pairs:
                    tokens_one = np.concatenate([tokens, [self.sep_token_idx]])
                    targets_one = np.concatenate([targets, [self.pad_idx]])

                    masked_blk_two, masked_tgt_two = self._mask_block(
                        s["block_two"], self.mask_idx, self.pad_idx,
                        token_range)
                    tokens_two = np.concatenate(
                        [masked_blk_two, [self.sep_token_idx]])
                    targets_two = np.concatenate(
                        [masked_tgt_two, [self.pad_idx]])

                    # block + 1 sep + 1 special (CLS)
                    segments_one = np.zeros(len(tokens_one))
                    # block + 1 sep
                    segments_two = np.ones(len(tokens_two))

                    tokens = np.concatenate([tokens_one, tokens_two])
                    targets = np.concatenate([targets_one, targets_two])
                    segments = np.concatenate([segments_one, segments_two])

                s["source"] = torch.LongTensor(tokens)
                s["segment_labels"] = torch.LongTensor(segments)
                s["lm_target"] = torch.LongTensor(targets)

        def merge(key):
            return data_utils.collate_tokens([s[key] for s in samples],
                                             pad_idx,
                                             eos_idx,
                                             left_pad=False)

        return {
            "id":
            torch.LongTensor([s["id"] for s in samples]),
            "ntokens":
            sum(len(s["source"]) for s in samples),
            "net_input": {
                "src_tokens": merge("source"),
                "segment_labels": merge("segment_labels"),
            },
            "lm_target":
            merge("lm_target"),
            "sentence_target":
            torch.LongTensor([s["sentence_target"]
                              for s in samples]) if self.has_pairs else None,
            "nsentences":
            len(samples),
        }
Ejemplo n.º 8
0
    def get_batch_iterator(
        self,
        dataset,
        max_tokens=None,
        max_sentences=None,
        max_positions=None,
        ignore_invalid_inputs=False,
        required_batch_size_multiple=1,
        seed=1,
        num_shards=1,
        shard_id=0,
        num_workers=0,
        epoch=1,
        data_buffer_size=0,
        disable_iterator_cache=False,
    ):
        """
        Get an iterator that yields batches of data from the given dataset.

        Args:
            dataset (~fairseq_stchde.data.FairseqDataset): dataset to batch
            max_tokens (int, optional): max number of tokens in each batch
                (default: None).
            max_sentences (int, optional): max number of sentences in each
                batch (default: None).
            max_positions (optional): max sentence length supported by the
                model (default: None).
            ignore_invalid_inputs (bool, optional): don't raise Exception for
                sentences that are too long (default: False).
            required_batch_size_multiple (int, optional): require batch size to
                be a multiple of N (default: 1).
            seed (int, optional): seed for random number generator for
                reproducibility (default: 1).
            num_shards (int, optional): shard the data iterator into N
                shards (default: 1).
            shard_id (int, optional): which shard of the data iterator to
                return (default: 0).
            num_workers (int, optional): how many subprocesses to use for data
                loading. 0 means the data will be loaded in the main process
                (default: 0).
            epoch (int, optional): the epoch to start the iterator from
                (default: 1).
            data_buffer_size (int, optional): number of batches to
                preload (default: 0).
            disable_iterator_cache (bool, optional): don't cache the
                EpochBatchIterator (ignores `FairseqTask::can_reuse_epoch_itr`)
                (default: False).
        Returns:
            ~fairseq_stchde.iterators.EpochBatchIterator: a batched iterator over the
                given dataset split
        """
        can_reuse_epoch_itr = not disable_iterator_cache and self.can_reuse_epoch_itr(
            dataset)
        if can_reuse_epoch_itr and dataset in self.dataset_to_epoch_iter:
            logger.debug(
                "reusing EpochBatchIterator for epoch {}".format(epoch))
            return self.dataset_to_epoch_iter[dataset]

        assert isinstance(dataset, FairseqDataset)

        # initialize the dataset with the correct starting epoch
        dataset.set_epoch(epoch)

        # get indices ordered by example size
        with data_utils.numpy_seed(seed):
            indices = dataset.ordered_indices()

        # filter examples that are too large
        if max_positions is not None:
            indices = self.filter_indices_by_size(indices, dataset,
                                                  max_positions,
                                                  ignore_invalid_inputs)

        # create mini-batches with given size constraints
        batch_sampler = dataset.batch_by_size(
            indices,
            max_tokens=max_tokens,
            max_sentences=max_sentences,
            required_batch_size_multiple=required_batch_size_multiple,
        )

        # return a reusable, sharded iterator
        epoch_iter = iterators.EpochBatchIterator(
            dataset=dataset,
            collate_fn=dataset.collater,
            batch_sampler=batch_sampler,
            seed=seed,
            num_shards=num_shards,
            shard_id=shard_id,
            num_workers=num_workers,
            epoch=epoch,
            buffer_size=data_buffer_size,
        )

        if can_reuse_epoch_itr:
            self.dataset_to_epoch_iter[dataset] = epoch_iter

        return epoch_iter
Ejemplo n.º 9
0
    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = utils.split_paths(self.args.data)
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]

        languages = sorted(name for name in os.listdir(data_path)
                           if os.path.isdir(os.path.join(data_path, name)))

        logger.info("Training on {0} languages: {1}".format(
            len(languages), languages))
        logger.info("Language to id mapping: ",
                    {lang: id
                     for id, lang in enumerate(languages)})

        mask_whole_words = self._get_whole_word_mask()
        lang_datasets = []
        for lang_id, language in enumerate(languages):
            split_path = os.path.join(data_path, language, split)

            dataset = data_utils.load_indexed_dataset(
                split_path,
                self.source_dictionary,
                self.args.dataset_impl,
                combine=combine,
            )
            if dataset is None:
                raise FileNotFoundError("Dataset not found: {} ({})".format(
                    split, split_path))

            # create continuous blocks of tokens
            dataset = TokenBlockDataset(
                dataset,
                dataset.sizes,
                self.args.tokens_per_sample - 1,  # one less for <s>
                pad=self.source_dictionary.pad(),
                eos=self.source_dictionary.eos(),
                break_mode=self.args.sample_break_mode,
            )
            logger.info("loaded {} blocks from: {}".format(
                len(dataset), split_path))

            # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
            dataset = PrependTokenDataset(dataset,
                                          self.source_dictionary.bos())

            src_dataset, tgt_dataset = MaskTokensDataset.apply_mask(
                dataset,
                self.source_dictionary,
                pad_idx=self.source_dictionary.pad(),
                mask_idx=self.mask_idx,
                seed=self.args.seed,
                mask_prob=self.args.mask_prob,
                leave_unmasked_prob=self.args.leave_unmasked_prob,
                random_token_prob=self.args.random_token_prob,
                freq_weighted_replacement=self.args.freq_weighted_replacement,
                mask_whole_words=mask_whole_words,
            )

            lang_dataset = NestedDictionaryDataset(
                {
                    "net_input": {
                        "src_tokens":
                        PadDataset(
                            src_dataset,
                            pad_idx=self.source_dictionary.pad(),
                            left_pad=False,
                        ),
                        "src_lengths":
                        NumelDataset(src_dataset, reduce=False),
                    },
                    "target":
                    PadDataset(
                        tgt_dataset,
                        pad_idx=self.source_dictionary.pad(),
                        left_pad=False,
                    ),
                    "nsentences":
                    NumSamplesDataset(),
                    "ntokens":
                    NumelDataset(src_dataset, reduce=True),
                    "lang_id":
                    RawLabelDataset([lang_id] * src_dataset.sizes.shape[0]),
                },
                sizes=[src_dataset.sizes],
            )
            lang_datasets.append(lang_dataset)

        dataset_lengths = np.array(
            [len(d) for d in lang_datasets],
            dtype=float,
        )
        logger.info("loaded total {} blocks for all languages".format(
            dataset_lengths.sum(), ))
        if split == self.args.train_subset:
            # For train subset, additionally up or down sample languages.
            sample_probs = self._get_sample_prob(dataset_lengths)
            logger.info(
                "Sample probability by language: ",
                {
                    lang: "{0:.4f}".format(sample_probs[id])
                    for id, lang in enumerate(languages)
                },
            )
            size_ratio = (sample_probs *
                          dataset_lengths.sum()) / dataset_lengths
            logger.info(
                "Up/Down Sampling ratio by language: ",
                {
                    lang: "{0:.2f}".format(size_ratio[id])
                    for id, lang in enumerate(languages)
                },
            )

            resampled_lang_datasets = [
                ResamplingDataset(
                    lang_datasets[i],
                    size_ratio=size_ratio[i],
                    seed=self.args.seed,
                    epoch=epoch,
                    replace=size_ratio[i] >= 1.0,
                ) for i, d in enumerate(lang_datasets)
            ]
            dataset = ConcatDataset(resampled_lang_datasets)
        else:
            dataset = ConcatDataset(lang_datasets)
            lang_splits = [split]
            for lang_id, lang_dataset in enumerate(lang_datasets):
                split_name = split + "_" + languages[lang_id]
                lang_splits.append(split_name)
                self.datasets[split_name] = lang_dataset

            # [TODO]: This is hacky for now to print validation ppl for each
            # language individually. Maybe need task API changes to allow it
            # in more generic ways.
            if split in self.args.valid_subset:
                self.args.valid_subset = self.args.valid_subset.replace(
                    split, ",".join(lang_splits))

        with data_utils.numpy_seed(self.args.seed + epoch):
            shuffle = np.random.permutation(len(dataset))

        self.datasets[split] = SortDataset(
            dataset,
            sort_order=[
                shuffle,
                dataset.sizes,
            ],
        )
Ejemplo n.º 10
0
    def __getitem_cached__(self, seed: int, epoch: int, index: int):
        with data_utils.numpy_seed(self.seed, self.epoch, index):
            item = self.dataset[index]
            sz = len(item)

            assert (
                self.mask_idx not in item
            ), "Dataset contains mask_idx (={}), this is not expected!".format(
                self.mask_idx, )

            if self.mask_whole_words is not None:
                word_begins_mask = self.mask_whole_words.gather(0, item)
                word_begins_idx = word_begins_mask.nonzero().view(-1)
                sz = len(word_begins_idx)
                words = np.split(word_begins_mask, word_begins_idx)[1:]
                assert len(words) == sz
                word_lens = list(map(len, words))

            # decide elements to mask
            mask = np.full(sz, False)
            num_mask = int(
                # add a random number for probabilistic rounding
                self.mask_prob * sz / float(self.mask_multiple_length) +
                np.random.rand())

            # multiple masking as described in the vq-wav2vec paper (https://arxiv.org/abs/1910.05453)
            mask_idc = np.random.choice(sz, num_mask, replace=False)
            if self.mask_stdev > 0.0:
                lengths = np.random.normal(self.mask_multiple_length,
                                           self.mask_stdev,
                                           size=num_mask)
                lengths = [max(0, int(round(x))) for x in lengths]
                mask_idc = np.asarray(
                    [
                        mask_idc[j] + offset for j in range(len(mask_idc))
                        for offset in range(lengths[j])
                    ],
                    dtype=np.int64,
                )
            else:
                mask_idc = np.concatenate(
                    [mask_idc + i for i in range(self.mask_multiple_length)])
            mask_idc = mask_idc[mask_idc < len(mask)]
            try:
                mask[mask_idc] = True
            except:  # something wrong
                print("Assigning mask indexes {} to mask {} failed!".format(
                    mask_idc, mask))
                raise

            if self.return_masked_tokens:
                # exit early if we're just returning the masked tokens
                # (i.e., the targets for masked LM training)
                if self.mask_whole_words is not None:
                    mask = np.repeat(mask, word_lens)
                new_item = np.full(len(mask), self.pad_idx)
                new_item[mask] = item[torch.from_numpy(mask.astype(np.uint8))
                                      == 1]
                return torch.from_numpy(new_item)

            # decide unmasking and random replacement
            rand_or_unmask_prob = self.random_token_prob + self.leave_unmasked_prob
            if rand_or_unmask_prob > 0.0:
                rand_or_unmask = mask & (np.random.rand(sz) <
                                         rand_or_unmask_prob)
                if self.random_token_prob == 0.0:
                    unmask = rand_or_unmask
                    rand_mask = None
                elif self.leave_unmasked_prob == 0.0:
                    unmask = None
                    rand_mask = rand_or_unmask
                else:
                    unmask_prob = self.leave_unmasked_prob / rand_or_unmask_prob
                    decision = np.random.rand(sz) < unmask_prob
                    unmask = rand_or_unmask & decision
                    rand_mask = rand_or_unmask & (~decision)
            else:
                unmask = rand_mask = None

            if unmask is not None:
                mask = mask ^ unmask

            if self.mask_whole_words is not None:
                mask = np.repeat(mask, word_lens)

            new_item = np.copy(item)
            new_item[mask] = self.mask_idx
            if rand_mask is not None:
                num_rand = rand_mask.sum()
                if num_rand > 0:
                    if self.mask_whole_words is not None:
                        rand_mask = np.repeat(rand_mask, word_lens)
                        num_rand = rand_mask.sum()

                    new_item[rand_mask] = np.random.choice(
                        len(self.vocab),
                        num_rand,
                        p=self.weights,
                    )

            return torch.from_numpy(new_item)
Ejemplo n.º 11
0
    def load_dataset(self, split, combine=False, **kwargs):
        """Load a given dataset split (e.g., train, valid, test)."""

        def get_path(type, split):
            return os.path.join(self.args.data, type, split)

        def make_dataset(type, dictionary):
            split_path = get_path(type, split)

            dataset = data_utils.load_indexed_dataset(
                split_path,
                self.source_dictionary,
                self.args.dataset_impl,
                combine=combine,
            )
            return dataset

        input0 = make_dataset("input0", self.source_dictionary)
        input_options = [
            make_dataset("input{idx}".format(idx=idx + 1), self.source_dictionary)
            for idx in range(self.args.num_classes)
        ]

        if self.args.separator_token is not None:
            input0 = PrependTokenDataset(input0, self.args.separator_token)

        src_tokens = []
        for input_option in input_options:
            if self.args.init_token is not None:
                input_option = PrependTokenDataset(input_option, self.args.init_token)
            if self.args.max_option_length is not None:
                input_option = TruncateDataset(
                    input_option, self.args.max_option_length
                )
            src_token = ConcatSentencesDataset(input_option, input0)
            src_token = maybe_shorten_dataset(
                src_token,
                split,
                self.args.shorten_data_split_list,
                self.args.shorten_method,
                self.args.max_positions,
                self.args.seed,
            )
            src_tokens.append(src_token)

        with data_utils.numpy_seed(self.args.seed):
            shuffle = np.random.permutation(len(src_tokens[0]))

        dataset = {
            "id": IdDataset(),
            "nsentences": NumSamplesDataset(),
            "ntokens": NumelDataset(src_tokens[0], reduce=True),
        }

        for src_token_idx in range(len(src_tokens)):
            dataset.update(
                {
                    "net_input{idx}".format(idx=src_token_idx + 1): {
                        "src_tokens": RightPadDataset(
                            src_tokens[src_token_idx],
                            pad_idx=self.source_dictionary.pad(),
                        ),
                        "src_lengths": NumelDataset(
                            src_tokens[src_token_idx], reduce=False
                        ),
                    }
                }
            )

        label_path = "{}.label".format(get_path("label", split))
        if os.path.exists(label_path):
            with open(label_path) as h:
                dataset.update(
                    target=RawLabelDataset([int(x.strip()) for x in h.readlines()])
                )

        nested_dataset = NestedDictionaryDataset(
            dataset,
            sizes=[np.maximum.reduce([src_token.sizes for src_token in src_tokens])],
        )

        if self.args.no_shuffle:
            dataset = nested_dataset
        else:
            dataset = SortDataset(
                nested_dataset,
                # shuffle
                sort_order=[shuffle],
            )

        logger.info("Loaded {0} with #samples: {1}".format(split, len(dataset)))

        self.datasets[split] = dataset
        return self.datasets[split]
Ejemplo n.º 12
0
    def load_dataset(self, split, combine=False, **kwargs):
        """Load a given dataset split (e.g., train, valid, test)."""
        def get_path(key, split):
            return os.path.join(self.args.data, key, split)

        def make_dataset(key, dictionary):
            split_path = get_path(key, split)

            dataset = data_utils.load_indexed_dataset(
                split_path,
                dictionary,
                self.args.dataset_impl,
                combine=combine,
            )
            return dataset

        input0 = make_dataset("input0", self.source_dictionary)
        assert input0 is not None, "could not find dataset: {}".format(
            get_path("input0", split))
        input1 = make_dataset("input1", self.source_dictionary)

        if self.args.init_token is not None:
            input0 = PrependTokenDataset(input0, self.args.init_token)

        if input1 is None:
            src_tokens = input0
        else:
            if self.args.separator_token is not None:
                input1 = PrependTokenDataset(input1, self.args.separator_token)

            src_tokens = ConcatSentencesDataset(input0, input1)

        with data_utils.numpy_seed(self.args.seed):
            shuffle = np.random.permutation(len(src_tokens))

        src_tokens = maybe_shorten_dataset(
            src_tokens,
            split,
            self.args.shorten_data_split_list,
            self.args.shorten_method,
            self.max_positions(),
            self.args.seed,
        )

        dataset = {
            "id": IdDataset(),
            "net_input": {
                "src_tokens":
                RightPadDataset(
                    src_tokens,
                    pad_idx=self.source_dictionary.pad(),
                ),
                "src_lengths":
                NumelDataset(src_tokens, reduce=False),
            },
            "nsentences": NumSamplesDataset(),
            "ntokens": NumelDataset(src_tokens, reduce=True),
        }

        if self.args.add_prev_output_tokens:
            prev_tokens_dataset = RightPadDataset(
                RollDataset(src_tokens, 1),
                pad_idx=self.dictionary.pad(),
            )
            dataset["net_input"].update(
                prev_output_tokens=prev_tokens_dataset, )

        if not self.args.regression_target:
            label_dataset = make_dataset("label", self.label_dictionary)
            if label_dataset is not None:
                dataset.update(target=OffsetTokensDataset(
                    StripTokenDataset(
                        label_dataset,
                        id_to_strip=self.label_dictionary.eos(),
                    ),
                    offset=-self.label_dictionary.nspecial,
                ))
        else:
            label_path = "{0}.label".format(get_path("label", split))
            if os.path.exists(label_path):

                def parse_regression_target(i, line):
                    values = line.split()
                    assert (
                        len(values) == self.args.num_classes
                    ), f'expected num_classes={self.args.num_classes} regression target values on line {i}, found: "{line}"'
                    return [float(x) for x in values]

                with open(label_path) as h:
                    dataset.update(target=RawLabelDataset([
                        parse_regression_target(i, line.strip())
                        for i, line in enumerate(h.readlines())
                    ]))

        nested_dataset = NestedDictionaryDataset(
            dataset,
            sizes=[src_tokens.sizes],
        )

        if self.args.no_shuffle:
            dataset = nested_dataset
        else:
            dataset = SortDataset(
                nested_dataset,
                # shuffle
                sort_order=[shuffle],
            )

        logger.info("Loaded {0} with #samples: {1}".format(
            split, len(dataset)))

        self.datasets[split] = dataset
        return self.datasets[split]