class AsrChainDataset(FairseqDataset):
    """
    A pair of torch.utils.data.Datasets.

    Args:
        src (torch.utils.data.Dataset): source dataset to wrap
        src_sizes (List[int]): source sentence lengths
        tgt (espresso.data.NumeratorGraphDataset, optional): target numerator graph dataset to wrap
        tgt_sizes (List[int], optional): target sizes (num of states in the numerator graph)
        text  (torch.utils.data.Dataset, optional): text dataset to wrap
        shuffle (bool, optional): shuffle dataset elements before batching
            (default: True).
        num_buckets (int, optional): if set to a value greater than 0, then
            batches will be bucketed into the given number of batch shapes.
    """
    def __init__(
        self,
        src,
        src_sizes,
        tgt=None,
        tgt_sizes=None,
        text=None,
        shuffle=True,
        num_buckets=0,
    ):
        self.src = src
        self.tgt = tgt
        self.src_sizes = np.array(src_sizes)
        self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None
        self.text = text
        self.shuffle = shuffle
        self.epoch = 1
        num_before_matching = len(self.src.utt_ids)
        if self.tgt is not None:
            self._match_src_tgt()
        if self.text is not None:
            changed = self._match_src_text()
            if self.tgt is not None and changed:
                self._match_src_tgt()
        num_after_matching = len(self.src.utt_ids)
        num_removed = num_before_matching - num_after_matching
        if num_removed > 0:
            logger.warning(
                "Removed {} examples due to empty numerator graphs or missing entries, "
                "{} remaining".format(num_removed, num_after_matching))

        if num_buckets > 0:
            from espresso.data import FeatBucketPadLengthDataset
            self.src = FeatBucketPadLengthDataset(
                self.src,
                sizes=self.src_sizes,
                num_buckets=num_buckets,
                pad_idx=0.0,
                left_pad=False,
            )
            self.src_sizes = self.src.sizes
            logger.info("bucketing source lengths: {}".format(
                list(self.src.buckets)))

            # determine bucket sizes using self.num_tokens, which will return
            # the padded lengths (thanks to FeatBucketPadLengthDataset)
            num_tokens = np.vectorize(self.num_tokens, otypes=[np.long])
            self.bucketed_num_tokens = num_tokens(np.arange(len(self.src)))
            self.buckets = [
                (None, num_tokens)
                for num_tokens in np.unique(self.bucketed_num_tokens)
            ]
        else:
            self.buckets = None

    def _match_src_tgt(self):
        """Makes utterances in src and tgt the same order in terms of
        their utt_ids. Removes those that are only present in one of them."""
        assert self.tgt is not None
        if self.src.utt_ids == self.tgt.utt_ids:
            return
        tgt_utt_ids_set = set(self.tgt.utt_ids)
        src_indices = [
            i for i, id in enumerate(self.src.utt_ids) if id in tgt_utt_ids_set
        ]
        self.src.filter_and_reorder(src_indices)
        self.src_sizes = np.array(self.src.sizes)
        try:
            tgt_indices = list(map(self.tgt.utt_ids.index, self.src.utt_ids))
        except ValueError:
            raise ValueError(
                "Unable to find some utt_id(s) in tgt. which is unlikely to happen. "
                "Something must be wrong.")
        self.tgt.filter_and_reorder(tgt_indices)
        self.tgt_sizes = np.array(self.tgt.sizes)
        assert self.src.utt_ids == self.tgt.utt_ids

    def _match_src_text(self):
        """Makes utterances in src and text the same order in terms of
        their utt_ids. Removes those that are only present in one of them."""
        assert self.text is not None
        if self.src.utt_ids == self.text.utt_ids:
            return False
        text_utt_ids_set = set(self.text.utt_ids)
        src_indices = [
            i for i, id in enumerate(self.src.utt_ids)
            if id in text_utt_ids_set
        ]
        self.src.filter_and_reorder(src_indices)
        self.src_sizes = np.array(self.src.sizes)
        try:
            text_indices = list(map(self.text.utt_ids.index, self.src.utt_ids))
        except ValueError:
            raise ValueError(
                "Unable to find some utt_id(s) in text. which is unlikely to happen. "
                "Something must be wrong.")
        self.text.filter_and_reorder(text_indices)
        assert self.src.utt_ids == self.text.utt_ids
        return True

    def get_batch_shapes(self):
        return self.buckets

    def __getitem__(self, index):
        tgt_item = self.tgt[index] if self.tgt is not None else None
        text_item = self.text[index][1] if self.text is not None else None
        src_item = self.src[index]
        example = {
            "id": index,
            "utt_id": self.src.utt_ids[index],
            "source": src_item,
            "target": tgt_item,
            "text": text_item,
        }
        return example

    def __len__(self):
        return len(self.src)

    def collater(self, samples):
        """Merge a list of samples to form a mini-batch.

        Args:
            samples (List[dict]): samples to collate

        Returns:
            dict: a mini-batch with the following keys:

                - `id` (LongTensor): example IDs in the original input order
                - `utt_id` (List[str]): list of utterance ids
                - `nsentences` (int): batch size
                - `ntokens` (int): total number of tokens in the batch
                - `net_input` (dict): the input to the Model, containing keys:

                  - `src_tokens` (FloatTensor): a padded 3D Tensor of features in
                    the source of shape `(bsz, src_len, feat_dim)`.
                  - `src_lengths` (IntTensor): 1D Tensor of the unpadded
                    lengths of each source sequence of shape `(bsz)`

                - `target` (ChainGraphBatch): an instance representing a batch of
                    numerator graphs
                - `text` (List[str]): list of original text
        """
        return collate(samples, src_bucketed=(self.buckets is not None))

    def num_tokens(self, index):
        """Return the number of frames in a sample. This value is used to
        enforce ``--max-tokens`` during batching."""
        return self.src_sizes[index]

    def size(self, index):
        """Return an example's size as a float or tuple. This value is used when
        filtering a dataset with ``--max-positions``."""
        return (self.src_sizes[index],
                self.tgt_sizes[index] if self.tgt_sizes is not None else 0)

    def ordered_indices(self):
        """Return an ordered list of indices. Batches will be constructed based
        on this order."""
        if self.shuffle:
            indices = np.random.permutation(len(self))
        else:
            indices = np.arange(len(self))
        if self.buckets is None:
            # sort by target length, then source length
            if self.tgt_sizes is not None:
                indices = indices[np.argsort(self.tgt_sizes[indices],
                                             kind="mergesort")]
            return indices[np.argsort(self.src_sizes[indices],
                                      kind="mergesort")]
        else:
            # sort by bucketed_num_tokens, which is padded_src_len
            return indices[np.argsort(self.bucketed_num_tokens[indices],
                                      kind="mergesort")]

    @property
    def supports_prefetch(self):
        return getattr(self.src, "supports_prefetch", False)

    def prefetch(self, indices):
        """Only prefetch src."""
        self.src.prefetch(indices)

    def set_epoch(self, epoch):
        super().set_epoch(epoch)
        self.epoch = epoch
        if hasattr(self.src, "set_epoch"):
            self.src.set_epoch(epoch)
        if self.tgt is not None and hasattr(self.tgt, "set_epoch"):
            self.tgt.set_epoch(epoch)
