コード例 #1
0
class LanguagePairDataset(FairseqDataset):
    """
    A pair of torch.utils.data.Datasets.

    Args:
        src (torch.utils.data.Dataset): source dataset to wrap
        src_sizes (List[int]): source sentence lengths
        src_dict (~fairseq.data.Dictionary): source vocabulary
        tgt (torch.utils.data.Dataset, optional): target dataset to wrap
        tgt_sizes (List[int], optional): target sentence lengths
        tgt_dict (~fairseq.data.Dictionary, optional): target vocabulary
        left_pad_source (bool, optional): pad source tensors on the left side
            (default: True).
        left_pad_target (bool, optional): pad target tensors on the left side
            (default: False).
        shuffle (bool, optional): shuffle dataset elements before batching
            (default: True).
        input_feeding (bool, optional): create a shifted version of the targets
            to be passed into the model for teacher forcing (default: True).
        remove_eos_from_source (bool, optional): if set, removes eos from end
            of source if it's present (default: False).
        append_eos_to_target (bool, optional): if set, appends eos to end of
            target if it's absent (default: False).
        align_dataset (torch.utils.data.Dataset, optional): dataset
            containing alignments.
        constraints (Tensor, optional): 2d tensor with a concatenated, zero-
            delimited list of constraints for each sentence.
        append_bos (bool, optional): if set, appends bos to the beginning of
            source/target sentence.
        num_buckets (int, optional): if set to a value greater than 0, then
            batches will be bucketed into the given number of batch shapes.
        src_lang_id (int, optional): source language ID, if set, the collated batch
            will contain a field 'src_lang_id' in 'net_input' which indicates the
            source language of the samples.
        tgt_lang_id (int, optional): target language ID, if set, the collated batch
            will contain a field 'tgt_lang_id' which indicates the target language
             of the samples.
    """
    def __init__(
        self,
        src,
        src_sizes,
        src_dict,
        tgt=None,
        tgt_sizes=None,
        tgt_dict=None,
        left_pad_source=True,
        left_pad_target=False,
        shuffle=True,
        input_feeding=True,
        remove_eos_from_source=False,
        append_eos_to_target=False,
        align_dataset=None,
        constraints=None,
        append_bos=False,
        eos=None,
        num_buckets=0,
        src_lang_id=None,
        tgt_lang_id=None,
    ):
        if tgt_dict is not None:
            assert src_dict.pad() == tgt_dict.pad()
            assert src_dict.eos() == tgt_dict.eos()
            assert src_dict.unk() == tgt_dict.unk()
        if tgt is not None:
            assert len(src) == len(
                tgt
            ), "Source and target must contain the same number of examples"
        self.src = src
        self.tgt = tgt
        self.src_sizes = np.array(src_sizes)
        self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None
        self.src_dict = src_dict
        self.tgt_dict = tgt_dict
        self.left_pad_source = left_pad_source
        self.left_pad_target = left_pad_target
        self.shuffle = shuffle
        self.input_feeding = input_feeding
        self.remove_eos_from_source = remove_eos_from_source
        self.append_eos_to_target = append_eos_to_target
        self.align_dataset = align_dataset
        if self.align_dataset is not None:
            assert self.tgt_sizes is not None, "Both source and target needed when alignments are provided"
        self.constraints = constraints
        self.append_bos = append_bos
        self.eos = (eos if eos is not None else src_dict.eos())
        self.src_lang_id = src_lang_id
        self.tgt_lang_id = tgt_lang_id
        if num_buckets > 0:
            from fairseq.data import BucketPadLengthDataset
            self.src = BucketPadLengthDataset(
                self.src,
                sizes=self.src_sizes,
                num_buckets=num_buckets,
                pad_idx=self.src_dict.pad(),
                left_pad=self.left_pad_source,
            )
            self.src_sizes = self.src.sizes
            logger.info('bucketing source lengths: {}'.format(
                list(self.src.buckets)))
            if self.tgt is not None:
                self.tgt = BucketPadLengthDataset(
                    self.tgt,
                    sizes=self.tgt_sizes,
                    num_buckets=num_buckets,
                    pad_idx=self.tgt_dict.pad(),
                    left_pad=self.left_pad_target,
                )
                self.tgt_sizes = self.tgt.sizes
                logger.info('bucketing target lengths: {}'.format(
                    list(self.tgt.buckets)))

            # determine bucket sizes using self.num_tokens, which will return
            # the padded lengths (thanks to BucketPadLengthDataset)
            num_tokens = np.vectorize(self.num_tokens, otypes=[np.long])
            self.bucketed_num_tokens = num_tokens(np.arange(len(self.src)))
            self.buckets = [
                (None, num_tokens)
                for num_tokens in np.unique(self.bucketed_num_tokens)
            ]
        else:
            self.buckets = None

    def get_batch_shapes(self):
        return self.buckets

    def __getitem__(self, index):
        tgt_item = self.tgt[index] if self.tgt is not None else None
        src_item = self.src[index]
        # Append EOS to end of tgt sentence if it does not have an EOS and remove
        # EOS from end of src sentence if it exists. This is useful when we use
        # use existing datasets for opposite directions i.e., when we want to
        # use tgt_dataset as src_dataset and vice versa
        if self.append_eos_to_target:
            eos = self.tgt_dict.eos() if self.tgt_dict else self.src_dict.eos()
            if self.tgt and self.tgt[index][-1] != eos:
                tgt_item = torch.cat(
                    [self.tgt[index], torch.LongTensor([eos])])

        if self.append_bos:
            bos = self.tgt_dict.bos() if self.tgt_dict else self.src_dict.bos()
            if self.tgt and self.tgt[index][0] != bos:
                tgt_item = torch.cat(
                    [torch.LongTensor([bos]), self.tgt[index]])

            bos = self.src_dict.bos()
            if self.src[index][0] != bos:
                src_item = torch.cat(
                    [torch.LongTensor([bos]), self.src[index]])

        if self.remove_eos_from_source:
            eos = self.src_dict.eos()
            if self.src[index][-1] == eos:
                src_item = self.src[index][:-1]

        example = {
            'id': index,
            'source': src_item,
            'target': tgt_item,
        }
        if self.align_dataset is not None:
            example['alignment'] = self.align_dataset[index]
        if self.constraints is not None:
            example["constraints"] = self.constraints[index]
        return example

    def __len__(self):
        return len(self.src)

    def collater(self, samples, pad_to_length=None):
        """Merge a list of samples to form a mini-batch.

        Args:
            samples (List[dict]): samples to collate
            pad_to_length (dict, optional): a dictionary of
                {'source': source_pad_to_length, 'target': target_pad_to_length}
                to indicate the max length to pad to in source and target respectively.

        Returns:
            dict: a mini-batch with the following keys:

                - `id` (LongTensor): example IDs in the original input order
                - `ntokens` (int): total number of tokens in the batch
                - `net_input` (dict): the input to the Model, containing keys:

                  - `src_tokens` (LongTensor): a padded 2D Tensor of tokens in
                    the source sentence of shape `(bsz, src_len)`. Padding will
                    appear on the left if *left_pad_source* is ``True``.
                  - `src_lengths` (LongTensor): 1D Tensor of the unpadded
                    lengths of each source sentence of shape `(bsz)`
                  - `prev_output_tokens` (LongTensor): a padded 2D Tensor of
                    tokens in the target sentence, shifted right by one
                    position for teacher forcing, of shape `(bsz, tgt_len)`.
                    This key will not be present if *input_feeding* is
                    ``False``.  Padding will appear on the left if
                    *left_pad_target* is ``True``.
                  - `src_lang_id` (LongTensor): a long Tensor which contains source
                    language IDs of each sample in the batch

                - `target` (LongTensor): a padded 2D Tensor of tokens in the
                  target sentence of shape `(bsz, tgt_len)`. Padding will appear
                  on the left if *left_pad_target* is ``True``.
                - `tgt_lang_id` (LongTensor): a long Tensor which contains target language
                   IDs of each sample in the batch
        """
        res = collate(
            samples,
            pad_idx=self.src_dict.pad(),
            eos_idx=self.eos,
            left_pad_source=self.left_pad_source,
            left_pad_target=self.left_pad_target,
            input_feeding=self.input_feeding,
            pad_to_length=pad_to_length,
        )
        if self.src_lang_id is not None or self.tgt_lang_id is not None:
            src_tokens = res['net_input']['src_tokens']
            bsz = src_tokens.size(0)
            if self.src_lang_id is not None:
                res['net_input']['src_lang_id'] = torch.LongTensor(
                    [[self.src_lang_id]]).expand(bsz, 1).to(src_tokens)
            if self.tgt_lang_id is not None:
                res['tgt_lang_id'] = torch.LongTensor(
                    [[self.tgt_lang_id]]).expand(bsz, 1).to(src_tokens)
        return res

    def num_tokens(self, index):
        """Return the number of tokens in a sample. This value is used to
        enforce ``--max-tokens`` during batching."""
        return max(self.src_sizes[index],
                   self.tgt_sizes[index] if self.tgt_sizes is not None else 0)

    def size(self, index):
        """Return an example's size as a float or tuple. This value is used when
        filtering a dataset with ``--max-positions``."""
        return (self.src_sizes[index],
                self.tgt_sizes[index] if self.tgt_sizes is not None else 0)

    def ordered_indices(self):
        """Return an ordered list of indices. Batches will be constructed based
        on this order."""
        if self.shuffle:
            indices = np.random.permutation(len(self)).astype(np.int64)
        else:
            indices = np.arange(len(self), dtype=np.int64)
        if self.buckets is None:
            # sort by target length, then source length
            if self.tgt_sizes is not None:
                indices = indices[np.argsort(self.tgt_sizes[indices],
                                             kind='mergesort')]
            return indices[np.argsort(self.src_sizes[indices],
                                      kind='mergesort')]
        else:
            # sort by bucketed_num_tokens, which is:
            #   max(padded_src_len, padded_tgt_len)
            return indices[np.argsort(self.bucketed_num_tokens[indices],
                                      kind='mergesort')]

    @property
    def supports_prefetch(self):
        return (getattr(self.src, 'supports_prefetch', False)
                and (getattr(self.tgt, 'supports_prefetch', False)
                     or self.tgt is None))

    def prefetch(self, indices):
        self.src.prefetch(indices)
        if self.tgt is not None:
            self.tgt.prefetch(indices)
        if self.align_dataset is not None:
            self.align_dataset.prefetch(indices)

    def filter_indices_by_size(self, indices, max_sizes):
        """ Filter a list of sample indices. Remove those that are longer
            than specified in max_sizes.

        Args:
            indices (np.array): original array of sample indices
            max_sizes (int or list[int] or tuple[int]): max sample size,
                can be defined separately for src and tgt (then list or tuple)

        Returns:
            np.array: filtered sample array
            list: list of removed indices
        """
        if max_sizes is None:
            return indices, []
        if type(max_sizes) in (int, float):
            max_src_size, max_tgt_size = max_sizes, max_sizes
        else:
            max_src_size, max_tgt_size = max_sizes
        if self.tgt_sizes is None:
            ignored = indices[self.src_sizes[indices] > max_src_size]
        else:
            ignored = indices[(self.src_sizes[indices] > max_src_size) |
                              (self.tgt_sizes[indices] > max_tgt_size)]
        if len(ignored) > 0:
            if self.tgt_sizes is None:
                indices = indices[self.src_sizes[indices] <= max_src_size]
            else:
                indices = indices[(self.src_sizes[indices] <= max_src_size)
                                  & (self.tgt_sizes[indices] <= max_tgt_size)]
        return indices, ignored.tolist()
