class AsrChainDataset(FairseqDataset): """ A pair of torch.utils.data.Datasets. Args: src (torch.utils.data.Dataset): source dataset to wrap src_sizes (List[int]): source sentence lengths tgt (espresso.data.NumeratorGraphDataset, optional): target numerator graph dataset to wrap tgt_sizes (List[int], optional): target sizes (num of states in the numerator graph) text (torch.utils.data.Dataset, optional): text dataset to wrap shuffle (bool, optional): shuffle dataset elements before batching (default: True). num_buckets (int, optional): if set to a value greater than 0, then batches will be bucketed into the given number of batch shapes. """ def __init__( self, src, src_sizes, tgt=None, tgt_sizes=None, text=None, shuffle=True, num_buckets=0, ): self.src = src self.tgt = tgt self.src_sizes = np.array(src_sizes) self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None self.text = text self.shuffle = shuffle self.epoch = 1 num_before_matching = len(self.src.utt_ids) if self.tgt is not None: self._match_src_tgt() if self.text is not None: changed = self._match_src_text() if self.tgt is not None and changed: self._match_src_tgt() num_after_matching = len(self.src.utt_ids) num_removed = num_before_matching - num_after_matching if num_removed > 0: logger.warning( "Removed {} examples due to empty numerator graphs or missing entries, " "{} remaining".format(num_removed, num_after_matching)) if num_buckets > 0: from espresso.data import FeatBucketPadLengthDataset self.src = FeatBucketPadLengthDataset( self.src, sizes=self.src_sizes, num_buckets=num_buckets, pad_idx=0.0, left_pad=False, ) self.src_sizes = self.src.sizes logger.info("bucketing source lengths: {}".format( list(self.src.buckets))) # determine bucket sizes using self.num_tokens, which will return # the padded lengths (thanks to FeatBucketPadLengthDataset) num_tokens = np.vectorize(self.num_tokens, otypes=[np.long]) self.bucketed_num_tokens = num_tokens(np.arange(len(self.src))) self.buckets = [ (None, num_tokens) for num_tokens in np.unique(self.bucketed_num_tokens) ] else: self.buckets = None def _match_src_tgt(self): """Makes utterances in src and tgt the same order in terms of their utt_ids. Removes those that are only present in one of them.""" assert self.tgt is not None if self.src.utt_ids == self.tgt.utt_ids: return tgt_utt_ids_set = set(self.tgt.utt_ids) src_indices = [ i for i, id in enumerate(self.src.utt_ids) if id in tgt_utt_ids_set ] self.src.filter_and_reorder(src_indices) self.src_sizes = np.array(self.src.sizes) try: tgt_indices = list(map(self.tgt.utt_ids.index, self.src.utt_ids)) except ValueError: raise ValueError( "Unable to find some utt_id(s) in tgt. which is unlikely to happen. " "Something must be wrong.") self.tgt.filter_and_reorder(tgt_indices) self.tgt_sizes = np.array(self.tgt.sizes) assert self.src.utt_ids == self.tgt.utt_ids def _match_src_text(self): """Makes utterances in src and text the same order in terms of their utt_ids. Removes those that are only present in one of them.""" assert self.text is not None if self.src.utt_ids == self.text.utt_ids: return False text_utt_ids_set = set(self.text.utt_ids) src_indices = [ i for i, id in enumerate(self.src.utt_ids) if id in text_utt_ids_set ] self.src.filter_and_reorder(src_indices) self.src_sizes = np.array(self.src.sizes) try: text_indices = list(map(self.text.utt_ids.index, self.src.utt_ids)) except ValueError: raise ValueError( "Unable to find some utt_id(s) in text. which is unlikely to happen. " "Something must be wrong.") self.text.filter_and_reorder(text_indices) assert self.src.utt_ids == self.text.utt_ids return True def get_batch_shapes(self): return self.buckets def __getitem__(self, index): tgt_item = self.tgt[index] if self.tgt is not None else None text_item = self.text[index][1] if self.text is not None else None src_item = self.src[index] example = { "id": index, "utt_id": self.src.utt_ids[index], "source": src_item, "target": tgt_item, "text": text_item, } return example def __len__(self): return len(self.src) def collater(self, samples): """Merge a list of samples to form a mini-batch. Args: samples (List[dict]): samples to collate Returns: dict: a mini-batch with the following keys: - `id` (LongTensor): example IDs in the original input order - `utt_id` (List[str]): list of utterance ids - `nsentences` (int): batch size - `ntokens` (int): total number of tokens in the batch - `net_input` (dict): the input to the Model, containing keys: - `src_tokens` (FloatTensor): a padded 3D Tensor of features in the source of shape `(bsz, src_len, feat_dim)`. - `src_lengths` (IntTensor): 1D Tensor of the unpadded lengths of each source sequence of shape `(bsz)` - `target` (ChainGraphBatch): an instance representing a batch of numerator graphs - `text` (List[str]): list of original text """ return collate(samples, src_bucketed=(self.buckets is not None)) def num_tokens(self, index): """Return the number of frames in a sample. This value is used to enforce ``--max-tokens`` during batching.""" return self.src_sizes[index] def size(self, index): """Return an example's size as a float or tuple. This value is used when filtering a dataset with ``--max-positions``.""" return (self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0) def ordered_indices(self): """Return an ordered list of indices. Batches will be constructed based on this order.""" if self.shuffle: indices = np.random.permutation(len(self)) else: indices = np.arange(len(self)) if self.buckets is None: # sort by target length, then source length if self.tgt_sizes is not None: indices = indices[np.argsort(self.tgt_sizes[indices], kind="mergesort")] return indices[np.argsort(self.src_sizes[indices], kind="mergesort")] else: # sort by bucketed_num_tokens, which is padded_src_len return indices[np.argsort(self.bucketed_num_tokens[indices], kind="mergesort")] @property def supports_prefetch(self): return getattr(self.src, "supports_prefetch", False) def prefetch(self, indices): """Only prefetch src.""" self.src.prefetch(indices) def set_epoch(self, epoch): super().set_epoch(epoch) self.epoch = epoch if hasattr(self.src, "set_epoch"): self.src.set_epoch(epoch) if self.tgt is not None and hasattr(self.tgt, "set_epoch"): self.tgt.set_epoch(epoch)
class AsrDataset(FairseqDataset): """ A pair of torch.utils.data.Datasets. Args: src (torch.utils.data.Dataset): source dataset to wrap src_sizes (List[int]): source sentence lengths tgt (torch.utils.data.Dataset, optional): target dataset to wrap tgt_sizes (List[int], optional): target sentence lengths dictionary (~fairseq.data.Dictionary, optional): target vocabulary left_pad_source (bool, optional): pad source tensors on the left side (default: True). left_pad_target (bool, optional): pad target tensors on the left side (default: False). shuffle (bool, optional): shuffle dataset elements before batching (default: True). input_feeding (bool, optional): create a shifted version of the targets to be passed into the model for teacher forcing (default: True). num_buckets (int, optional): if set to a value greater than 0, then batches will be bucketed into the given number of batch shapes. """ def __init__( self, src, src_sizes, tgt=None, tgt_sizes=None, dictionary=None, left_pad_source=False, left_pad_target=False, shuffle=True, input_feeding=True, num_buckets=0, ): self.src = src self.tgt = tgt self.src_sizes = np.array(src_sizes) self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None self.dictionary = dictionary self.left_pad_source = left_pad_source self.left_pad_target = left_pad_target self.shuffle = shuffle self.input_feeding = input_feeding if self.tgt is not None: self._match_src_tgt() if num_buckets > 0: from espresso.data import FeatBucketPadLengthDataset, TextBucketPadLengthDataset self.src = FeatBucketPadLengthDataset( self.src, sizes=self.src_sizes, num_buckets=num_buckets, pad_idx=0.0, left_pad=False, ) self.src_sizes = self.src.sizes logger.info('bucketing source lengths: {}'.format( list(self.src.buckets))) if self.tgt is not None: self.tgt = TextBucketPadLengthDataset( self.tgt, sizes=self.tgt_sizes, num_buckets=num_buckets, pad_idx=self.dictionary.pad(), left_pad=False, ) self.tgt_sizes = self.tgt.sizes logger.info('bucketing target lengths: {}'.format( list(self.tgt.buckets))) # determine bucket sizes using self.num_tokens, which will return # the padded lengths (thanks to FeatBucketPadLengthDataset) num_tokens = np.vectorize(self.num_tokens, otypes=[np.long]) self.bucketed_num_tokens = num_tokens(np.arange(len(self.src))) self.buckets = [ (None, num_tokens) for num_tokens in np.unique(self.bucketed_num_tokens) ] else: self.buckets = None def _match_src_tgt(self): """Makes utterances in src and tgt the same order in terms of their utt_ids. Removes those that are only present in one of them.""" assert self.tgt is not None if self.src.utt_ids == self.tgt.utt_ids: return tgt_utt_ids_set = set(self.tgt.utt_ids) src_indices = [ i for i, id in enumerate(self.src.utt_ids) if id in tgt_utt_ids_set ] self.src.filter_and_reorder(src_indices) self.src_sizes = np.array(self.src.sizes) try: tgt_indices = list(map(self.tgt.utt_ids.index, self.src.utt_ids)) except ValueError: raise ValueError( 'Unable to find some utt_id(s) in tgt. which is unlikely to happen. ' 'Something must be wrong.') self.tgt.filter_and_reorder(tgt_indices) self.tgt_sizes = np.array(self.tgt.sizes) assert self.src.utt_ids == self.tgt.utt_ids def get_batch_shapes(self): return self.buckets def __getitem__(self, index): tgt_item = self.tgt[index][0] if self.tgt is not None else None raw_text_item = self.tgt[index][1] if self.tgt is not None else None src_item = self.src[index] example = { 'id': index, 'utt_id': self.src.utt_ids[index], 'source': src_item, 'target': tgt_item, 'target_raw_text': raw_text_item, } return example def __len__(self): return len(self.src) def collater(self, samples): """Merge a list of samples to form a mini-batch. Args: samples (List[dict]): samples to collate Returns: dict: a mini-batch with the following keys: - `id` (LongTensor): example IDs in the original input order - `utt_id` (List[str]): list of utterance ids - `nsentences` (int): batch size - `ntokens` (int): total number of tokens in the batch - `net_input` (dict): the input to the Model, containing keys: - `src_tokens` (FloatTensor): a padded 3D Tensor of features in the source of shape `(bsz, src_len, feat_dim)`. Padding will appear on the left if *left_pad_source* is ``True``. - `src_lengths` (IntTensor): 1D Tensor of the unpadded lengths of each source sequence of shape `(bsz)` - `prev_output_tokens` (LongTensor): a padded 2D Tensor of tokens in the target sentence, shifted right by one position for teacher forcing, of shape `(bsz, tgt_len)`. This key will not be present if *input_feeding* is ``False``. Padding will appear on the left if *left_pad_target* is ``True``. - `target` (LongTensor): a padded 2D Tensor of tokens in the target sentence of shape `(bsz, tgt_len)`. Padding will appear on the left if *left_pad_target* is ``True``. - `target_raw_text` (List[str]): list of original text """ return collate( samples, pad_idx=self.dictionary.pad(), eos_idx=self.dictionary.eos(), left_pad_source=self.left_pad_source, left_pad_target=self.left_pad_target, input_feeding=self.input_feeding, src_bucketed=(self.buckets is not None), ) def num_tokens(self, index): """Return the number of frames in a sample. This value is used to enforce ``--max-tokens`` during batching.""" return self.src_sizes[index] def size(self, index): """Return an example's size as a float or tuple. This value is used when filtering a dataset with ``--max-positions``.""" return (self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0) def ordered_indices(self): """Return an ordered list of indices. Batches will be constructed based on this order.""" if self.shuffle: indices = np.random.permutation(len(self)) else: indices = np.arange(len(self)) if self.buckets is None: # sort by target length, then source length if self.tgt_sizes is not None: indices = indices[np.argsort(self.tgt_sizes[indices], kind='mergesort')] return indices[np.argsort(self.src_sizes[indices], kind='mergesort')] else: # sort by bucketed_num_tokens, which is padded_src_len return indices[np.argsort(self.bucketed_num_tokens[indices], kind='mergesort')] @property def supports_prefetch(self): return getattr(self.src, 'supports_prefetch', False) def prefetch(self, indices): """Only prefetch src.""" self.src.prefetch(indices) def set_epoch(self, epoch): super().set_epoch(epoch) if hasattr(self.src, 'set_epoch'): self.src.set_epoch(epoch) if self.tgt is not None and hasattr(self.tgt, 'set_epoch'): self.tgt.set_epoch(epoch)
class AsrDataset(FairseqDataset): """ A pair of torch.utils.data.Datasets. Args: src (torch.utils.data.Dataset): source dataset to wrap src_sizes (List[int]): source sentence lengths tgt (torch.utils.data.Dataset, optional): target dataset to wrap tgt_sizes (List[int], optional): target sentence lengths dictionary (~fairseq.data.Dictionary, optional): target vocabulary left_pad_source (bool, optional): pad source tensors on the left side (default: True). left_pad_target (bool, optional): pad target tensors on the left side (default: False). shuffle (bool, optional): shuffle dataset elements before batching (default: True). input_feeding (bool, optional): create a shifted version of the targets to be passed into the model for teacher forcing (default: True). constraints (Tensor, optional): 2d tensor with a concatenated, zero- delimited list of constraints for each sentence. num_buckets (int, optional): if set to a value greater than 0, then batches will be bucketed into the given number of batch shapes. src_lang_id (int, optional): source language ID, if set, the collated batch will contain a field "src_lang_id" in "net_input" which indicates the source language of the samples. tgt_lang_id (int, optional): target language ID, if set, the collated batch will contain a field "tgt_lang_id" which indicates the target language of the samples. pad_to_multiple (int, optional): pad src/tgt lengths to a multiple of this value """ def __init__( self, src, src_sizes, tgt=None, tgt_sizes=None, dictionary=None, left_pad_source=False, left_pad_target=False, shuffle=True, input_feeding=True, constraints=None, num_buckets=0, src_lang_id=None, tgt_lang_id=None, pad_to_multiple=1, ): self.src = src self.tgt = tgt self.src_sizes = np.array(src_sizes) self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None assert dictionary is not None self.dictionary = dictionary self.left_pad_source = left_pad_source self.left_pad_target = left_pad_target self.shuffle = shuffle self.input_feeding = input_feeding self.constraints = constraints self.src_lang_id = src_lang_id self.tgt_lang_id = tgt_lang_id if self.tgt is not None: self._match_src_tgt() self.sizes = (np.vstack((self.src_sizes, self.tgt_sizes)).T if self.tgt_sizes is not None else self.src_sizes) if num_buckets > 0: from espresso.data import FeatBucketPadLengthDataset, TextBucketPadLengthDataset self.src = FeatBucketPadLengthDataset( self.src, sizes=self.src_sizes, num_buckets=num_buckets, pad_idx=0.0, left_pad=False, ) self.src_sizes = self.src.sizes logger.info("bucketing source lengths: {}".format( list(self.src.buckets))) if self.tgt is not None: self.tgt = TextBucketPadLengthDataset( self.tgt, sizes=self.tgt_sizes, num_buckets=num_buckets, pad_idx=self.dictionary.pad(), left_pad=False, ) self.tgt_sizes = self.tgt.sizes logger.info("bucketing target lengths: {}".format( list(self.tgt.buckets))) # determine bucket sizes using self.num_tokens, which will return # the padded lengths (thanks to FeatBucketPadLengthDataset) num_tokens = np.vectorize(self.num_tokens, otypes=[np.long]) self.bucketed_num_tokens = num_tokens(np.arange(len(self.src))) self.buckets = [ (None, num_tokens) for num_tokens in np.unique(self.bucketed_num_tokens) ] else: self.buckets = None self.pad_to_multiple = pad_to_multiple def _match_src_tgt(self): """Makes utterances in src and tgt the same order in terms of their utt_ids. Removes those that are only present in one of them.""" assert self.tgt is not None if self.src.utt_ids == self.tgt.utt_ids: return tgt_utt_ids_set = set(self.tgt.utt_ids) src_indices = [ i for i, id in enumerate(self.src.utt_ids) if id in tgt_utt_ids_set ] self.src.filter_and_reorder(src_indices) self.src_sizes = np.array(self.src.sizes) try: tgt_indices = list(map(self.tgt.utt_ids.index, self.src.utt_ids)) except ValueError: raise ValueError( "Unable to find some utt_id(s) in tgt. which is unlikely to happen. " "Something must be wrong.") self.tgt.filter_and_reorder(tgt_indices) self.tgt_sizes = np.array(self.tgt.sizes) assert self.src.utt_ids == self.tgt.utt_ids def get_batch_shapes(self): return self.buckets def __getitem__(self, index): tgt_item = self.tgt[index][0] if self.tgt is not None else None raw_text_item = self.tgt[index][1] if self.tgt is not None else None src_item = self.src[index] example = { "id": index, "utt_id": self.src.utt_ids[index], "source": src_item, "target": tgt_item, "target_raw_text": raw_text_item, } if self.constraints is not None: example["constraints"] = self.constraints[index] return example def __len__(self): return len(self.src) def collater(self, samples, pad_to_length=None): """Merge a list of samples to form a mini-batch. Args: samples (List[dict]): samples to collate pad_to_length (dict, optional): a dictionary of {"source": source_pad_to_length, "target": target_pad_to_length} to indicate the max length to pad to in source and target respectively. Returns: dict: a mini-batch with the following keys: - `id` (LongTensor): example IDs in the original input order - `utt_id` (List[str]): list of utterance ids - `nsentences` (int): batch size - `ntokens` (int): total number of tokens in the batch - `net_input` (dict): the input to the Model, containing keys: - `src_tokens` (FloatTensor): a padded 3D Tensor of features in the source of shape `(bsz, src_len, feat_dim)`. Padding will appear on the left if *left_pad_source* is ``True``. - `src_lengths` (IntTensor): 1D Tensor of the unpadded lengths of each source sequence of shape `(bsz)` - `prev_output_tokens` (LongTensor): a padded 2D Tensor of tokens in the target sentence, shifted right by one position for teacher forcing, of shape `(bsz, tgt_len)`. This key will not be present if *input_feeding* is ``False``. Padding will appear on the left if *left_pad_target* is ``True``. - `src_lang_id` (LongTensor): a long Tensor which contains source language IDs of each sample in the batch - `target` (LongTensor): a padded 2D Tensor of tokens in the target sentence of shape `(bsz, tgt_len)`. Padding will appear on the left if *left_pad_target* is ``True``. - `target_raw_text` (List[str]): list of original text - `tgt_lang_id` (LongTensor): a long Tensor which contains target language IDs of each sample in the batch """ res = collate( samples, pad_idx=self.dictionary.pad(), eos_idx=self.dictionary.eos(), left_pad_source=self.left_pad_source, left_pad_target=self.left_pad_target, input_feeding=self.input_feeding, pad_to_length=pad_to_length, pad_to_multiple=self.pad_to_multiple, src_bucketed=(self.buckets is not None), ) if self.src_lang_id is not None or self.tgt_lang_id is not None: src_tokens = res["net_input"]["src_tokens"] bsz = src_tokens.size(0) if self.src_lang_id is not None: res["net_input"]["src_lang_id"] = (torch.LongTensor( [[self.src_lang_id]]).expand(bsz, 1).to(src_tokens)) if self.tgt_lang_id is not None: res["tgt_lang_id"] = (torch.LongTensor( [[self.tgt_lang_id]]).expand(bsz, 1).to(src_tokens)) return res def num_tokens(self, index): """Return the number of frames in a sample. This value is used to enforce ``--max-tokens`` during batching.""" return self.src_sizes[index] def num_tokens_vec(self, indices): """Return the number of tokens for a set of positions defined by indices. This value is used to enforce ``--max-tokens`` during batching.""" sizes = self.src_sizes[indices] return sizes def size(self, index): """Return an example's size as a float or tuple. This value is used when filtering a dataset with ``--max-positions``.""" return ( self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0, ) def ordered_indices(self): """Return an ordered list of indices. Batches will be constructed based on this order.""" if self.shuffle: indices = np.random.permutation(len(self)).astype(np.int64) else: indices = np.arange(len(self), dtype=np.int64) if self.buckets is None: # sort by target length, then source length if self.tgt_sizes is not None: indices = indices[np.argsort(self.tgt_sizes[indices], kind="mergesort")] return indices[np.argsort(self.src_sizes[indices], kind="mergesort")] else: # sort by bucketed_num_tokens, which is padded_src_len return indices[np.argsort(self.bucketed_num_tokens[indices], kind="mergesort")] @property def supports_prefetch(self): return getattr(self.src, "supports_prefetch", False) def prefetch(self, indices): """Only prefetch src.""" self.src.prefetch(indices) def filter_indices_by_size(self, indices, max_sizes): """Filter a list of sample indices. Remove those that are longer than specified in max_sizes. Args: indices (np.array): original array of sample indices max_sizes (int or list[int] or tuple[int]): max sample size, can be defined separately for src and tgt (then list or tuple) Returns: np.array: filtered sample array list: list of removed indices """ return data_utils.filter_paired_dataset_indices_by_size( self.src_sizes, self.tgt_sizes, indices, max_sizes, ) @property def supports_fetch_outside_dataloader(self): """Whether this dataset supports fetching outside the workers of the dataloader.""" return False @property def can_reuse_epoch_itr_across_epochs(self): return False # to avoid running out of CPU RAM def set_epoch(self, epoch): super().set_epoch(epoch) if hasattr(self.src, "set_epoch"): self.src.set_epoch(epoch) if self.tgt is not None and hasattr(self.tgt, "set_epoch"): self.tgt.set_epoch(epoch)
class AsrXentDataset(FairseqDataset): """ A pair of torch.utils.data.Datasets. Args: src (torch.utils.data.Dataset): source dataset to wrap src_sizes (List[int]): source sentence lengths tgt (espresso.data.AliScpCachedDataset, optional): target alignment dataset to wrap tgt_sizes (List[int], optional): target sizes (num of states in the numerator graph) tgt_vocab_size (int, optional): used for setting padding index text (torch.utils.data.Dataset, optional): text dataset to wrap shuffle (bool, optional): shuffle dataset elements before batching (default: True). num_buckets (int, optional): if set to a value greater than 0, then batches will be bucketed into the given number of batch shapes. pad_to_multiple (int, optional): pad src/tgt lengths to a multiple of this value seed (int, optional): random seed for generating a chunk from an utterance. chunk_width (int, optional): chunk width for chunk-wise training. chunk_left_context (int, optional): number of frames appended to the left of a chunk. chunk_right_context (int, optional): number of frames appended to the right of a chunk. label_delay (int, optional): offset of the alignments as prediction labels. Can be useful in archs such as asymmetric convolution, unidirectional LSTM, etc. random_chunking (bool, optional): wether do random chunking from utterance, or sequntially obtain chunks within each utterance. True for train and False for valid/test data. """ def __init__( self, src, src_sizes, tgt: Optional[AliScpCachedDataset] = None, tgt_sizes=None, text=None, shuffle=True, num_buckets=0, pad_to_multiple=1, seed=1, chunk_width=None, chunk_left_context=None, chunk_right_context=None, label_delay=0, random_chunking=True, ): self.src = src self.tgt = tgt self.src_sizes = np.array(src_sizes) self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None self.text = text self.shuffle = shuffle self.seed = seed self.epoch = 1 assert chunk_width is None or chunk_width > 0 self.chunk_width = chunk_width assert chunk_left_context >= 0 and chunk_right_context >= 0 self.chunk_left_context = chunk_left_context self.chunk_right_context = chunk_right_context assert (label_delay < 0 and -label_delay <= chunk_right_context) or \ (label_delay >= 0 and (chunk_width is None or label_delay < chunk_width)) self.label_delay = label_delay self.random_chunking = random_chunking if self.tgt is not None: self._match_src_tgt() if self.text is not None: changed = self._match_src_text() if self.tgt is not None and changed: self._match_src_tgt() self.sizes = np.vstack( (self.src_sizes, self.tgt_sizes )).T if self.tgt_sizes is not None else self.src_sizes if chunk_width is not None: # remove those whose lengths are shorter than chunk_size indices = np.flatnonzero(self.src.sizes >= chunk_width) if len(indices) < self.src.size: logger.warning( "Removing {} examples whose lengths are shorter than chunk_size={}" .format(self.src.size - len(indices), chunk_width)) self.src.filter_and_reorder(indices) if self.tgt is not None: self.tgt.filter_and_reorder(indices) if self.text is not None: self.text.filter_and_reorder(indices) logger.warning("Done removal. {} examples remaining".format( len(indices))) if num_buckets > 0: from fairseq.data import BucketPadLengthDataset from espresso.data import FeatBucketPadLengthDataset self.src = FeatBucketPadLengthDataset( self.src, sizes=self.src_sizes, num_buckets=num_buckets, pad_idx=0.0, left_pad=False, ) self.src_sizes = self.src.sizes logger.info("bucketing source lengths: {}".format( list(self.src.buckets))) if self.tgt is not None: self.tgt = BucketPadLengthDataset( self.tgt, sizes=self.tgt_sizes, num_buckets=num_buckets, pad_idx=self.dictionary.pad(), left_pad=False, ) self.tgt_sizes = self.tgt.sizes logger.info("bucketing target lengths: {}".format( list(self.tgt.buckets))) # determine bucket sizes using self.num_tokens, which will return # the padded lengths (thanks to FeatBucketPadLengthDataset) num_tokens = np.vectorize(self.num_tokens, otypes=[np.long]) self.bucketed_num_tokens = num_tokens(np.arange(len(self.src))) self.buckets = [ (None, num_tokens) for num_tokens in np.unique(self.bucketed_num_tokens) ] else: self.buckets = None self.pad_to_multiple = pad_to_multiple def _match_src_tgt(self): """Makes utterances in src and tgt the same order in terms of their utt_ids. Removes those that are only present in one of them.""" assert self.tgt is not None if self.src.utt_ids == self.tgt.utt_ids: assert np.all(self.src.sizes == self.tgt.sizes ), "frame and alignment lengths mismatch" return tgt_utt_ids_set = set(self.tgt.utt_ids) src_indices = [ i for i, id in enumerate(self.src.utt_ids) if id in tgt_utt_ids_set ] self.src.filter_and_reorder(src_indices) self.src_sizes = np.array(self.src.sizes) try: tgt_indices = list(map(self.tgt.utt_ids.index, self.src.utt_ids)) except ValueError: raise ValueError( "Unable to find some utt_id(s) in tgt. which is unlikely to happen. " "Something must be wrong.") self.tgt.filter_and_reorder(tgt_indices) self.tgt_sizes = np.array(self.tgt.sizes) assert self.src.utt_ids == self.tgt.utt_ids assert np.all(self.src.sizes == self.tgt.sizes), "frame and alignment lengths mismatch" def _match_src_text(self): """Makes utterances in src and text the same order in terms of their utt_ids. Removes those that are only present in one of them.""" assert self.text is not None if self.src.utt_ids == self.text.utt_ids: return False text_utt_ids_set = set(self.text.utt_ids) src_indices = [ i for i, id in enumerate(self.src.utt_ids) if id in text_utt_ids_set ] self.src.filter_and_reorder(src_indices) self.src_sizes = np.array(self.src.sizes) try: text_indices = list(map(self.text.utt_ids.index, self.src.utt_ids)) except ValueError: raise ValueError( "Unable to find some utt_id(s) in text. which is unlikely to happen. " "Something must be wrong.") self.text.filter_and_reorder(text_indices) assert self.src.utt_ids == self.text.utt_ids return True def get_batch_shapes(self): return self.buckets def __getitem__(self, index): tgt_item = self.tgt[index] if self.tgt is not None else None text_item = self.text[index][1] if self.text is not None else None src_item = self.src[index] example = { "id": index, "utt_id": self.src.utt_ids[index], "source": src_item, "target": tgt_item, "text": text_item, } return example def __len__(self): return len(self.src) def collater(self, samples, pad_to_length=None): """Merge a list of samples to form a mini-batch. Args: samples (List[dict]): samples to collate pad_to_length (dict, optional): a dictionary of {'source': source_pad_to_length, 'target': target_pad_to_length} to indicate the max length to pad to in source and target respectively. Returns: dict: a mini-batch with the following keys: - `id` (LongTensor): example IDs in the original input order - `utt_id` (List[str]): list of utterance ids - `nsentences` (int): batch size - `ntokens` (int): total number of tokens in the batch - `net_input` (dict): the input to the Model, containing keys: - `src_tokens` (FloatTensor): a padded 3D Tensor of features in the source of shape `(bsz, src_len, feat_dim)`. - `src_lengths` (IntTensor): 1D Tensor of the unpadded lengths of each source sequence of shape `(bsz)` - `target` (LongTensor): a padded 2D Tensor of indices in the target alignments of shape `(bsz, tgt_len)` - `text` (List[str]): list of original text """ # pad_idx=-100 matches the default in criterions return collate( samples, pad_idx=-100, chunk_width=self.chunk_width, chunk_left_context=self.chunk_left_context, chunk_right_context=self.chunk_right_context, label_delay=self.label_delay, seed=self.seed, epoch=self.epoch, pad_to_length=pad_to_length, pad_to_multiple=self.pad_to_multiple, src_bucketed=(self.buckets is not None), random_chunking=self.random_chunking, ) def num_tokens(self, index): """Return the number of frames in a sample. This value is used to enforce ``--max-tokens`` during batching.""" if self.chunk_width is None: return self.src_sizes[index] return self.chunk_width + self.chunk_left_context + self.chunk_right_context def size(self, index): """Return an example's size as a float or tuple. This value is used when filtering a dataset with ``--max-positions``.""" return (self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0) def ordered_indices(self): """Return an ordered list of indices. Batches will be constructed based on this order.""" if self.shuffle: indices = np.random.permutation(len(self)).astype(np.int64) else: indices = np.arange(len(self), dtype=np.int64) if self.buckets is None: # sort by target length, then source length if self.tgt_sizes is not None: indices = indices[np.argsort(self.tgt_sizes[indices], kind="mergesort")] return indices[np.argsort(self.src_sizes[indices], kind="mergesort")] else: # sort by bucketed_num_tokens, which is padded_src_len return indices[np.argsort(self.bucketed_num_tokens[indices], kind="mergesort")] @property def supports_prefetch(self): return getattr(self.src, "supports_prefetch", False) def prefetch(self, indices): self.src.prefetch(indices) if self.tgt is not None: self.tgt.prefetch(indices) def filter_indices_by_size(self, indices, max_sizes): """ Filter a list of sample indices. Remove those that are longer than specified in max_sizes. Args: indices (np.array): original array of sample indices max_sizes (int or list[int] or tuple[int]): max sample size, can be defined separately for src and tgt (then list or tuple) Returns: np.array: filtered sample array list: list of removed indices """ return data_utils.filter_paired_dataset_indices_by_size( self.src_sizes, self.tgt_sizes, indices, max_sizes, ) @property def can_reuse_epoch_itr_across_epochs(self): return False # to avoid running out of CPU RAM def set_epoch(self, epoch): super().set_epoch(epoch) self.epoch = epoch if hasattr(self.src, "set_epoch"): self.src.set_epoch(epoch) if self.tgt is not None and hasattr(self.tgt, "set_epoch"): self.tgt.set_epoch(epoch)