Beispiel #2
0
class AsrDataset(FairseqDataset):
    """
    A pair of torch.utils.data.Datasets.

    Args:
        src (torch.utils.data.Dataset): source dataset to wrap
        src_sizes (List[int]): source sentence lengths
        tgt (torch.utils.data.Dataset, optional): target dataset to wrap
        tgt_sizes (List[int], optional): target sentence lengths
        dictionary (~fairseq.data.Dictionary, optional): target vocabulary
        left_pad_source (bool, optional): pad source tensors on the left side
            (default: True).
        left_pad_target (bool, optional): pad target tensors on the left side
            (default: False).
        shuffle (bool, optional): shuffle dataset elements before batching
            (default: True).
        input_feeding (bool, optional): create a shifted version of the targets
            to be passed into the model for teacher forcing (default: True).
        num_buckets (int, optional): if set to a value greater than 0, then
            batches will be bucketed into the given number of batch shapes.
    """
    def __init__(
        self,
        src,
        src_sizes,
        tgt=None,
        tgt_sizes=None,
        dictionary=None,
        left_pad_source=False,
        left_pad_target=False,
        shuffle=True,
        input_feeding=True,
        num_buckets=0,
    ):
        self.src = src
        self.tgt = tgt
        self.src_sizes = np.array(src_sizes)
        self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None
        self.dictionary = dictionary
        self.left_pad_source = left_pad_source
        self.left_pad_target = left_pad_target
        self.shuffle = shuffle
        self.input_feeding = input_feeding
        if self.tgt is not None:
            self._match_src_tgt()

        if num_buckets > 0:
            from espresso.data import FeatBucketPadLengthDataset, TextBucketPadLengthDataset
            self.src = FeatBucketPadLengthDataset(
                self.src,
                sizes=self.src_sizes,
                num_buckets=num_buckets,
                pad_idx=0.0,
                left_pad=False,
            )
            self.src_sizes = self.src.sizes
            logger.info('bucketing source lengths: {}'.format(
                list(self.src.buckets)))
            if self.tgt is not None:
                self.tgt = TextBucketPadLengthDataset(
                    self.tgt,
                    sizes=self.tgt_sizes,
                    num_buckets=num_buckets,
                    pad_idx=self.dictionary.pad(),
                    left_pad=False,
                )
                self.tgt_sizes = self.tgt.sizes
                logger.info('bucketing target lengths: {}'.format(
                    list(self.tgt.buckets)))

            # determine bucket sizes using self.num_tokens, which will return
            # the padded lengths (thanks to FeatBucketPadLengthDataset)
            num_tokens = np.vectorize(self.num_tokens, otypes=[np.long])
            self.bucketed_num_tokens = num_tokens(np.arange(len(self.src)))
            self.buckets = [
                (None, num_tokens)
                for num_tokens in np.unique(self.bucketed_num_tokens)
            ]
        else:
            self.buckets = None

    def _match_src_tgt(self):
        """Makes utterances in src and tgt the same order in terms of
        their utt_ids. Removes those that are only present in one of them."""
        assert self.tgt is not None
        if self.src.utt_ids == self.tgt.utt_ids:
            return
        tgt_utt_ids_set = set(self.tgt.utt_ids)
        src_indices = [
            i for i, id in enumerate(self.src.utt_ids) if id in tgt_utt_ids_set
        ]
        self.src.filter_and_reorder(src_indices)
        self.src_sizes = np.array(self.src.sizes)
        try:
            tgt_indices = list(map(self.tgt.utt_ids.index, self.src.utt_ids))
        except ValueError:
            raise ValueError(
                'Unable to find some utt_id(s) in tgt. which is unlikely to happen. '
                'Something must be wrong.')
        self.tgt.filter_and_reorder(tgt_indices)
        self.tgt_sizes = np.array(self.tgt.sizes)
        assert self.src.utt_ids == self.tgt.utt_ids

    def get_batch_shapes(self):
        return self.buckets

    def __getitem__(self, index):
        tgt_item = self.tgt[index][0] if self.tgt is not None else None
        raw_text_item = self.tgt[index][1] if self.tgt is not None else None
        src_item = self.src[index]
        example = {
            'id': index,
            'utt_id': self.src.utt_ids[index],
            'source': src_item,
            'target': tgt_item,
            'target_raw_text': raw_text_item,
        }
        return example

    def __len__(self):
        return len(self.src)

    def collater(self, samples):
        """Merge a list of samples to form a mini-batch.

        Args:
            samples (List[dict]): samples to collate

        Returns:
            dict: a mini-batch with the following keys:

                - `id` (LongTensor): example IDs in the original input order
                - `utt_id` (List[str]): list of utterance ids
                - `nsentences` (int): batch size
                - `ntokens` (int): total number of tokens in the batch
                - `net_input` (dict): the input to the Model, containing keys:

                  - `src_tokens` (FloatTensor): a padded 3D Tensor of features in
                    the source of shape `(bsz, src_len, feat_dim)`. Padding will
                    appear on the left if *left_pad_source* is ``True``.
                  - `src_lengths` (IntTensor): 1D Tensor of the unpadded
                    lengths of each source sequence of shape `(bsz)`
                  - `prev_output_tokens` (LongTensor): a padded 2D Tensor of
                    tokens in the target sentence, shifted right by one
                    position for teacher forcing, of shape `(bsz, tgt_len)`.
                    This key will not be present if *input_feeding* is
                    ``False``.  Padding will appear on the left if
                    *left_pad_target* is ``True``.

                - `target` (LongTensor): a padded 2D Tensor of tokens in the
                  target sentence of shape `(bsz, tgt_len)`. Padding will appear
                  on the left if *left_pad_target* is ``True``.
                - `target_raw_text` (List[str]): list of original text
        """
        return collate(
            samples,
            pad_idx=self.dictionary.pad(),
            eos_idx=self.dictionary.eos(),
            left_pad_source=self.left_pad_source,
            left_pad_target=self.left_pad_target,
            input_feeding=self.input_feeding,
            src_bucketed=(self.buckets is not None),
        )

    def num_tokens(self, index):
        """Return the number of frames in a sample. This value is used to
        enforce ``--max-tokens`` during batching."""
        return self.src_sizes[index]

    def size(self, index):
        """Return an example's size as a float or tuple. This value is used when
        filtering a dataset with ``--max-positions``."""
        return (self.src_sizes[index],
                self.tgt_sizes[index] if self.tgt_sizes is not None else 0)

    def ordered_indices(self):
        """Return an ordered list of indices. Batches will be constructed based
        on this order."""
        if self.shuffle:
            indices = np.random.permutation(len(self))
        else:
            indices = np.arange(len(self))
        if self.buckets is None:
            # sort by target length, then source length
            if self.tgt_sizes is not None:
                indices = indices[np.argsort(self.tgt_sizes[indices],
                                             kind='mergesort')]
            return indices[np.argsort(self.src_sizes[indices],
                                      kind='mergesort')]
        else:
            # sort by bucketed_num_tokens, which is padded_src_len
            return indices[np.argsort(self.bucketed_num_tokens[indices],
                                      kind='mergesort')]

    @property
    def supports_prefetch(self):
        return getattr(self.src, 'supports_prefetch', False)

    def prefetch(self, indices):
        """Only prefetch src."""
        self.src.prefetch(indices)

    def set_epoch(self, epoch):
        super().set_epoch(epoch)
        if hasattr(self.src, 'set_epoch'):
            self.src.set_epoch(epoch)
        if self.tgt is not None and hasattr(self.tgt, 'set_epoch'):
            self.tgt.set_epoch(epoch)