コード例 #2
0
class AsrXentDataset(FairseqDataset):
    """
    A pair of torch.utils.data.Datasets.

    Args:
        src (torch.utils.data.Dataset): source dataset to wrap
        src_sizes (List[int]): source sentence lengths
        tgt (espresso.data.AliScpCachedDataset, optional): target alignment dataset to wrap
        tgt_sizes (List[int], optional): target sizes (num of states in the numerator graph)
        tgt_vocab_size (int, optional): used for setting padding index
        text  (torch.utils.data.Dataset, optional): text dataset to wrap
        shuffle (bool, optional): shuffle dataset elements before batching
            (default: True).
        num_buckets (int, optional): if set to a value greater than 0, then
            batches will be bucketed into the given number of batch shapes.
        pad_to_multiple (int, optional): pad src/tgt lengths to a multiple of this value
        seed (int, optional): random seed for generating a chunk from an utterance.
        chunk_width (int, optional): chunk width for chunk-wise training.
        chunk_left_context (int, optional): number of frames appended to the left of a chunk.
        chunk_right_context (int, optional): number of frames appended to the right of a chunk.
        label_delay (int, optional): offset of the alignments as prediction labels. Can be
            useful in archs such as asymmetric convolution, unidirectional LSTM, etc.
        random_chunking (bool, optional): wether do random chunking from utterance, or sequntially
            obtain chunks within each utterance. True for train and False for valid/test data.
    """
    def __init__(
        self,
        src,
        src_sizes,
        tgt: Optional[AliScpCachedDataset] = None,
        tgt_sizes=None,
        text=None,
        shuffle=True,
        num_buckets=0,
        pad_to_multiple=1,
        seed=1,
        chunk_width=None,
        chunk_left_context=None,
        chunk_right_context=None,
        label_delay=0,
        random_chunking=True,
    ):
        self.src = src
        self.tgt = tgt
        self.src_sizes = np.array(src_sizes)
        self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None
        self.text = text
        self.shuffle = shuffle
        self.seed = seed
        self.epoch = 1
        assert chunk_width is None or chunk_width > 0
        self.chunk_width = chunk_width
        assert chunk_left_context >= 0 and chunk_right_context >= 0
        self.chunk_left_context = chunk_left_context
        self.chunk_right_context = chunk_right_context
        assert (label_delay < 0 and -label_delay <= chunk_right_context) or \
            (label_delay >= 0 and (chunk_width is None or label_delay < chunk_width))
        self.label_delay = label_delay
        self.random_chunking = random_chunking
        if self.tgt is not None:
            self._match_src_tgt()
        if self.text is not None:
            changed = self._match_src_text()
            if self.tgt is not None and changed:
                self._match_src_tgt()
        self.sizes = np.vstack(
            (self.src_sizes, self.tgt_sizes
             )).T if self.tgt_sizes is not None else self.src_sizes

        if chunk_width is not None:
            # remove those whose lengths are shorter than chunk_size
            indices = np.flatnonzero(self.src.sizes >= chunk_width)
            if len(indices) < self.src.size:
                logger.warning(
                    "Removing {} examples whose lengths are shorter than chunk_size={}"
                    .format(self.src.size - len(indices), chunk_width))
                self.src.filter_and_reorder(indices)
                if self.tgt is not None:
                    self.tgt.filter_and_reorder(indices)
                if self.text is not None:
                    self.text.filter_and_reorder(indices)
                logger.warning("Done removal. {} examples remaining".format(
                    len(indices)))

        if num_buckets > 0:
            from fairseq.data import BucketPadLengthDataset
            from espresso.data import FeatBucketPadLengthDataset
            self.src = FeatBucketPadLengthDataset(
                self.src,
                sizes=self.src_sizes,
                num_buckets=num_buckets,
                pad_idx=0.0,
                left_pad=False,
            )
            self.src_sizes = self.src.sizes
            logger.info("bucketing source lengths: {}".format(
                list(self.src.buckets)))
            if self.tgt is not None:
                self.tgt = BucketPadLengthDataset(
                    self.tgt,
                    sizes=self.tgt_sizes,
                    num_buckets=num_buckets,
                    pad_idx=self.dictionary.pad(),
                    left_pad=False,
                )
                self.tgt_sizes = self.tgt.sizes
                logger.info("bucketing target lengths: {}".format(
                    list(self.tgt.buckets)))

            # determine bucket sizes using self.num_tokens, which will return
            # the padded lengths (thanks to FeatBucketPadLengthDataset)
            num_tokens = np.vectorize(self.num_tokens, otypes=[np.long])
            self.bucketed_num_tokens = num_tokens(np.arange(len(self.src)))
            self.buckets = [
                (None, num_tokens)
                for num_tokens in np.unique(self.bucketed_num_tokens)
            ]
        else:
            self.buckets = None
        self.pad_to_multiple = pad_to_multiple

    def _match_src_tgt(self):
        """Makes utterances in src and tgt the same order in terms of
        their utt_ids. Removes those that are only present in one of them."""
        assert self.tgt is not None
        if self.src.utt_ids == self.tgt.utt_ids:
            assert np.all(self.src.sizes == self.tgt.sizes
                          ), "frame and alignment lengths mismatch"
            return
        tgt_utt_ids_set = set(self.tgt.utt_ids)
        src_indices = [
            i for i, id in enumerate(self.src.utt_ids) if id in tgt_utt_ids_set
        ]
        self.src.filter_and_reorder(src_indices)
        self.src_sizes = np.array(self.src.sizes)
        try:
            tgt_indices = list(map(self.tgt.utt_ids.index, self.src.utt_ids))
        except ValueError:
            raise ValueError(
                "Unable to find some utt_id(s) in tgt. which is unlikely to happen. "
                "Something must be wrong.")
        self.tgt.filter_and_reorder(tgt_indices)
        self.tgt_sizes = np.array(self.tgt.sizes)
        assert self.src.utt_ids == self.tgt.utt_ids
        assert np.all(self.src.sizes ==
                      self.tgt.sizes), "frame and alignment lengths mismatch"

    def _match_src_text(self):
        """Makes utterances in src and text the same order in terms of
        their utt_ids. Removes those that are only present in one of them."""
        assert self.text is not None
        if self.src.utt_ids == self.text.utt_ids:
            return False
        text_utt_ids_set = set(self.text.utt_ids)
        src_indices = [
            i for i, id in enumerate(self.src.utt_ids)
            if id in text_utt_ids_set
        ]
        self.src.filter_and_reorder(src_indices)
        self.src_sizes = np.array(self.src.sizes)
        try:
            text_indices = list(map(self.text.utt_ids.index, self.src.utt_ids))
        except ValueError:
            raise ValueError(
                "Unable to find some utt_id(s) in text. which is unlikely to happen. "
                "Something must be wrong.")
        self.text.filter_and_reorder(text_indices)
        assert self.src.utt_ids == self.text.utt_ids
        return True

    def get_batch_shapes(self):
        return self.buckets

    def __getitem__(self, index):
        tgt_item = self.tgt[index] if self.tgt is not None else None
        text_item = self.text[index][1] if self.text is not None else None
        src_item = self.src[index]
        example = {
            "id": index,
            "utt_id": self.src.utt_ids[index],
            "source": src_item,
            "target": tgt_item,
            "text": text_item,
        }
        return example

    def __len__(self):
        return len(self.src)

    def collater(self, samples, pad_to_length=None):
        """Merge a list of samples to form a mini-batch.

        Args:
            samples (List[dict]): samples to collate
            pad_to_length (dict, optional): a dictionary of
                {'source': source_pad_to_length, 'target': target_pad_to_length}
                to indicate the max length to pad to in source and target respectively.


        Returns:
            dict: a mini-batch with the following keys:

                - `id` (LongTensor): example IDs in the original input order
                - `utt_id` (List[str]): list of utterance ids
                - `nsentences` (int): batch size
                - `ntokens` (int): total number of tokens in the batch
                - `net_input` (dict): the input to the Model, containing keys:

                  - `src_tokens` (FloatTensor): a padded 3D Tensor of features in
                    the source of shape `(bsz, src_len, feat_dim)`.
                  - `src_lengths` (IntTensor): 1D Tensor of the unpadded
                    lengths of each source sequence of shape `(bsz)`

                - `target` (LongTensor): a padded 2D Tensor of indices in the
                  target alignments of shape `(bsz, tgt_len)`
                - `text` (List[str]): list of original text
        """
        # pad_idx=-100 matches the default in criterions
        return collate(
            samples,
            pad_idx=-100,
            chunk_width=self.chunk_width,
            chunk_left_context=self.chunk_left_context,
            chunk_right_context=self.chunk_right_context,
            label_delay=self.label_delay,
            seed=self.seed,
            epoch=self.epoch,
            pad_to_length=pad_to_length,
            pad_to_multiple=self.pad_to_multiple,
            src_bucketed=(self.buckets is not None),
            random_chunking=self.random_chunking,
        )

    def num_tokens(self, index):
        """Return the number of frames in a sample. This value is used to
        enforce ``--max-tokens`` during batching."""
        if self.chunk_width is None:
            return self.src_sizes[index]
        return self.chunk_width + self.chunk_left_context + self.chunk_right_context

    def size(self, index):
        """Return an example's size as a float or tuple. This value is used when
        filtering a dataset with ``--max-positions``."""
        return (self.src_sizes[index],
                self.tgt_sizes[index] if self.tgt_sizes is not None else 0)

    def ordered_indices(self):
        """Return an ordered list of indices. Batches will be constructed based
        on this order."""
        if self.shuffle:
            indices = np.random.permutation(len(self)).astype(np.int64)
        else:
            indices = np.arange(len(self), dtype=np.int64)
        if self.buckets is None:
            # sort by target length, then source length
            if self.tgt_sizes is not None:
                indices = indices[np.argsort(self.tgt_sizes[indices],
                                             kind="mergesort")]
            return indices[np.argsort(self.src_sizes[indices],
                                      kind="mergesort")]
        else:
            # sort by bucketed_num_tokens, which is padded_src_len
            return indices[np.argsort(self.bucketed_num_tokens[indices],
                                      kind="mergesort")]

    @property
    def supports_prefetch(self):
        return getattr(self.src, "supports_prefetch", False)

    def prefetch(self, indices):
        self.src.prefetch(indices)
        if self.tgt is not None:
            self.tgt.prefetch(indices)

    def filter_indices_by_size(self, indices, max_sizes):
        """ Filter a list of sample indices. Remove those that are longer
            than specified in max_sizes.

        Args:
            indices (np.array): original array of sample indices
            max_sizes (int or list[int] or tuple[int]): max sample size,
                can be defined separately for src and tgt (then list or tuple)

        Returns:
            np.array: filtered sample array
            list: list of removed indices
        """
        return data_utils.filter_paired_dataset_indices_by_size(
            self.src_sizes,
            self.tgt_sizes,
            indices,
            max_sizes,
        )

    @property
    def can_reuse_epoch_itr_across_epochs(self):
        return False  # to avoid running out of CPU RAM

    def set_epoch(self, epoch):
        super().set_epoch(epoch)
        self.epoch = epoch
        if hasattr(self.src, "set_epoch"):
            self.src.set_epoch(epoch)
        if self.tgt is not None and hasattr(self.tgt, "set_epoch"):
            self.tgt.set_epoch(epoch)