Beispiel #3
0
class AsrDataset(FairseqDataset):
    """
    A pair of torch.utils.data.Datasets.

    Args:
        src (torch.utils.data.Dataset): source dataset to wrap
        src_sizes (List[int]): source sentence lengths
        tgt (torch.utils.data.Dataset, optional): target dataset to wrap
        tgt_sizes (List[int], optional): target sentence lengths
        dictionary (~fairseq.data.Dictionary, optional): target vocabulary
        left_pad_source (bool, optional): pad source tensors on the left side
            (default: True).
        left_pad_target (bool, optional): pad target tensors on the left side
            (default: False).
        shuffle (bool, optional): shuffle dataset elements before batching
            (default: True).
        input_feeding (bool, optional): create a shifted version of the targets
            to be passed into the model for teacher forcing (default: True).
        constraints (Tensor, optional): 2d tensor with a concatenated, zero-
            delimited list of constraints for each sentence.
        num_buckets (int, optional): if set to a value greater than 0, then
            batches will be bucketed into the given number of batch shapes.
        src_lang_id (int, optional): source language ID, if set, the collated batch
            will contain a field "src_lang_id" in "net_input" which indicates the
            source language of the samples.
        tgt_lang_id (int, optional): target language ID, if set, the collated batch
            will contain a field "tgt_lang_id" which indicates the target language
            of the samples.
        pad_to_multiple (int, optional): pad src/tgt lengths to a multiple of this value
    """
    def __init__(
        self,
        src,
        src_sizes,
        tgt=None,
        tgt_sizes=None,
        dictionary=None,
        left_pad_source=False,
        left_pad_target=False,
        shuffle=True,
        input_feeding=True,
        constraints=None,
        num_buckets=0,
        src_lang_id=None,
        tgt_lang_id=None,
        pad_to_multiple=1,
    ):
        self.src = src
        self.tgt = tgt
        self.src_sizes = np.array(src_sizes)
        self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None
        assert dictionary is not None
        self.dictionary = dictionary
        self.left_pad_source = left_pad_source
        self.left_pad_target = left_pad_target
        self.shuffle = shuffle
        self.input_feeding = input_feeding
        self.constraints = constraints
        self.src_lang_id = src_lang_id
        self.tgt_lang_id = tgt_lang_id
        if self.tgt is not None:
            self._match_src_tgt()
        self.sizes = (np.vstack((self.src_sizes, self.tgt_sizes)).T
                      if self.tgt_sizes is not None else self.src_sizes)

        if num_buckets > 0:
            from espresso.data import FeatBucketPadLengthDataset, TextBucketPadLengthDataset

            self.src = FeatBucketPadLengthDataset(
                self.src,
                sizes=self.src_sizes,
                num_buckets=num_buckets,
                pad_idx=0.0,
                left_pad=False,
            )
            self.src_sizes = self.src.sizes
            logger.info("bucketing source lengths: {}".format(
                list(self.src.buckets)))
            if self.tgt is not None:
                self.tgt = TextBucketPadLengthDataset(
                    self.tgt,
                    sizes=self.tgt_sizes,
                    num_buckets=num_buckets,
                    pad_idx=self.dictionary.pad(),
                    left_pad=False,
                )
                self.tgt_sizes = self.tgt.sizes
                logger.info("bucketing target lengths: {}".format(
                    list(self.tgt.buckets)))

            # determine bucket sizes using self.num_tokens, which will return
            # the padded lengths (thanks to FeatBucketPadLengthDataset)
            num_tokens = np.vectorize(self.num_tokens, otypes=[np.long])
            self.bucketed_num_tokens = num_tokens(np.arange(len(self.src)))
            self.buckets = [
                (None, num_tokens)
                for num_tokens in np.unique(self.bucketed_num_tokens)
            ]
        else:
            self.buckets = None
        self.pad_to_multiple = pad_to_multiple

    def _match_src_tgt(self):
        """Makes utterances in src and tgt the same order in terms of
        their utt_ids. Removes those that are only present in one of them."""
        assert self.tgt is not None
        if self.src.utt_ids == self.tgt.utt_ids:
            return
        tgt_utt_ids_set = set(self.tgt.utt_ids)
        src_indices = [
            i for i, id in enumerate(self.src.utt_ids) if id in tgt_utt_ids_set
        ]
        self.src.filter_and_reorder(src_indices)
        self.src_sizes = np.array(self.src.sizes)
        try:
            tgt_indices = list(map(self.tgt.utt_ids.index, self.src.utt_ids))
        except ValueError:
            raise ValueError(
                "Unable to find some utt_id(s) in tgt. which is unlikely to happen. "
                "Something must be wrong.")
        self.tgt.filter_and_reorder(tgt_indices)
        self.tgt_sizes = np.array(self.tgt.sizes)
        assert self.src.utt_ids == self.tgt.utt_ids

    def get_batch_shapes(self):
        return self.buckets

    def __getitem__(self, index):
        tgt_item = self.tgt[index][0] if self.tgt is not None else None
        raw_text_item = self.tgt[index][1] if self.tgt is not None else None
        src_item = self.src[index]
        example = {
            "id": index,
            "utt_id": self.src.utt_ids[index],
            "source": src_item,
            "target": tgt_item,
            "target_raw_text": raw_text_item,
        }
        if self.constraints is not None:
            example["constraints"] = self.constraints[index]
        return example

    def __len__(self):
        return len(self.src)

    def collater(self, samples, pad_to_length=None):
        """Merge a list of samples to form a mini-batch.

        Args:
            samples (List[dict]): samples to collate
            pad_to_length (dict, optional): a dictionary of
                {"source": source_pad_to_length, "target": target_pad_to_length}
                to indicate the max length to pad to in source and target respectively.

        Returns:
            dict: a mini-batch with the following keys:

                - `id` (LongTensor): example IDs in the original input order
                - `utt_id` (List[str]): list of utterance ids
                - `nsentences` (int): batch size
                - `ntokens` (int): total number of tokens in the batch
                - `net_input` (dict): the input to the Model, containing keys:

                  - `src_tokens` (FloatTensor): a padded 3D Tensor of features in
                    the source of shape `(bsz, src_len, feat_dim)`. Padding will
                    appear on the left if *left_pad_source* is ``True``.
                  - `src_lengths` (IntTensor): 1D Tensor of the unpadded
                    lengths of each source sequence of shape `(bsz)`
                  - `prev_output_tokens` (LongTensor): a padded 2D Tensor of
                    tokens in the target sentence, shifted right by one
                    position for teacher forcing, of shape `(bsz, tgt_len)`.
                    This key will not be present if *input_feeding* is
                    ``False``.  Padding will appear on the left if
                    *left_pad_target* is ``True``.
                  - `src_lang_id` (LongTensor): a long Tensor which contains source
                    language IDs of each sample in the batch

                - `target` (LongTensor): a padded 2D Tensor of tokens in the
                  target sentence of shape `(bsz, tgt_len)`. Padding will appear
                  on the left if *left_pad_target* is ``True``.
                - `target_raw_text` (List[str]): list of original text
                - `tgt_lang_id` (LongTensor): a long Tensor which contains target language
                  IDs of each sample in the batch
        """
        res = collate(
            samples,
            pad_idx=self.dictionary.pad(),
            eos_idx=self.dictionary.eos(),
            left_pad_source=self.left_pad_source,
            left_pad_target=self.left_pad_target,
            input_feeding=self.input_feeding,
            pad_to_length=pad_to_length,
            pad_to_multiple=self.pad_to_multiple,
            src_bucketed=(self.buckets is not None),
        )
        if self.src_lang_id is not None or self.tgt_lang_id is not None:
            src_tokens = res["net_input"]["src_tokens"]
            bsz = src_tokens.size(0)
            if self.src_lang_id is not None:
                res["net_input"]["src_lang_id"] = (torch.LongTensor(
                    [[self.src_lang_id]]).expand(bsz, 1).to(src_tokens))
            if self.tgt_lang_id is not None:
                res["tgt_lang_id"] = (torch.LongTensor(
                    [[self.tgt_lang_id]]).expand(bsz, 1).to(src_tokens))
        return res

    def num_tokens(self, index):
        """Return the number of frames in a sample. This value is used to
        enforce ``--max-tokens`` during batching."""
        return self.src_sizes[index]

    def num_tokens_vec(self, indices):
        """Return the number of tokens for a set of positions defined by indices.
        This value is used to enforce ``--max-tokens`` during batching."""
        sizes = self.src_sizes[indices]
        return sizes

    def size(self, index):
        """Return an example's size as a float or tuple. This value is used when
        filtering a dataset with ``--max-positions``."""
        return (
            self.src_sizes[index],
            self.tgt_sizes[index] if self.tgt_sizes is not None else 0,
        )

    def ordered_indices(self):
        """Return an ordered list of indices. Batches will be constructed based
        on this order."""
        if self.shuffle:
            indices = np.random.permutation(len(self)).astype(np.int64)
        else:
            indices = np.arange(len(self), dtype=np.int64)
        if self.buckets is None:
            # sort by target length, then source length
            if self.tgt_sizes is not None:
                indices = indices[np.argsort(self.tgt_sizes[indices],
                                             kind="mergesort")]
            return indices[np.argsort(self.src_sizes[indices],
                                      kind="mergesort")]
        else:
            # sort by bucketed_num_tokens, which is padded_src_len
            return indices[np.argsort(self.bucketed_num_tokens[indices],
                                      kind="mergesort")]

    @property
    def supports_prefetch(self):
        return getattr(self.src, "supports_prefetch", False)

    def prefetch(self, indices):
        """Only prefetch src."""
        self.src.prefetch(indices)

    def filter_indices_by_size(self, indices, max_sizes):
        """Filter a list of sample indices. Remove those that are longer
            than specified in max_sizes.

        Args:
            indices (np.array): original array of sample indices
            max_sizes (int or list[int] or tuple[int]): max sample size,
                can be defined separately for src and tgt (then list or tuple)

        Returns:
            np.array: filtered sample array
            list: list of removed indices
        """
        return data_utils.filter_paired_dataset_indices_by_size(
            self.src_sizes,
            self.tgt_sizes,
            indices,
            max_sizes,
        )

    @property
    def supports_fetch_outside_dataloader(self):
        """Whether this dataset supports fetching outside the workers of the dataloader."""
        return False

    @property
    def can_reuse_epoch_itr_across_epochs(self):
        return False  # to avoid running out of CPU RAM

    def set_epoch(self, epoch):
        super().set_epoch(epoch)
        if hasattr(self.src, "set_epoch"):
            self.src.set_epoch(epoch)
        if self.tgt is not None and hasattr(self.tgt, "set_epoch"):
            self.tgt.set_epoch(epoch)
Beispiel #4
0
class AsrXentDataset(FairseqDataset):
    """
    A pair of torch.utils.data.Datasets.

    Args:
        src (torch.utils.data.Dataset): source dataset to wrap
        src_sizes (List[int]): source sentence lengths
        tgt (espresso.data.AliScpCachedDataset, optional): target alignment dataset to wrap
        tgt_sizes (List[int], optional): target sizes (num of states in the numerator graph)
        tgt_vocab_size (int, optional): used for setting padding index
        text  (torch.utils.data.Dataset, optional): text dataset to wrap
        shuffle (bool, optional): shuffle dataset elements before batching
            (default: True).
        num_buckets (int, optional): if set to a value greater than 0, then
            batches will be bucketed into the given number of batch shapes.
        pad_to_multiple (int, optional): pad src/tgt lengths to a multiple of this value
        seed (int, optional): random seed for generating a chunk from an utterance.
        chunk_width (int, optional): chunk width for chunk-wise training.
        chunk_left_context (int, optional): number of frames appended to the left of a chunk.
        chunk_right_context (int, optional): number of frames appended to the right of a chunk.
        label_delay (int, optional): offset of the alignments as prediction labels. Can be
            useful in archs such as asymmetric convolution, unidirectional LSTM, etc.
        random_chunking (bool, optional): wether do random chunking from utterance, or sequntially
            obtain chunks within each utterance. True for train and False for valid/test data.
    """
    def __init__(
        self,
        src,
        src_sizes,
        tgt: Optional[AliScpCachedDataset] = None,
        tgt_sizes=None,
        text=None,
        shuffle=True,
        num_buckets=0,
        pad_to_multiple=1,
        seed=1,
        chunk_width=None,
        chunk_left_context=None,
        chunk_right_context=None,
        label_delay=0,
        random_chunking=True,
    ):
        self.src = src
        self.tgt = tgt
        self.src_sizes = np.array(src_sizes)
        self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None
        self.text = text
        self.shuffle = shuffle
        self.seed = seed
        self.epoch = 1
        assert chunk_width is None or chunk_width > 0
        self.chunk_width = chunk_width
        assert chunk_left_context >= 0 and chunk_right_context >= 0
        self.chunk_left_context = chunk_left_context
        self.chunk_right_context = chunk_right_context
        assert (label_delay < 0 and -label_delay <= chunk_right_context) or \
            (label_delay >= 0 and (chunk_width is None or label_delay < chunk_width))
        self.label_delay = label_delay
        self.random_chunking = random_chunking
        if self.tgt is not None:
            self._match_src_tgt()
        if self.text is not None:
            changed = self._match_src_text()
            if self.tgt is not None and changed:
                self._match_src_tgt()
        self.sizes = np.vstack(
            (self.src_sizes, self.tgt_sizes
             )).T if self.tgt_sizes is not None else self.src_sizes

        if chunk_width is not None:
            # remove those whose lengths are shorter than chunk_size
            indices = np.flatnonzero(self.src.sizes >= chunk_width)
            if len(indices) < self.src.size:
                logger.warning(
                    "Removing {} examples whose lengths are shorter than chunk_size={}"
                    .format(self.src.size - len(indices), chunk_width))
                self.src.filter_and_reorder(indices)
                if self.tgt is not None:
                    self.tgt.filter_and_reorder(indices)
                if self.text is not None:
                    self.text.filter_and_reorder(indices)
                logger.warning("Done removal. {} examples remaining".format(
                    len(indices)))

        if num_buckets > 0:
            from fairseq.data import BucketPadLengthDataset
            from espresso.data import FeatBucketPadLengthDataset
            self.src = FeatBucketPadLengthDataset(
                self.src,
                sizes=self.src_sizes,
                num_buckets=num_buckets,
                pad_idx=0.0,
                left_pad=False,
            )
            self.src_sizes = self.src.sizes
            logger.info("bucketing source lengths: {}".format(
                list(self.src.buckets)))
            if self.tgt is not None:
                self.tgt = BucketPadLengthDataset(
                    self.tgt,
                    sizes=self.tgt_sizes,
                    num_buckets=num_buckets,
                    pad_idx=self.dictionary.pad(),
                    left_pad=False,
                )
                self.tgt_sizes = self.tgt.sizes
                logger.info("bucketing target lengths: {}".format(
                    list(self.tgt.buckets)))

            # determine bucket sizes using self.num_tokens, which will return
            # the padded lengths (thanks to FeatBucketPadLengthDataset)
            num_tokens = np.vectorize(self.num_tokens, otypes=[np.long])
            self.bucketed_num_tokens = num_tokens(np.arange(len(self.src)))
            self.buckets = [
                (None, num_tokens)
                for num_tokens in np.unique(self.bucketed_num_tokens)
            ]
        else:
            self.buckets = None
        self.pad_to_multiple = pad_to_multiple

    def _match_src_tgt(self):
        """Makes utterances in src and tgt the same order in terms of
        their utt_ids. Removes those that are only present in one of them."""
        assert self.tgt is not None
        if self.src.utt_ids == self.tgt.utt_ids:
            assert np.all(self.src.sizes == self.tgt.sizes
                          ), "frame and alignment lengths mismatch"
            return
        tgt_utt_ids_set = set(self.tgt.utt_ids)
        src_indices = [
            i for i, id in enumerate(self.src.utt_ids) if id in tgt_utt_ids_set
        ]
        self.src.filter_and_reorder(src_indices)
        self.src_sizes = np.array(self.src.sizes)
        try:
            tgt_indices = list(map(self.tgt.utt_ids.index, self.src.utt_ids))
        except ValueError:
            raise ValueError(
                "Unable to find some utt_id(s) in tgt. which is unlikely to happen. "
                "Something must be wrong.")
        self.tgt.filter_and_reorder(tgt_indices)
        self.tgt_sizes = np.array(self.tgt.sizes)
        assert self.src.utt_ids == self.tgt.utt_ids
        assert np.all(self.src.sizes ==
                      self.tgt.sizes), "frame and alignment lengths mismatch"

    def _match_src_text(self):
        """Makes utterances in src and text the same order in terms of
        their utt_ids. Removes those that are only present in one of them."""
        assert self.text is not None
        if self.src.utt_ids == self.text.utt_ids:
            return False
        text_utt_ids_set = set(self.text.utt_ids)
        src_indices = [
            i for i, id in enumerate(self.src.utt_ids)
            if id in text_utt_ids_set
        ]
        self.src.filter_and_reorder(src_indices)
        self.src_sizes = np.array(self.src.sizes)
        try:
            text_indices = list(map(self.text.utt_ids.index, self.src.utt_ids))
        except ValueError:
            raise ValueError(
                "Unable to find some utt_id(s) in text. which is unlikely to happen. "
                "Something must be wrong.")
        self.text.filter_and_reorder(text_indices)
        assert self.src.utt_ids == self.text.utt_ids
        return True

    def get_batch_shapes(self):
        return self.buckets

    def __getitem__(self, index):
        tgt_item = self.tgt[index] if self.tgt is not None else None
        text_item = self.text[index][1] if self.text is not None else None
        src_item = self.src[index]
        example = {
            "id": index,
            "utt_id": self.src.utt_ids[index],
            "source": src_item,
            "target": tgt_item,
            "text": text_item,
        }
        return example

    def __len__(self):
        return len(self.src)

    def collater(self, samples, pad_to_length=None):
        """Merge a list of samples to form a mini-batch.

        Args:
            samples (List[dict]): samples to collate
            pad_to_length (dict, optional): a dictionary of
                {'source': source_pad_to_length, 'target': target_pad_to_length}
                to indicate the max length to pad to in source and target respectively.


        Returns:
            dict: a mini-batch with the following keys:

                - `id` (LongTensor): example IDs in the original input order
                - `utt_id` (List[str]): list of utterance ids
                - `nsentences` (int): batch size
                - `ntokens` (int): total number of tokens in the batch
                - `net_input` (dict): the input to the Model, containing keys:

                  - `src_tokens` (FloatTensor): a padded 3D Tensor of features in
                    the source of shape `(bsz, src_len, feat_dim)`.
                  - `src_lengths` (IntTensor): 1D Tensor of the unpadded
                    lengths of each source sequence of shape `(bsz)`

                - `target` (LongTensor): a padded 2D Tensor of indices in the
                  target alignments of shape `(bsz, tgt_len)`
                - `text` (List[str]): list of original text
        """
        # pad_idx=-100 matches the default in criterions
        return collate(
            samples,
            pad_idx=-100,
            chunk_width=self.chunk_width,
            chunk_left_context=self.chunk_left_context,
            chunk_right_context=self.chunk_right_context,
            label_delay=self.label_delay,
            seed=self.seed,
            epoch=self.epoch,
            pad_to_length=pad_to_length,
            pad_to_multiple=self.pad_to_multiple,
            src_bucketed=(self.buckets is not None),
            random_chunking=self.random_chunking,
        )

    def num_tokens(self, index):
        """Return the number of frames in a sample. This value is used to
        enforce ``--max-tokens`` during batching."""
        if self.chunk_width is None:
            return self.src_sizes[index]
        return self.chunk_width + self.chunk_left_context + self.chunk_right_context

    def size(self, index):
        """Return an example's size as a float or tuple. This value is used when
        filtering a dataset with ``--max-positions``."""
        return (self.src_sizes[index],
                self.tgt_sizes[index] if self.tgt_sizes is not None else 0)

    def ordered_indices(self):
        """Return an ordered list of indices. Batches will be constructed based
        on this order."""
        if self.shuffle:
            indices = np.random.permutation(len(self)).astype(np.int64)
        else:
            indices = np.arange(len(self), dtype=np.int64)
        if self.buckets is None:
            # sort by target length, then source length
            if self.tgt_sizes is not None:
                indices = indices[np.argsort(self.tgt_sizes[indices],
                                             kind="mergesort")]
            return indices[np.argsort(self.src_sizes[indices],
                                      kind="mergesort")]
        else:
            # sort by bucketed_num_tokens, which is padded_src_len
            return indices[np.argsort(self.bucketed_num_tokens[indices],
                                      kind="mergesort")]

    @property
    def supports_prefetch(self):
        return getattr(self.src, "supports_prefetch", False)

    def prefetch(self, indices):
        self.src.prefetch(indices)
        if self.tgt is not None:
            self.tgt.prefetch(indices)

    def filter_indices_by_size(self, indices, max_sizes):
        """ Filter a list of sample indices. Remove those that are longer
            than specified in max_sizes.

        Args:
            indices (np.array): original array of sample indices
            max_sizes (int or list[int] or tuple[int]): max sample size,
                can be defined separately for src and tgt (then list or tuple)

        Returns:
            np.array: filtered sample array
            list: list of removed indices
        """
        return data_utils.filter_paired_dataset_indices_by_size(
            self.src_sizes,
            self.tgt_sizes,
            indices,
            max_sizes,
        )

    @property
    def can_reuse_epoch_itr_across_epochs(self):
        return False  # to avoid running out of CPU RAM

    def set_epoch(self, epoch):
        super().set_epoch(epoch)
        self.epoch = epoch
        if hasattr(self.src, "set_epoch"):
            self.src.set_epoch(epoch)
        if self.tgt is not None and hasattr(self.tgt, "set_epoch"):
            self.tgt.set_epoch(epoch)