コード例 #3
0
class LanguagePairDataset(FairseqDataset):
    """
    A pair of torch.utils.data.Datasets.

    Args:
        src (torch.utils.data.Dataset): source dataset to wrap
        src_sizes (List[int]): source sentence lengths
        src_dict (~fairseq.data.Dictionary): source vocabulary
        tgt (torch.utils.data.Dataset, optional): target dataset to wrap
        tgt_sizes (List[int], optional): target sentence lengths
        tgt_dict (~fairseq.data.Dictionary, optional): target vocabulary
        left_pad_source (bool, optional): pad source tensors on the left side
            (default: True).
        left_pad_target (bool, optional): pad target tensors on the left side
            (default: False).
        shuffle (bool, optional): shuffle dataset elements before batching
            (default: True).
        input_feeding (bool, optional): create a shifted version of the targets
            to be passed into the model for teacher forcing (default: True).
        remove_eos_from_source (bool, optional): if set, removes eos from end
            of source if it's present (default: False).
        append_eos_to_target (bool, optional): if set, appends eos to end of
            target if it's absent (default: False).
        align_dataset (torch.utils.data.Dataset, optional): dataset
            containing alignments.
        append_bos (bool, optional): if set, appends bos to the beginning of
            source/target sentence.
        num_buckets (int, optional): if set to a value greater than 0, then
            batches will be bucketed into the given number of batch shapes.
    """

    def __init__(
        self, src, src_sizes, src_dict,
        tgt=None, tgt_sizes=None, tgt_dict=None,
        left_pad_source=True, left_pad_target=False,
        shuffle=True, input_feeding=True,
        remove_eos_from_source=False, append_eos_to_target=False,
        align_dataset=None,
        append_bos=False, eos=None,
        num_buckets=0,
        max_source_positions=1024,
        max_target_positions=1024,
    ):
        if tgt_dict is not None:
            assert src_dict.pad() == tgt_dict.pad()
            assert src_dict.eos() == tgt_dict.eos()
            assert src_dict.unk() == tgt_dict.unk()
        if tgt is not None:
            assert len(src) == len(tgt), "Source and target must contain the same number of examples"
        self.src = src
        self.tgt = tgt
        self.src_sizes = np.array(src_sizes)
        self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None
        self.src_dict = src_dict
        self.tgt_dict = tgt_dict
        self.left_pad_source = left_pad_source
        self.left_pad_target = left_pad_target
        self.max_source_positions = max_source_positions
        self.max_target_positions = max_target_positions
        self.shuffle = shuffle
        self.input_feeding = input_feeding
        self.remove_eos_from_source = remove_eos_from_source
        self.append_eos_to_target = append_eos_to_target
        self.align_dataset = align_dataset
        if self.align_dataset is not None:
            assert self.tgt_sizes is not None, "Both source and target needed when alignments are provided"
        self.append_bos = append_bos
        self.eos = (eos if eos is not None else src_dict.eos())

        if num_buckets > 0:
            from fairseq.data import BucketPadLengthDataset
            self.src = BucketPadLengthDataset(
                self.src,
                sizes=self.src_sizes,
                num_buckets=num_buckets,
                pad_idx=self.src_dict.pad(),
                left_pad=self.left_pad_source,
            )
            self.src_sizes = self.src.sizes
            logger.info('bucketing source lengths: {}'.format(list(self.src.buckets)))
            if self.tgt is not None:
                self.tgt = BucketPadLengthDataset(
                    self.tgt,
                    sizes=self.tgt_sizes,
                    num_buckets=num_buckets,
                    pad_idx=self.tgt_dict.pad(),
                    left_pad=self.left_pad_target,
                )
                self.tgt_sizes = self.tgt.sizes
                logger.info('bucketing target lengths: {}'.format(list(self.tgt.buckets)))

            # determine bucket sizes using self.num_tokens, which will return
            # the padded lengths (thanks to BucketPadLengthDataset)
            num_tokens = np.vectorize(self.num_tokens, otypes=[np.long])
            self.bucketed_num_tokens = num_tokens(np.arange(len(self.src)))
            self.buckets = [
                (None, num_tokens)
                for num_tokens in np.unique(self.bucketed_num_tokens)
            ]
        else:
            self.buckets = None

    def get_batch_shapes(self):
        return self.buckets

    def __getitem__(self, index):
        tgt_item = self.tgt[index] if self.tgt is not None else None
        src_item = self.src[index]
        # Append EOS to end of tgt sentence if it does not have an EOS and remove
        # EOS from end of src sentence if it exists. This is useful when we use
        # use existing datasets for opposite directions i.e., when we want to
        # use tgt_dataset as src_dataset and vice versa
        if self.append_eos_to_target:
            eos = self.tgt_dict.eos() if self.tgt_dict else self.src_dict.eos()
            if self.tgt and self.tgt[index][-1] != eos:
                tgt_item = torch.cat([self.tgt[index], torch.LongTensor([eos])])

        if self.append_bos:
            bos = self.tgt_dict.bos() if self.tgt_dict else self.src_dict.bos()
            if self.tgt and self.tgt[index][0] != bos:
                tgt_item = torch.cat([torch.LongTensor([bos]), self.tgt[index]])

            bos = self.src_dict.bos()
            if self.src[index][-1] != bos:
                src_item = torch.cat([torch.LongTensor([bos]), self.src[index]])

        if self.remove_eos_from_source:
            eos = self.src_dict.eos()
            if self.src[index][-1] == eos:
                src_item = self.src[index][:-1]

        example = {
            'id': index,
            'source': src_item,
            'target': tgt_item,
        }
        if self.align_dataset is not None:
            example['alignment'] = self.align_dataset[index]
        return example

    def __len__(self):
        return len(self.src)

    def collater(self, samples):
        """Merge a list of samples to form a mini-batch.

        Args:
            samples (List[dict]): samples to collate

        Returns:
            dict: a mini-batch with the following keys:

                - `id` (LongTensor): example IDs in the original input order
                - `ntokens` (int): total number of tokens in the batch
                - `net_input` (dict): the input to the Model, containing keys:

                  - `src_tokens` (LongTensor): a padded 2D Tensor of tokens in
                    the source sentence of shape `(bsz, src_len)`. Padding will
                    appear on the left if *left_pad_source* is ``True``.
                  - `src_lengths` (LongTensor): 1D Tensor of the unpadded
                    lengths of each source sentence of shape `(bsz)`
                  - `prev_output_tokens` (LongTensor): a padded 2D Tensor of
                    tokens in the target sentence, shifted right by one
                    position for teacher forcing, of shape `(bsz, tgt_len)`.
                    This key will not be present if *input_feeding* is
                    ``False``.  Padding will appear on the left if
                    *left_pad_target* is ``True``.

                - `target` (LongTensor): a padded 2D Tensor of tokens in the
                  target sentence of shape `(bsz, tgt_len)`. Padding will appear
                  on the left if *left_pad_target* is ``True``.
        """
        return collate(
            samples,
            pad_idx=self.src_dict.pad(),
            eos_idx=self.eos,
            left_pad_source=self.left_pad_source,
            left_pad_target=self.left_pad_target,
            input_feeding=self.input_feeding,
        )

    def get_dummy_batch(self, num_tokens, max_positions, src_len=128, tgt_len=128):
        """Return a dummy batch with a given number of tokens."""
        src_len, tgt_len = utils.resolve_max_positions(
            (src_len, tgt_len),
            max_positions,
            (self.max_source_positions, self.max_target_positions),
        )
        bsz = max(num_tokens // max(src_len, tgt_len), 1)
        return self.collater([
            {
                'id': i,
                'source': self.src_dict.dummy_sentence(src_len),
                'target': self.tgt_dict.dummy_sentence(tgt_len) if self.tgt_dict is not None else None,
            }
            for i in range(bsz)
        ])

    def num_tokens(self, index):
        """Return the number of tokens in a sample. This value is used to
        enforce ``--max-tokens`` during batching."""
        return max(self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0)

    def size(self, index):
        """Return an example's size as a float or tuple. This value is used when
        filtering a dataset with ``--max-positions``."""
        return (self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0)

    def ordered_indices(self):
        """Return an ordered list of indices. Batches will be constructed based
        on this order."""
        if self.shuffle:
            indices = np.random.permutation(len(self))
        else:
            indices = np.arange(len(self))
        if self.buckets is None:
            # sort by target length, then source length
            if self.tgt_sizes is not None:
                indices = indices[
                    np.argsort(self.tgt_sizes[indices], kind='mergesort')
                ]
            return indices[np.argsort(self.src_sizes[indices], kind='mergesort')]
        else:
            # sort by bucketed_num_tokens, which is:
            #   max(padded_src_len, padded_tgt_len)
            return indices[
                np.argsort(self.bucketed_num_tokens[indices], kind='mergesort')
            ]

    @property
    def supports_prefetch(self):
        return (
            getattr(self.src, 'supports_prefetch', False)
            and (getattr(self.tgt, 'supports_prefetch', False) or self.tgt is None)
        )

    def prefetch(self, indices):
        self.src.prefetch(indices)
        if self.tgt is not None:
            self.tgt.prefetch(indices)
        if self.align_dataset is not None:
            self.align_dataset.prefetch(indices)
コード例 #4
0
class APEDataset(LanguagePairDataset):
    def __init__(self,
                 src,
                 src_sizes,
                 src_dict,
                 tgt=None,
                 tgt_sizes=None,
                 tgt_dict=None,
                 mt=None,
                 mt_sizes=None,
                 term=None,
                 term_sizes=None,
                 src_factor=None,
                 src_factor_sizes=None,
                 mt_factor=None,
                 mt_factor_sizes=None,
                 left_pad_source=True,
                 left_pad_target=False,
                 shuffle=True,
                 input_feeding=True,
                 remove_eos_from_source=False,
                 append_eos_to_target=False,
                 align_dataset=None,
                 append_bos=False,
                 eos=None,
                 num_buckets=0,
                 input_type='src_only'):
        """
        Add mt to LanguagePairDataset

        Additional Args:
            mt (torch.utils.data.Dataset, optional): mt dataset to wrap
            mt_sizes (List[int], optional): mt sentence lengths
        """
        super().__init__(
            src,
            src_sizes,
            src_dict,
            tgt=tgt,
            tgt_sizes=tgt_sizes,
            tgt_dict=tgt_dict,
            left_pad_source=left_pad_source,
            left_pad_target=left_pad_target,
            shuffle=shuffle,
            input_feeding=input_feeding,
            remove_eos_from_source=remove_eos_from_source,
            append_eos_to_target=append_eos_to_target,
            align_dataset=align_dataset,
            append_bos=append_bos,
            eos=eos,
            num_buckets=num_buckets,
        )
        self.mt = mt
        self.mt_sizes = np.array(mt_sizes) if mt_sizes is not None else None
        self.term = term
        self.term_sizes = np.array(
            term_sizes) if term_sizes is not None else None
        self.src_factor = src_factor
        self.src_factor_sizes = np.array(
            src_factor_sizes) if src_factor_sizes is not None else None
        self.mt_factor = mt_factor
        self.mt_factor_sizes = np.array(
            mt_factor_sizes) if mt_factor_sizes is not None else None

        if num_buckets > 0:
            from fairseq.data import BucketPadLengthDataset
            if self.mt is not None:
                self.mt = BucketPadLengthDataset(
                    self.mt,
                    sizes=self.mt_sizes,
                    num_buckets=num_buckets,
                    pad_idx=self.tgt_dict.pad(),
                    left_pad=self.left_pad_target,
                )
                self.mt_sizes = self.mt.sizes
                logger.info('bucketing mt lengths: {}'.format(
                    list(self.mt.buckets)))

            if self.term is not None:
                self.term = BucketPadLengthDataset(
                    self.term,
                    sizes=self.term_sizes,
                    num_buckets=num_buckets,
                    pad_idx=self.tgt_dict.pad(),
                    left_pad=self.left_pad_target,
                )
                self.term_sizes = self.term.sizes
                logger.info('bucketing term lengths: {}'.format(
                    list(self.term.buckets)))

            if self.src_factor is not None:
                self.src_factor = BucketPadLengthDataset(
                    self.src_factor,
                    sizes=self.src_factor_sizes,
                    num_buckets=num_buckets,
                    pad_idx=self.tgt_dict.pad(),
                    left_pad=self.left_pad_target,
                )
                self.src_factor_sizes = self.src_factor.sizes
                logger.info('bucketing src_factor lengths: {}'.format(
                    list(self.src_factor.buckets)))

            if self.mt_factor is not None:
                self.mt_factor = BucketPadLengthDataset(
                    self.mt_factor,
                    sizes=self.mt_factor_sizes,
                    num_buckets=num_buckets,
                    pad_idx=self.tgt_dict.pad(),
                    left_pad=self.left_pad_target,
                )
                self.mt_factor_sizes = self.mt_factor.sizes
                logger.info('bucketing src_factor lengths: {}'.format(
                    list(self.mt_factor.buckets)))

        self.input_type = input_type

    def __getitem__(self, index):
        example = super().__getitem__(index)

        mt_item = self.mt[index] if self.mt is not None else None
        term_item = self.term[index] if self.term is not None else None
        src_factor_item = self.src_factor[
            index] if self.src_factor is not None else None
        mt_factor_item = self.mt_factor[
            index] if self.mt_factor is not None else None

        # Append EOS to end of tgt sentence if it does not have an EOS and remove
        # EOS from end of src sentence if it exists. This is useful when we use
        # use existing datasets for opposite directions i.e., when we want to
        # use tgt_dataset as src_dataset and vice versa
        if self.append_eos_to_target:
            eos = self.tgt_dict.eos() if self.tgt_dict else self.src_dict.eos()
            if self.mt and self.mt[index][-1] != eos:
                mt_item = torch.cat([self.mt[index], torch.LongTensor([eos])])
            if self.term and self.term[index][-1] != eos:
                term_item = torch.cat(
                    [self.term[index],
                     torch.LongTensor([eos])])
            if self.src_factor and self.src_factor[index][-1] != eos:
                src_factor_item = torch.cat(
                    [self.src_factor[index],
                     torch.LongTensor([eos])])
            if self.mt_factor and self.mt_factor[index][-1] != eos:
                mt_factor_item = torch.cat(
                    [self.mt_factor[index],
                     torch.LongTensor([eos])])

        if self.append_bos:
            bos = self.tgt_dict.bos() if self.tgt_dict else self.src_dict.bos()
            if self.mt and self.mt[index][0] != bos:
                mt_item = torch.cat([torch.LongTensor([bos]), self.mt[index]])
            if self.term and self.term[index][0] != bos:
                term_item = torch.cat(
                    [torch.LongTensor([bos]), self.term[index]])
            if self.src_factor and self.src_factor[index][0] != bos:
                src_factor_item = torch.cat(
                    [torch.LongTensor([bos]), self.src_factor[index]])
            if self.mt_factor and self.mt_factor[index][0] != bos:
                mt_factor_item = torch.cat(
                    [torch.LongTensor([bos]), self.mt_factor[index]])

        if self.input_type == "concatenate":
            src_item = example["source"]
            eos = [src_item[-1]]
            combined_item = src_item[:-1]
            mt_sep = [self.src_dict.index('<sep>')]
            combined_item = torch.cat(
                [combined_item,
                 torch.LongTensor(mt_sep), mt_item[1:-1]])
            combined_item = torch.cat([combined_item, torch.LongTensor(eos)])
            example['source'] = combined_item

        example["mt"] = mt_item
        example["term"] = term_item
        example["src_factor"] = src_factor_item
        example["mt_factor"] = mt_factor_item
        return example

    def collater(self, samples):
        """Merge a list of samples to form a mini-batch."""
        return collate(
            samples,
            pad_idx=self.src_dict.pad(),
            eos_idx=self.eos,
            left_pad_source=self.left_pad_source,
            left_pad_target=self.left_pad_target,
            input_feeding=self.input_feeding,
            input_type=self.input_type,
        )

    def num_tokens(self, index):
        """Return the number of tokens in a sample. This value is used to
        enforce ``--max-tokens`` during batching."""
        return max(
            self.src_sizes[index],
            self.tgt_sizes[index] if self.tgt_sizes is not None else 0,
            self.mt_sizes[index] if self.mt_sizes is not None else 0,
            self.term_sizes[index] if self.term_sizes is not None else 0,
            self.src_factor_sizes[index]
            if self.src_factor_sizes is not None else 0,
            self.mt_factor_sizes[index]
            if self.mt_factor_sizes is not None else 0,
        )

    def size(self, index):
        """Return an example's size as a float or tuple. This value is used when
        filtering a dataset with ``--max-positions``."""
        return (
            self.src_sizes[index],
            self.tgt_sizes[index] if self.tgt_sizes is not None else 0,
            self.mt_sizes[index] if self.mt_sizes is not None else 0,
            self.term_sizes[index] if self.term_sizes is not None else 0,
            self.src_factor_sizes[index]
            if self.src_factor_sizes is not None else 0,
            self.mt_factor_sizes[index]
            if self.mt_factor_sizes is not None else 0,
        )

    def ordered_indices(self):
        """Return an ordered list of indices. Batches will be constructed based
        on this order."""
        if self.shuffle:
            indices = np.random.permutation(len(self))
        else:
            indices = np.arange(len(self))
        if self.buckets is None:
            # sort by target length, mt_length then source length
            if self.tgt_sizes is not None:
                indices = indices[np.argsort(self.tgt_sizes[indices],
                                             kind='mergesort')]
            if self.mt_sizes is not None:
                indices = indices[np.argsort(self.mt_sizes[indices],
                                             kind='mergesort')]
            if self.term_sizes is not None:
                indices = indices[np.argsort(self.term_sizes[indices],
                                             kind='mergesort')]
            if self.src_factor_sizes is not None:
                indices = indices[np.argsort(self.src_factor_sizes[indices],
                                             kind='mergesort')]
            if self.mt_factor_sizes is not None:
                indices = indices[np.argsort(self.mt_factor_sizes[indices],
                                             kind='mergesort')]
            return indices[np.argsort(self.src_sizes[indices],
                                      kind='mergesort')]
        else:
            # sort by bucketed_num_tokens, which is:
            #   max(padded_src_len, padded_tgt_len)
            return indices[np.argsort(self.bucketed_num_tokens[indices],
                                      kind='mergesort')]

    @property
    def supports_prefetch(self):
        return (getattr(self.src, 'supports_prefetch', False)
                and (getattr(self.tgt, 'supports_prefetch', False)
                     or self.tgt is None)
                and (getattr(self.mt, 'supports_prefetch', False)
                     or self.mt is None)
                and (getattr(self.term, 'supports_prefetch', False)
                     or self.term is None)
                and (getattr(self.src_factor, 'supports_prefetch', False)
                     or self.src_factor is None)
                and (getattr(self.mt_factor, 'supports_prefetch', False)
                     or self.mt_factor is None))

    def prefetch(self, indices):
        self.src.prefetch(indices)
        if self.tgt is not None:
            self.tgt.prefetch(indices)
        if self.mt is not None:
            self.mt.prefetch(indices)
        if self.term is not None:
            self.term.prefetch(indices)
        if self.src_factor is not None:
            self.src_factor.prefetch(indices)
        if self.mt_factor is not None:
            self.mt_factor.prefetch(indices)
コード例 #5
0
class LanguagePairDataset(FairseqDataset):
    """
    A pair of torch.utils.data.Datasets.

    Args:
        src (torch.utils.data.Dataset): source dataset to wrap
        src_sizes (List[int]): source sentence lengths
        src_dict (~fairseq.data.Dictionary): source vocabulary
        tgt (torch.utils.data.Dataset, optional): target dataset to wrap
        tgt_sizes (List[int], optional): target sentence lengths
        tgt_dict (~fairseq.data.Dictionary, optional): target vocabulary
        left_pad_source (bool, optional): pad source tensors on the left side
            (default: True).
        left_pad_target (bool, optional): pad target tensors on the left side
            (default: False).
        shuffle (bool, optional): shuffle dataset elements before batching
            (default: True).
        input_feeding (bool, optional): create a shifted version of the targets
            to be passed into the model for teacher forcing (default: True).
        remove_eos_from_source (bool, optional): if set, removes eos from end
            of source if it's present (default: False).
        append_eos_to_target (bool, optional): if set, appends eos to end of
            target if it's absent (default: False).
        align_dataset (torch.utils.data.Dataset, optional): dataset
            containing alignments.
        constraints (Tensor, optional): 2d tensor with a concatenated, zero-
            delimited list of constraints for each sentence.
        append_bos (bool, optional): if set, appends bos to the beginning of
            source/target sentence.
        num_buckets (int, optional): if set to a value greater than 0, then
            batches will be bucketed into the given number of batch shapes.
        src_lang_id (int, optional): source language ID, if set, the collated batch
            will contain a field 'src_lang_id' in 'net_input' which indicates the
            source language of the samples.
        tgt_lang_id (int, optional): target language ID, if set, the collated batch
            will contain a field 'tgt_lang_id' which indicates the target language
             of the samples.
    """

    def __init__(
            self,
            src,
            src_sizes,
            src_dict,
            tgt=None,
            tgt_sizes=None,
            tgt_dict=None,
            left_pad_source=True,
            left_pad_target=False,
            shuffle=True,
            input_feeding=True,
            remove_eos_from_source=False,
            append_eos_to_target=False,
            align_dataset=None,
            constraints=None,
            append_bos=False,
            eos=None,
            num_buckets=0,
            src_lang_id=None,
            tgt_lang_id=None,
            pad_to_multiple=1,
            add_lang_token=False,
            lang_pair=None,
            shuffle_lang_pair=False,
            args=None,
    ):
        if tgt_dict is not None:
            assert src_dict.pad() == tgt_dict.pad()
            assert src_dict.eos() == tgt_dict.eos()
            assert src_dict.unk() == tgt_dict.unk()
        if tgt is not None:
            assert len(src) == len(
                tgt
            ), "Source and target must contain the same number of examples"
        self.src = src
        self.tgt = tgt
        self.src_sizes = np.array(src_sizes)
        self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None
        self.sizes = (
            np.vstack((self.src_sizes, self.tgt_sizes)).T
            if self.tgt_sizes is not None
            else self.src_sizes
        )
        self.src_dict = src_dict
        self.tgt_dict = tgt_dict
        self.left_pad_source = left_pad_source
        self.left_pad_target = left_pad_target
        self.shuffle = shuffle
        self.input_feeding = input_feeding
        self.remove_eos_from_source = remove_eos_from_source
        self.append_eos_to_target = append_eos_to_target
        self.align_dataset = align_dataset
        self.add_lang_token = add_lang_token
        self.lang_pair = lang_pair
        self.args=args
        if self.align_dataset is not None:
            assert (
                    self.tgt_sizes is not None
            ), "Both source and target needed when alignments are provided"
        self.constraints = constraints
        self.append_bos = append_bos
        self.eos = eos if eos is not None else src_dict.eos()
        self.src_lang_id = src_lang_id
        self.tgt_lang_id = tgt_lang_id
        if num_buckets > 0:
            from fairseq.data import BucketPadLengthDataset

            self.src = BucketPadLengthDataset(
                self.src,
                sizes=self.src_sizes,
                num_buckets=num_buckets,
                pad_idx=self.src_dict.pad(),
                left_pad=self.left_pad_source,
            )
            self.src_sizes = self.src.sizes
            logger.info("bucketing source lengths: {}".format(list(self.src.buckets)))
            if self.tgt is not None:
                self.tgt = BucketPadLengthDataset(
                    self.tgt,
                    sizes=self.tgt_sizes,
                    num_buckets=num_buckets,
                    pad_idx=self.tgt_dict.pad(),
                    left_pad=self.left_pad_target,
                )
                self.tgt_sizes = self.tgt.sizes
                logger.info(
                    "bucketing target lengths: {}".format(list(self.tgt.buckets))
                )

            # determine bucket sizes using self.num_tokens, which will return
            # the padded lengths (thanks to BucketPadLengthDataset)
            num_tokens = np.vectorize(self.num_tokens, otypes=[np.long])
            self.bucketed_num_tokens = num_tokens(np.arange(len(self.src)))
            self.buckets = [
                (None, num_tokens) for num_tokens in np.unique(self.bucketed_num_tokens)
            ]
        else:
            self.buckets = None
        self.pad_to_multiple = pad_to_multiple
        self.shuffle_lang_pair = shuffle_lang_pair
        self.src_mask_index = self.src_dict.index("<mask>")
        self.tgt_mask_index = self.tgt_dict.index("<mask>")
        self.padding_index = self.tgt_dict.index("<pad>")
        self.ratio = 0.5

        if args != None  and args.task == "translation_from_pretrained_maskdecode_multi":
            self.random_ratio = args.mask_random
            self.replace_length = args.replace_length
            if self.replace_length not in [-1, 0, 1]:
                raise ValueError(f"invalid arg: replace_length={self.replace_length}")
            if args.mask_length not in ["subword", "word", "span-poisson"]:
                raise ValueError(f"invalid arg: mask-length={args.mask_length}")
            if args.mask_length == "subword" and args.replace_length not in [0, 1]:
                raise ValueError(f"if using subwords, use replace-length=1 or 0")
            self.mask_span_distribution = None
            if self.args.mask_length == "span-poisson":
                _lambda =  self.args.poisson_lambda
                lambda_to_the_k = 1
                e_to_the_minus_lambda = math.exp(-_lambda)
                k_factorial = 1
                ps = []
                for k in range(0, 128):
                    ps.append(e_to_the_minus_lambda * lambda_to_the_k / k_factorial)
                    lambda_to_the_k *= _lambda
                    k_factorial *= k + 1
                    if ps[-1] < 0.0000001:
                        break
                ps = torch.FloatTensor(ps)
                self.mask_span_distribution = torch.distributions.Categorical(ps)
            #print(self.mask_span_distribution)


    def get_batch_shapes(self):
        return self.buckets


    def word_starts(self, source):
        is_word_start=[]
        sources = [self.tgt_dict[token] for token in source]
        if True:
            for token in sources:
                #print(token)
                if token[0] == "▁":
                    is_word_start.append(1)
                else:
                    is_word_start.append(0)
            is_word_start = torch.tensor(is_word_start)
        else:
            is_word_start = torch.ones(source.size())
        is_word_start[0] = 0
        is_word_start[-1] = 0
        return is_word_start

    def add_whole_word_mask(self, source, p):
        is_word_start = self.word_starts(source)
        num_to_mask = int(math.ceil(is_word_start.float().sum() * p))
        # print(num_to_mask)
        num_inserts = 0
        if num_to_mask == 0:
            return source

        if self.mask_span_distribution is not None:
            lengths = self.mask_span_distribution.sample(sample_shape=(num_to_mask,))
            # print("mask_length:{}".format(lengths))
            # exit(0)
            # Make sure we have enough to mask
            cum_length = torch.cumsum(lengths, 0)
            while cum_length[-1] < num_to_mask:
                lengths = torch.cat(
                    [
                        lengths,
                        self.mask_span_distribution.sample(sample_shape=(num_to_mask,)),
                    ],
                    dim=0,
                )
                cum_length = torch.cumsum(lengths, 0)

            # Trim to masking budget
            i = 0
            while cum_length[i] < num_to_mask:
                i += 1
            lengths[i] = num_to_mask - (0 if i == 0 else cum_length[i - 1])
            num_to_mask = i + 1
            lengths = lengths[:num_to_mask]

            # Handle 0-length mask (inserts) separately
            lengths = lengths[lengths > 0]
            num_inserts = num_to_mask - lengths.size(0)
            num_to_mask -= num_inserts
            if num_to_mask == 0:
                return self.add_insertion_noise(source, num_inserts / source.size(0))

            assert (lengths > 0).all()
        else:
            lengths = torch.ones((num_to_mask,)).long()
        assert is_word_start[-1] == 0
        word_starts = is_word_start.nonzero(as_tuple=False)
        indices = word_starts[
            torch.randperm(word_starts.size(0))[:num_to_mask]
        ].squeeze(1)
        mask_random = torch.FloatTensor(num_to_mask).uniform_() < self.random_ratio

        source_length = source.size(0)
        assert source_length - 1 not in indices
        to_keep = torch.ones(source_length, dtype=torch.bool)
        is_word_start[
            -1
        ] = 255  # acts as a long length, so spans don't go over the end of doc
        if self.replace_length == 0:
            to_keep[indices] = 0
        else:
            # keep index, but replace it with [MASK]
            source[indices] = self.tgt_mask_index
            source[indices[mask_random]] = torch.randint(
                1, len(self.tgt_dict), size=(mask_random.sum(),)
            )

        if self.mask_span_distribution is not None:
            assert len(lengths.size()) == 1
            assert lengths.size() == indices.size()
            lengths -= 1
            while indices.size(0) > 0:
                assert lengths.size() == indices.size()
                lengths -= is_word_start[indices + 1].long()
                uncompleted = lengths >= 0
                indices = indices[uncompleted] + 1
                mask_random = mask_random[uncompleted]
                lengths = lengths[uncompleted]
                if self.replace_length != -1:
                    # delete token
                    to_keep[indices] = 0
                else:
                    # keep index, but replace it with [MASK]
                    source[indices] = self.tgt_mask_index
                    source[indices[mask_random]] = torch.randint(
                        1, len(self.tgt_dict), size=(mask_random.sum(),)
                    )
        else:
            # A bit faster when all lengths are 1
            while indices.size(0) > 0:
                uncompleted = is_word_start[indices + 1] == 0
                indices = indices[uncompleted] + 1
                mask_random = mask_random[uncompleted]
                if self.replace_length != -1:
                    # delete token
                    to_keep[indices] = 0
                else:
                    # keep index, but replace it with [MASK]
                    source[indices] = self.tgt_mask_index
                    source[indices[mask_random]] = torch.randint(
                        1, len(self.tgt_dict), size=(mask_random.sum(),)
                    )

                assert source_length - 1 not in indices

        source = source[to_keep]

        if num_inserts > 0:
            source = self.add_insertion_noise(source, num_inserts / source.size(0))

        return source

    def add_insertion_noise(self, tokens, p):
        if p == 0.0:
            return tokens

        num_tokens = len(tokens)
        n = int(math.ceil(num_tokens * p))

        noise_indices = torch.randperm(num_tokens + n - 2)[:n] + 1
        noise_mask = torch.zeros(size=(num_tokens + n,), dtype=torch.bool)
        noise_mask[noise_indices] = 1
        result = torch.LongTensor(n + len(tokens)).fill_(-1)

        num_random = int(math.ceil(n * self.random_ratio))
        result[noise_indices[num_random:]] = self.tgt_mask_index
        result[noise_indices[:num_random]] = torch.randint(
            low=1, high=len(self.tgt_dict), size=(num_random,)
        )

        result[~noise_mask] = tokens

        assert (result >= 0).all()
        return result

    def add_rolling_noise(self, tokens):
        offset = np.random.randint(1, max(1, tokens.size(-1) - 1) + 1)
        tokens = torch.cat(
            (tokens[0:1], tokens[offset:-1], tokens[1:offset], tokens[-1:]),
            dim=0,
        )
        return tokens


    def denoising_mono(self,src_tokens):
        source=src_tokens
        target=src_tokens.tolist().copy()
        source = self.add_whole_word_mask(source, self.args.mask)
        source = self.add_rolling_noise(source)
        return torch.tensor(source),torch.tensor(target)


    def span_masking_mono(self, src_tokens):
        src_list = src_tokens.tolist()
        lang_id = src_list[0]
        start, length = self.mask_interval(len(src_list))
        source = []
        for i, w in enumerate(src_list):
            if i >= start and i < start + length:
                w = self.mask_word(w)
            if w is not None:
                source.append(w)

        output = [self.padding_index] * len(src_list)
        output[0] = lang_id
        output[start:start + length] = src_list[start:start + length].copy()
        if start + length < len(src_list) and  output[start + length] == self.padding_index:
            output[start + length] = self.tgt_dict.eos_index

        target = [self.tgt_mask_index] * len(src_list)
        target[0] = lang_id
        target[-1] = self.tgt_dict.eos_index
        target[start:start + length] = src_list[start: start + length].copy()

        # print(start)
        # print("src_raw:{}\nsrc_input:{}\ntgt_out:{}\ntgt_input:{}\n".format(
        #     " ".join([self.tgt_dict[vocab_i] for vocab_i in src_list]),
        #     " ".join([self.tgt_dict[vocab_i] for vocab_i in source]),
        #     " ".join([self.tgt_dict[vocab_i] for vocab_i in output]),
        #     " ".join([self.tgt_dict[vocab_i] for vocab_i in target])))

        assert len(target) == len(output)
        return torch.tensor(source), torch.tensor(target), torch.tensor(output)

    def __getitem__(self, index):
        tgt_item = self.tgt[index] if self.tgt is not None else None
        src_item = self.src[index]

        prev_output_tokens = None
        is_swap = random.random()
        if self.lang_pair == None or self.lang_pair[0] != self.lang_pair[1]:
            if is_swap > 0.5 and self.shuffle_lang_pair:
                src_item, tgt_item = tgt_item, src_item
        else:
            if self.args.bart_mono:
                src_item,tgt_item = self.denoising_mono(src_item)
            else:
                src_item, prev_output_tokens, tgt_item = self.span_masking_mono(src_item)

        # Append EOS to end of tgt sentence if it does not have an EOS and remove
        # EOS from end of src sentence if it exists. This is useful when we use
        # use existing datasets for opposite directions i.e., when we want to
        # use tgt_dataset as src_dataset and vice versa
        # print(self.append_eos_to_target, self.append_bos, self.remove_eos_from_source) : all False
        if self.append_eos_to_target:
            eos = self.tgt_dict.eos() if self.tgt_dict else self.src_dict.eos()
            if self.tgt and self.tgt[index][-1] != eos:
                tgt_item = torch.cat([self.tgt[index], torch.LongTensor([eos])])

        if self.append_bos:
            bos = self.tgt_dict.bos() if self.tgt_dict else self.src_dict.bos()
            if self.tgt and self.tgt[index][0] != bos:
                tgt_item = torch.cat([torch.LongTensor([bos]), self.tgt[index]])

            bos = self.src_dict.bos()
            if self.src[index][0] != bos:
                src_item = torch.cat([torch.LongTensor([bos]), self.src[index]])

        if self.remove_eos_from_source:
            eos = self.src_dict.eos()
            if self.src[index][-1] == eos:
                src_item = self.src[index][:-1]
        example = {
            "id": index,
            "source": src_item,
            "target": tgt_item,
            "prev_output_tokens": prev_output_tokens,
        }

        # print("src_input:{}\ntgt_input:{}\n".format(
        #      " ".join([self.src_dict[vocab_i] for vocab_i in src_item]),
        #      " ".join([self.tgt_dict[vocab_i] for vocab_i in tgt_item])))

        if self.align_dataset is not None:
            example["alignment"] = self.align_dataset[index]
        if self.constraints is not None:
            example["constraints"] = self.constraints[index]
        return example

    def __len__(self):
        return len(self.src)

    def collater(self, samples, pad_to_length=None):
        """Merge a list of samples to form a mini-batch.

        Args:
            samples (List[dict]): samples to collate
            pad_to_length (dict, optional): a dictionary of
                {'source': source_pad_to_length, 'target': target_pad_to_length}
                to indicate the max length to pad to in source and target respectively.

        Returns:
            dict: a mini-batch with the following keys:

                - `id` (LongTensor): example IDs in the original input order
                - `ntokens` (int): total number of tokens in the batch
                - `net_input` (dict): the input to the Model, containing keys:

                  - `src_tokens` (LongTensor): a padded 2D Tensor of tokens in
                    the source sentence of shape `(bsz, src_len)`. Padding will
                    appear on the left if *left_pad_source* is ``True``.
                  - `src_lengths` (LongTensor): 1D Tensor of the unpadded
                    lengths of each source sentence of shape `(bsz)`
                  - `prev_output_tokens` (LongTensor): a padded 2D Tensor of
                    tokens in the target sentence, shifted right by one
                    position for teacher forcing, of shape `(bsz, tgt_len)`.
                    This key will not be present if *input_feeding* is
                    ``False``.  Padding will appear on the left if
                    *left_pad_target* is ``True``.
                  - `src_lang_id` (LongTensor): a long Tensor which contains source
                    language IDs of each sample in the batch

                - `target` (LongTensor): a padded 2D Tensor of tokens in the
                  target sentence of shape `(bsz, tgt_len)`. Padding will appear
                  on the left if *left_pad_target* is ``True``.
                - `tgt_lang_id` (LongTensor): a long Tensor which contains target language
                   IDs of each sample in the batch
        """
        res = collate(
            samples,
            pad_idx=self.src_dict.pad(),
            eos_idx=self.eos,
            left_pad_source=self.left_pad_source,
            left_pad_target=self.left_pad_target,
            input_feeding=self.input_feeding,
            pad_to_length=pad_to_length,
            pad_to_multiple=self.pad_to_multiple,
        )
        if self.src_lang_id is not None or self.tgt_lang_id is not None:
            src_tokens = res["net_input"]["src_tokens"]
            bsz = src_tokens.size(0)
            if self.src_lang_id is not None:
                res["net_input"]["src_lang_id"] = (
                    torch.LongTensor([[self.src_lang_id]]).expand(bsz, 1).to(src_tokens)
                )
            if self.tgt_lang_id is not None:
                res["tgt_lang_id"] = (
                    torch.LongTensor([[self.tgt_lang_id]]).expand(bsz, 1).to(src_tokens)
                )

        # for i in range(len(res["id"])):
        #     print("src_raw:{}\ntgt_output:{}\ntgt_input:{}\n".format(
        #          " ".join([self.src_dict[vocab_i] for vocab_i in res["net_input"]["src_tokens"][i]]),
        #          " ".join([self.tgt_dict[vocab_i] for vocab_i in res["target"][i]]),
        #          " ".join([self.tgt_dict[vocab_i] for vocab_i in res["net_input"]["prev_output_tokens"][i]]),))
        # exit(0)

        return res

    def num_tokens(self, index):
        """Return the number of tokens in a sample. This value is used to
        enforce ``--max-tokens`` during batching."""

        return max(
            self.src_sizes[index],
            self.tgt_sizes[index] if self.tgt_sizes is not None else 0,
        )

    def size(self, index):
        """Return an example's size as a float or tuple. This value is used when
        filtering a dataset with ``--max-positions``."""
        return (
            self.src_sizes[index],
            self.tgt_sizes[index] if self.tgt_sizes is not None else 0,
        )

    def ordered_indices(self):
        """Return an ordered list of indices. Batches will be constructed based
        on this order."""
        if self.shuffle:
            indices = np.random.permutation(len(self)).astype(np.int64)
        else:
            indices = np.arange(len(self), dtype=np.int64)
        if self.buckets is None:
            # sort by target length, then source length
            if self.tgt_sizes is not None:
                indices = indices[np.argsort(self.tgt_sizes[indices], kind="mergesort")]
            return indices[np.argsort(self.src_sizes[indices], kind="mergesort")]
        else:
            # sort by bucketed_num_tokens, which is:
            #   max(padded_src_len, padded_tgt_len)
            return indices[
                np.argsort(self.bucketed_num_tokens[indices], kind="mergesort")
            ]

    @property
    def supports_prefetch(self):
        return getattr(self.src, "supports_prefetch", False) and (
                getattr(self.tgt, "supports_prefetch", False) or self.tgt is None
        )

    def prefetch(self, indices):
        self.src.prefetch(indices)
        if self.tgt is not None:
            self.tgt.prefetch(indices)
        if self.align_dataset is not None:
            self.align_dataset.prefetch(indices)

    def filter_indices_by_size(self, indices, max_sizes):
        """Filter a list of sample indices. Remove those that are longer
            than specified in max_sizes.

        Args:
            indices (np.array): original array of sample indices
            max_sizes (int or list[int] or tuple[int]): max sample size,
                can be defined separately for src and tgt (then list or tuple)

        Returns:
            np.array: filtered sample array
            list: list of removed indices
        """
        return data_utils.filter_paired_dataset_indices_by_size(
            self.src_sizes,
            self.tgt_sizes,
            indices,
            max_sizes,
        )

    def mask_start(self, end):
        p = np.random.random()
        if p >= 0.8 or 1 >= end:
            return 1
        elif p >= 0.6:
            return end
        else:
            return np.random.randint(1, end)

    def mask_word(self, w):
        p = np.random.random()
        if p >= 0.2:
            return self.src_mask_index
        elif p >= 0.1:
            return np.random.randint(self.tgt_dict.nspecial, len(self.tgt_dict))
        else:
            return w

    def random_word(self, w, pred_probs):
        self.pred_probs = [0, 0, 1]
        cands = [self.src_mask_index, np.random.randint(self.tgt_dict.nspecial, len(self.tgt_dict)), w]
        prob = torch.multinomial(self.pred_probs, 1, replacement=True)
        return cands[prob]

    def mask_interval(self, l):
        mask_length = round(l * self.ratio)
        mask_length = max(1, mask_length)
        mask_start = self.mask_start(l - mask_length)
        return mask_start, mask_length