def __init__( self, src, src_sizes, src_dict, tgt=None, tgt_sizes=None, tgt_dict=None, left_pad_source=True, left_pad_target=False, shuffle=True, input_feeding=True, remove_eos_from_source=False, append_eos_to_target=False, align_dataset=None, constraints=None, append_bos=False, eos=None, num_buckets=0, src_lang_id=None, tgt_lang_id=None, ): if tgt_dict is not None: assert src_dict.pad() == tgt_dict.pad() assert src_dict.eos() == tgt_dict.eos() assert src_dict.unk() == tgt_dict.unk() if tgt is not None: assert len(src) == len( tgt ), "Source and target must contain the same number of examples" self.src = src self.tgt = tgt self.src_sizes = np.array(src_sizes) self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None self.src_dict = src_dict self.tgt_dict = tgt_dict self.left_pad_source = left_pad_source self.left_pad_target = left_pad_target self.shuffle = shuffle self.input_feeding = input_feeding self.remove_eos_from_source = remove_eos_from_source self.append_eos_to_target = append_eos_to_target self.align_dataset = align_dataset if self.align_dataset is not None: assert self.tgt_sizes is not None, "Both source and target needed when alignments are provided" self.constraints = constraints self.append_bos = append_bos self.eos = (eos if eos is not None else src_dict.eos()) self.src_lang_id = src_lang_id self.tgt_lang_id = tgt_lang_id if num_buckets > 0: from fairseq.data import BucketPadLengthDataset self.src = BucketPadLengthDataset( self.src, sizes=self.src_sizes, num_buckets=num_buckets, pad_idx=self.src_dict.pad(), left_pad=self.left_pad_source, ) self.src_sizes = self.src.sizes logger.info('bucketing source lengths: {}'.format( list(self.src.buckets))) if self.tgt is not None: self.tgt = BucketPadLengthDataset( self.tgt, sizes=self.tgt_sizes, num_buckets=num_buckets, pad_idx=self.tgt_dict.pad(), left_pad=self.left_pad_target, ) self.tgt_sizes = self.tgt.sizes logger.info('bucketing target lengths: {}'.format( list(self.tgt.buckets))) # determine bucket sizes using self.num_tokens, which will return # the padded lengths (thanks to BucketPadLengthDataset) num_tokens = np.vectorize(self.num_tokens, otypes=[np.long]) self.bucketed_num_tokens = num_tokens(np.arange(len(self.src))) self.buckets = [ (None, num_tokens) for num_tokens in np.unique(self.bucketed_num_tokens) ] else: self.buckets = None
def __init__( self, src, src_sizes, tgt: Optional[AliScpCachedDataset] = None, tgt_sizes=None, text=None, shuffle=True, num_buckets=0, pad_to_multiple=1, seed=1, chunk_width=None, chunk_left_context=None, chunk_right_context=None, label_delay=0, random_chunking=True, ): self.src = src self.tgt = tgt self.src_sizes = np.array(src_sizes) self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None self.text = text self.shuffle = shuffle self.seed = seed self.epoch = 1 assert chunk_width is None or chunk_width > 0 self.chunk_width = chunk_width assert chunk_left_context >= 0 and chunk_right_context >= 0 self.chunk_left_context = chunk_left_context self.chunk_right_context = chunk_right_context assert (label_delay < 0 and -label_delay <= chunk_right_context) or \ (label_delay >= 0 and (chunk_width is None or label_delay < chunk_width)) self.label_delay = label_delay self.random_chunking = random_chunking if self.tgt is not None: self._match_src_tgt() if self.text is not None: changed = self._match_src_text() if self.tgt is not None and changed: self._match_src_tgt() self.sizes = np.vstack( (self.src_sizes, self.tgt_sizes )).T if self.tgt_sizes is not None else self.src_sizes if chunk_width is not None: # remove those whose lengths are shorter than chunk_size indices = np.flatnonzero(self.src.sizes >= chunk_width) if len(indices) < self.src.size: logger.warning( "Removing {} examples whose lengths are shorter than chunk_size={}" .format(self.src.size - len(indices), chunk_width)) self.src.filter_and_reorder(indices) if self.tgt is not None: self.tgt.filter_and_reorder(indices) if self.text is not None: self.text.filter_and_reorder(indices) logger.warning("Done removal. {} examples remaining".format( len(indices))) if num_buckets > 0: from fairseq.data import BucketPadLengthDataset from espresso.data import FeatBucketPadLengthDataset self.src = FeatBucketPadLengthDataset( self.src, sizes=self.src_sizes, num_buckets=num_buckets, pad_idx=0.0, left_pad=False, ) self.src_sizes = self.src.sizes logger.info("bucketing source lengths: {}".format( list(self.src.buckets))) if self.tgt is not None: self.tgt = BucketPadLengthDataset( self.tgt, sizes=self.tgt_sizes, num_buckets=num_buckets, pad_idx=self.dictionary.pad(), left_pad=False, ) self.tgt_sizes = self.tgt.sizes logger.info("bucketing target lengths: {}".format( list(self.tgt.buckets))) # determine bucket sizes using self.num_tokens, which will return # the padded lengths (thanks to FeatBucketPadLengthDataset) num_tokens = np.vectorize(self.num_tokens, otypes=[np.long]) self.bucketed_num_tokens = num_tokens(np.arange(len(self.src))) self.buckets = [ (None, num_tokens) for num_tokens in np.unique(self.bucketed_num_tokens) ] else: self.buckets = None self.pad_to_multiple = pad_to_multiple
class LanguagePairDataset(FairseqDataset): """ A pair of torch.utils.data.Datasets. Args: src (torch.utils.data.Dataset): source dataset to wrap src_sizes (List[int]): source sentence lengths src_dict (~fairseq.data.Dictionary): source vocabulary tgt (torch.utils.data.Dataset, optional): target dataset to wrap tgt_sizes (List[int], optional): target sentence lengths tgt_dict (~fairseq.data.Dictionary, optional): target vocabulary left_pad_source (bool, optional): pad source tensors on the left side (default: True). left_pad_target (bool, optional): pad target tensors on the left side (default: False). shuffle (bool, optional): shuffle dataset elements before batching (default: True). input_feeding (bool, optional): create a shifted version of the targets to be passed into the model for teacher forcing (default: True). remove_eos_from_source (bool, optional): if set, removes eos from end of source if it's present (default: False). append_eos_to_target (bool, optional): if set, appends eos to end of target if it's absent (default: False). align_dataset (torch.utils.data.Dataset, optional): dataset containing alignments. constraints (Tensor, optional): 2d tensor with a concatenated, zero- delimited list of constraints for each sentence. append_bos (bool, optional): if set, appends bos to the beginning of source/target sentence. num_buckets (int, optional): if set to a value greater than 0, then batches will be bucketed into the given number of batch shapes. src_lang_id (int, optional): source language ID, if set, the collated batch will contain a field 'src_lang_id' in 'net_input' which indicates the source language of the samples. tgt_lang_id (int, optional): target language ID, if set, the collated batch will contain a field 'tgt_lang_id' which indicates the target language of the samples. """ def __init__( self, src, src_sizes, src_dict, tgt=None, tgt_sizes=None, tgt_dict=None, left_pad_source=True, left_pad_target=False, shuffle=True, input_feeding=True, remove_eos_from_source=False, append_eos_to_target=False, align_dataset=None, constraints=None, append_bos=False, eos=None, num_buckets=0, src_lang_id=None, tgt_lang_id=None, ): if tgt_dict is not None: assert src_dict.pad() == tgt_dict.pad() assert src_dict.eos() == tgt_dict.eos() assert src_dict.unk() == tgt_dict.unk() if tgt is not None: assert len(src) == len( tgt ), "Source and target must contain the same number of examples" self.src = src self.tgt = tgt self.src_sizes = np.array(src_sizes) self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None self.src_dict = src_dict self.tgt_dict = tgt_dict self.left_pad_source = left_pad_source self.left_pad_target = left_pad_target self.shuffle = shuffle self.input_feeding = input_feeding self.remove_eos_from_source = remove_eos_from_source self.append_eos_to_target = append_eos_to_target self.align_dataset = align_dataset if self.align_dataset is not None: assert self.tgt_sizes is not None, "Both source and target needed when alignments are provided" self.constraints = constraints self.append_bos = append_bos self.eos = (eos if eos is not None else src_dict.eos()) self.src_lang_id = src_lang_id self.tgt_lang_id = tgt_lang_id if num_buckets > 0: from fairseq.data import BucketPadLengthDataset self.src = BucketPadLengthDataset( self.src, sizes=self.src_sizes, num_buckets=num_buckets, pad_idx=self.src_dict.pad(), left_pad=self.left_pad_source, ) self.src_sizes = self.src.sizes logger.info('bucketing source lengths: {}'.format( list(self.src.buckets))) if self.tgt is not None: self.tgt = BucketPadLengthDataset( self.tgt, sizes=self.tgt_sizes, num_buckets=num_buckets, pad_idx=self.tgt_dict.pad(), left_pad=self.left_pad_target, ) self.tgt_sizes = self.tgt.sizes logger.info('bucketing target lengths: {}'.format( list(self.tgt.buckets))) # determine bucket sizes using self.num_tokens, which will return # the padded lengths (thanks to BucketPadLengthDataset) num_tokens = np.vectorize(self.num_tokens, otypes=[np.long]) self.bucketed_num_tokens = num_tokens(np.arange(len(self.src))) self.buckets = [ (None, num_tokens) for num_tokens in np.unique(self.bucketed_num_tokens) ] else: self.buckets = None def get_batch_shapes(self): return self.buckets def __getitem__(self, index): tgt_item = self.tgt[index] if self.tgt is not None else None src_item = self.src[index] # Append EOS to end of tgt sentence if it does not have an EOS and remove # EOS from end of src sentence if it exists. This is useful when we use # use existing datasets for opposite directions i.e., when we want to # use tgt_dataset as src_dataset and vice versa if self.append_eos_to_target: eos = self.tgt_dict.eos() if self.tgt_dict else self.src_dict.eos() if self.tgt and self.tgt[index][-1] != eos: tgt_item = torch.cat( [self.tgt[index], torch.LongTensor([eos])]) if self.append_bos: bos = self.tgt_dict.bos() if self.tgt_dict else self.src_dict.bos() if self.tgt and self.tgt[index][0] != bos: tgt_item = torch.cat( [torch.LongTensor([bos]), self.tgt[index]]) bos = self.src_dict.bos() if self.src[index][0] != bos: src_item = torch.cat( [torch.LongTensor([bos]), self.src[index]]) if self.remove_eos_from_source: eos = self.src_dict.eos() if self.src[index][-1] == eos: src_item = self.src[index][:-1] example = { 'id': index, 'source': src_item, 'target': tgt_item, } if self.align_dataset is not None: example['alignment'] = self.align_dataset[index] if self.constraints is not None: example["constraints"] = self.constraints[index] return example def __len__(self): return len(self.src) def collater(self, samples, pad_to_length=None): """Merge a list of samples to form a mini-batch. Args: samples (List[dict]): samples to collate pad_to_length (dict, optional): a dictionary of {'source': source_pad_to_length, 'target': target_pad_to_length} to indicate the max length to pad to in source and target respectively. Returns: dict: a mini-batch with the following keys: - `id` (LongTensor): example IDs in the original input order - `ntokens` (int): total number of tokens in the batch - `net_input` (dict): the input to the Model, containing keys: - `src_tokens` (LongTensor): a padded 2D Tensor of tokens in the source sentence of shape `(bsz, src_len)`. Padding will appear on the left if *left_pad_source* is ``True``. - `src_lengths` (LongTensor): 1D Tensor of the unpadded lengths of each source sentence of shape `(bsz)` - `prev_output_tokens` (LongTensor): a padded 2D Tensor of tokens in the target sentence, shifted right by one position for teacher forcing, of shape `(bsz, tgt_len)`. This key will not be present if *input_feeding* is ``False``. Padding will appear on the left if *left_pad_target* is ``True``. - `src_lang_id` (LongTensor): a long Tensor which contains source language IDs of each sample in the batch - `target` (LongTensor): a padded 2D Tensor of tokens in the target sentence of shape `(bsz, tgt_len)`. Padding will appear on the left if *left_pad_target* is ``True``. - `tgt_lang_id` (LongTensor): a long Tensor which contains target language IDs of each sample in the batch """ res = collate( samples, pad_idx=self.src_dict.pad(), eos_idx=self.eos, left_pad_source=self.left_pad_source, left_pad_target=self.left_pad_target, input_feeding=self.input_feeding, pad_to_length=pad_to_length, ) if self.src_lang_id is not None or self.tgt_lang_id is not None: src_tokens = res['net_input']['src_tokens'] bsz = src_tokens.size(0) if self.src_lang_id is not None: res['net_input']['src_lang_id'] = torch.LongTensor( [[self.src_lang_id]]).expand(bsz, 1).to(src_tokens) if self.tgt_lang_id is not None: res['tgt_lang_id'] = torch.LongTensor( [[self.tgt_lang_id]]).expand(bsz, 1).to(src_tokens) return res def num_tokens(self, index): """Return the number of tokens in a sample. This value is used to enforce ``--max-tokens`` during batching.""" return max(self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0) def size(self, index): """Return an example's size as a float or tuple. This value is used when filtering a dataset with ``--max-positions``.""" return (self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0) def ordered_indices(self): """Return an ordered list of indices. Batches will be constructed based on this order.""" if self.shuffle: indices = np.random.permutation(len(self)).astype(np.int64) else: indices = np.arange(len(self), dtype=np.int64) if self.buckets is None: # sort by target length, then source length if self.tgt_sizes is not None: indices = indices[np.argsort(self.tgt_sizes[indices], kind='mergesort')] return indices[np.argsort(self.src_sizes[indices], kind='mergesort')] else: # sort by bucketed_num_tokens, which is: # max(padded_src_len, padded_tgt_len) return indices[np.argsort(self.bucketed_num_tokens[indices], kind='mergesort')] @property def supports_prefetch(self): return (getattr(self.src, 'supports_prefetch', False) and (getattr(self.tgt, 'supports_prefetch', False) or self.tgt is None)) def prefetch(self, indices): self.src.prefetch(indices) if self.tgt is not None: self.tgt.prefetch(indices) if self.align_dataset is not None: self.align_dataset.prefetch(indices) def filter_indices_by_size(self, indices, max_sizes): """ Filter a list of sample indices. Remove those that are longer than specified in max_sizes. Args: indices (np.array): original array of sample indices max_sizes (int or list[int] or tuple[int]): max sample size, can be defined separately for src and tgt (then list or tuple) Returns: np.array: filtered sample array list: list of removed indices """ if max_sizes is None: return indices, [] if type(max_sizes) in (int, float): max_src_size, max_tgt_size = max_sizes, max_sizes else: max_src_size, max_tgt_size = max_sizes if self.tgt_sizes is None: ignored = indices[self.src_sizes[indices] > max_src_size] else: ignored = indices[(self.src_sizes[indices] > max_src_size) | (self.tgt_sizes[indices] > max_tgt_size)] if len(ignored) > 0: if self.tgt_sizes is None: indices = indices[self.src_sizes[indices] <= max_src_size] else: indices = indices[(self.src_sizes[indices] <= max_src_size) & (self.tgt_sizes[indices] <= max_tgt_size)] return indices, ignored.tolist()
def __init__(self, src, src_sizes, src_dict, tgt=None, tgt_sizes=None, tgt_dict=None, mt=None, mt_sizes=None, term=None, term_sizes=None, src_factor=None, src_factor_sizes=None, mt_factor=None, mt_factor_sizes=None, left_pad_source=True, left_pad_target=False, shuffle=True, input_feeding=True, remove_eos_from_source=False, append_eos_to_target=False, align_dataset=None, append_bos=False, eos=None, num_buckets=0, input_type='src_only'): """ Add mt to LanguagePairDataset Additional Args: mt (torch.utils.data.Dataset, optional): mt dataset to wrap mt_sizes (List[int], optional): mt sentence lengths """ super().__init__( src, src_sizes, src_dict, tgt=tgt, tgt_sizes=tgt_sizes, tgt_dict=tgt_dict, left_pad_source=left_pad_source, left_pad_target=left_pad_target, shuffle=shuffle, input_feeding=input_feeding, remove_eos_from_source=remove_eos_from_source, append_eos_to_target=append_eos_to_target, align_dataset=align_dataset, append_bos=append_bos, eos=eos, num_buckets=num_buckets, ) self.mt = mt self.mt_sizes = np.array(mt_sizes) if mt_sizes is not None else None self.term = term self.term_sizes = np.array( term_sizes) if term_sizes is not None else None self.src_factor = src_factor self.src_factor_sizes = np.array( src_factor_sizes) if src_factor_sizes is not None else None self.mt_factor = mt_factor self.mt_factor_sizes = np.array( mt_factor_sizes) if mt_factor_sizes is not None else None if num_buckets > 0: from fairseq.data import BucketPadLengthDataset if self.mt is not None: self.mt = BucketPadLengthDataset( self.mt, sizes=self.mt_sizes, num_buckets=num_buckets, pad_idx=self.tgt_dict.pad(), left_pad=self.left_pad_target, ) self.mt_sizes = self.mt.sizes logger.info('bucketing mt lengths: {}'.format( list(self.mt.buckets))) if self.term is not None: self.term = BucketPadLengthDataset( self.term, sizes=self.term_sizes, num_buckets=num_buckets, pad_idx=self.tgt_dict.pad(), left_pad=self.left_pad_target, ) self.term_sizes = self.term.sizes logger.info('bucketing term lengths: {}'.format( list(self.term.buckets))) if self.src_factor is not None: self.src_factor = BucketPadLengthDataset( self.src_factor, sizes=self.src_factor_sizes, num_buckets=num_buckets, pad_idx=self.tgt_dict.pad(), left_pad=self.left_pad_target, ) self.src_factor_sizes = self.src_factor.sizes logger.info('bucketing src_factor lengths: {}'.format( list(self.src_factor.buckets))) if self.mt_factor is not None: self.mt_factor = BucketPadLengthDataset( self.mt_factor, sizes=self.mt_factor_sizes, num_buckets=num_buckets, pad_idx=self.tgt_dict.pad(), left_pad=self.left_pad_target, ) self.mt_factor_sizes = self.mt_factor.sizes logger.info('bucketing src_factor lengths: {}'.format( list(self.mt_factor.buckets))) self.input_type = input_type
class AsrXentDataset(FairseqDataset): """ A pair of torch.utils.data.Datasets. Args: src (torch.utils.data.Dataset): source dataset to wrap src_sizes (List[int]): source sentence lengths tgt (espresso.data.AliScpCachedDataset, optional): target alignment dataset to wrap tgt_sizes (List[int], optional): target sizes (num of states in the numerator graph) tgt_vocab_size (int, optional): used for setting padding index text (torch.utils.data.Dataset, optional): text dataset to wrap shuffle (bool, optional): shuffle dataset elements before batching (default: True). num_buckets (int, optional): if set to a value greater than 0, then batches will be bucketed into the given number of batch shapes. pad_to_multiple (int, optional): pad src/tgt lengths to a multiple of this value seed (int, optional): random seed for generating a chunk from an utterance. chunk_width (int, optional): chunk width for chunk-wise training. chunk_left_context (int, optional): number of frames appended to the left of a chunk. chunk_right_context (int, optional): number of frames appended to the right of a chunk. label_delay (int, optional): offset of the alignments as prediction labels. Can be useful in archs such as asymmetric convolution, unidirectional LSTM, etc. random_chunking (bool, optional): wether do random chunking from utterance, or sequntially obtain chunks within each utterance. True for train and False for valid/test data. """ def __init__( self, src, src_sizes, tgt: Optional[AliScpCachedDataset] = None, tgt_sizes=None, text=None, shuffle=True, num_buckets=0, pad_to_multiple=1, seed=1, chunk_width=None, chunk_left_context=None, chunk_right_context=None, label_delay=0, random_chunking=True, ): self.src = src self.tgt = tgt self.src_sizes = np.array(src_sizes) self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None self.text = text self.shuffle = shuffle self.seed = seed self.epoch = 1 assert chunk_width is None or chunk_width > 0 self.chunk_width = chunk_width assert chunk_left_context >= 0 and chunk_right_context >= 0 self.chunk_left_context = chunk_left_context self.chunk_right_context = chunk_right_context assert (label_delay < 0 and -label_delay <= chunk_right_context) or \ (label_delay >= 0 and (chunk_width is None or label_delay < chunk_width)) self.label_delay = label_delay self.random_chunking = random_chunking if self.tgt is not None: self._match_src_tgt() if self.text is not None: changed = self._match_src_text() if self.tgt is not None and changed: self._match_src_tgt() self.sizes = np.vstack( (self.src_sizes, self.tgt_sizes )).T if self.tgt_sizes is not None else self.src_sizes if chunk_width is not None: # remove those whose lengths are shorter than chunk_size indices = np.flatnonzero(self.src.sizes >= chunk_width) if len(indices) < self.src.size: logger.warning( "Removing {} examples whose lengths are shorter than chunk_size={}" .format(self.src.size - len(indices), chunk_width)) self.src.filter_and_reorder(indices) if self.tgt is not None: self.tgt.filter_and_reorder(indices) if self.text is not None: self.text.filter_and_reorder(indices) logger.warning("Done removal. {} examples remaining".format( len(indices))) if num_buckets > 0: from fairseq.data import BucketPadLengthDataset from espresso.data import FeatBucketPadLengthDataset self.src = FeatBucketPadLengthDataset( self.src, sizes=self.src_sizes, num_buckets=num_buckets, pad_idx=0.0, left_pad=False, ) self.src_sizes = self.src.sizes logger.info("bucketing source lengths: {}".format( list(self.src.buckets))) if self.tgt is not None: self.tgt = BucketPadLengthDataset( self.tgt, sizes=self.tgt_sizes, num_buckets=num_buckets, pad_idx=self.dictionary.pad(), left_pad=False, ) self.tgt_sizes = self.tgt.sizes logger.info("bucketing target lengths: {}".format( list(self.tgt.buckets))) # determine bucket sizes using self.num_tokens, which will return # the padded lengths (thanks to FeatBucketPadLengthDataset) num_tokens = np.vectorize(self.num_tokens, otypes=[np.long]) self.bucketed_num_tokens = num_tokens(np.arange(len(self.src))) self.buckets = [ (None, num_tokens) for num_tokens in np.unique(self.bucketed_num_tokens) ] else: self.buckets = None self.pad_to_multiple = pad_to_multiple def _match_src_tgt(self): """Makes utterances in src and tgt the same order in terms of their utt_ids. Removes those that are only present in one of them.""" assert self.tgt is not None if self.src.utt_ids == self.tgt.utt_ids: assert np.all(self.src.sizes == self.tgt.sizes ), "frame and alignment lengths mismatch" return tgt_utt_ids_set = set(self.tgt.utt_ids) src_indices = [ i for i, id in enumerate(self.src.utt_ids) if id in tgt_utt_ids_set ] self.src.filter_and_reorder(src_indices) self.src_sizes = np.array(self.src.sizes) try: tgt_indices = list(map(self.tgt.utt_ids.index, self.src.utt_ids)) except ValueError: raise ValueError( "Unable to find some utt_id(s) in tgt. which is unlikely to happen. " "Something must be wrong.") self.tgt.filter_and_reorder(tgt_indices) self.tgt_sizes = np.array(self.tgt.sizes) assert self.src.utt_ids == self.tgt.utt_ids assert np.all(self.src.sizes == self.tgt.sizes), "frame and alignment lengths mismatch" def _match_src_text(self): """Makes utterances in src and text the same order in terms of their utt_ids. Removes those that are only present in one of them.""" assert self.text is not None if self.src.utt_ids == self.text.utt_ids: return False text_utt_ids_set = set(self.text.utt_ids) src_indices = [ i for i, id in enumerate(self.src.utt_ids) if id in text_utt_ids_set ] self.src.filter_and_reorder(src_indices) self.src_sizes = np.array(self.src.sizes) try: text_indices = list(map(self.text.utt_ids.index, self.src.utt_ids)) except ValueError: raise ValueError( "Unable to find some utt_id(s) in text. which is unlikely to happen. " "Something must be wrong.") self.text.filter_and_reorder(text_indices) assert self.src.utt_ids == self.text.utt_ids return True def get_batch_shapes(self): return self.buckets def __getitem__(self, index): tgt_item = self.tgt[index] if self.tgt is not None else None text_item = self.text[index][1] if self.text is not None else None src_item = self.src[index] example = { "id": index, "utt_id": self.src.utt_ids[index], "source": src_item, "target": tgt_item, "text": text_item, } return example def __len__(self): return len(self.src) def collater(self, samples, pad_to_length=None): """Merge a list of samples to form a mini-batch. Args: samples (List[dict]): samples to collate pad_to_length (dict, optional): a dictionary of {'source': source_pad_to_length, 'target': target_pad_to_length} to indicate the max length to pad to in source and target respectively. Returns: dict: a mini-batch with the following keys: - `id` (LongTensor): example IDs in the original input order - `utt_id` (List[str]): list of utterance ids - `nsentences` (int): batch size - `ntokens` (int): total number of tokens in the batch - `net_input` (dict): the input to the Model, containing keys: - `src_tokens` (FloatTensor): a padded 3D Tensor of features in the source of shape `(bsz, src_len, feat_dim)`. - `src_lengths` (IntTensor): 1D Tensor of the unpadded lengths of each source sequence of shape `(bsz)` - `target` (LongTensor): a padded 2D Tensor of indices in the target alignments of shape `(bsz, tgt_len)` - `text` (List[str]): list of original text """ # pad_idx=-100 matches the default in criterions return collate( samples, pad_idx=-100, chunk_width=self.chunk_width, chunk_left_context=self.chunk_left_context, chunk_right_context=self.chunk_right_context, label_delay=self.label_delay, seed=self.seed, epoch=self.epoch, pad_to_length=pad_to_length, pad_to_multiple=self.pad_to_multiple, src_bucketed=(self.buckets is not None), random_chunking=self.random_chunking, ) def num_tokens(self, index): """Return the number of frames in a sample. This value is used to enforce ``--max-tokens`` during batching.""" if self.chunk_width is None: return self.src_sizes[index] return self.chunk_width + self.chunk_left_context + self.chunk_right_context def size(self, index): """Return an example's size as a float or tuple. This value is used when filtering a dataset with ``--max-positions``.""" return (self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0) def ordered_indices(self): """Return an ordered list of indices. Batches will be constructed based on this order.""" if self.shuffle: indices = np.random.permutation(len(self)).astype(np.int64) else: indices = np.arange(len(self), dtype=np.int64) if self.buckets is None: # sort by target length, then source length if self.tgt_sizes is not None: indices = indices[np.argsort(self.tgt_sizes[indices], kind="mergesort")] return indices[np.argsort(self.src_sizes[indices], kind="mergesort")] else: # sort by bucketed_num_tokens, which is padded_src_len return indices[np.argsort(self.bucketed_num_tokens[indices], kind="mergesort")] @property def supports_prefetch(self): return getattr(self.src, "supports_prefetch", False) def prefetch(self, indices): self.src.prefetch(indices) if self.tgt is not None: self.tgt.prefetch(indices) def filter_indices_by_size(self, indices, max_sizes): """ Filter a list of sample indices. Remove those that are longer than specified in max_sizes. Args: indices (np.array): original array of sample indices max_sizes (int or list[int] or tuple[int]): max sample size, can be defined separately for src and tgt (then list or tuple) Returns: np.array: filtered sample array list: list of removed indices """ return data_utils.filter_paired_dataset_indices_by_size( self.src_sizes, self.tgt_sizes, indices, max_sizes, ) @property def can_reuse_epoch_itr_across_epochs(self): return False # to avoid running out of CPU RAM def set_epoch(self, epoch): super().set_epoch(epoch) self.epoch = epoch if hasattr(self.src, "set_epoch"): self.src.set_epoch(epoch) if self.tgt is not None and hasattr(self.tgt, "set_epoch"): self.tgt.set_epoch(epoch)
class APEDataset(LanguagePairDataset): def __init__(self, src, src_sizes, src_dict, tgt=None, tgt_sizes=None, tgt_dict=None, mt=None, mt_sizes=None, term=None, term_sizes=None, src_factor=None, src_factor_sizes=None, mt_factor=None, mt_factor_sizes=None, left_pad_source=True, left_pad_target=False, shuffle=True, input_feeding=True, remove_eos_from_source=False, append_eos_to_target=False, align_dataset=None, append_bos=False, eos=None, num_buckets=0, input_type='src_only'): """ Add mt to LanguagePairDataset Additional Args: mt (torch.utils.data.Dataset, optional): mt dataset to wrap mt_sizes (List[int], optional): mt sentence lengths """ super().__init__( src, src_sizes, src_dict, tgt=tgt, tgt_sizes=tgt_sizes, tgt_dict=tgt_dict, left_pad_source=left_pad_source, left_pad_target=left_pad_target, shuffle=shuffle, input_feeding=input_feeding, remove_eos_from_source=remove_eos_from_source, append_eos_to_target=append_eos_to_target, align_dataset=align_dataset, append_bos=append_bos, eos=eos, num_buckets=num_buckets, ) self.mt = mt self.mt_sizes = np.array(mt_sizes) if mt_sizes is not None else None self.term = term self.term_sizes = np.array( term_sizes) if term_sizes is not None else None self.src_factor = src_factor self.src_factor_sizes = np.array( src_factor_sizes) if src_factor_sizes is not None else None self.mt_factor = mt_factor self.mt_factor_sizes = np.array( mt_factor_sizes) if mt_factor_sizes is not None else None if num_buckets > 0: from fairseq.data import BucketPadLengthDataset if self.mt is not None: self.mt = BucketPadLengthDataset( self.mt, sizes=self.mt_sizes, num_buckets=num_buckets, pad_idx=self.tgt_dict.pad(), left_pad=self.left_pad_target, ) self.mt_sizes = self.mt.sizes logger.info('bucketing mt lengths: {}'.format( list(self.mt.buckets))) if self.term is not None: self.term = BucketPadLengthDataset( self.term, sizes=self.term_sizes, num_buckets=num_buckets, pad_idx=self.tgt_dict.pad(), left_pad=self.left_pad_target, ) self.term_sizes = self.term.sizes logger.info('bucketing term lengths: {}'.format( list(self.term.buckets))) if self.src_factor is not None: self.src_factor = BucketPadLengthDataset( self.src_factor, sizes=self.src_factor_sizes, num_buckets=num_buckets, pad_idx=self.tgt_dict.pad(), left_pad=self.left_pad_target, ) self.src_factor_sizes = self.src_factor.sizes logger.info('bucketing src_factor lengths: {}'.format( list(self.src_factor.buckets))) if self.mt_factor is not None: self.mt_factor = BucketPadLengthDataset( self.mt_factor, sizes=self.mt_factor_sizes, num_buckets=num_buckets, pad_idx=self.tgt_dict.pad(), left_pad=self.left_pad_target, ) self.mt_factor_sizes = self.mt_factor.sizes logger.info('bucketing src_factor lengths: {}'.format( list(self.mt_factor.buckets))) self.input_type = input_type def __getitem__(self, index): example = super().__getitem__(index) mt_item = self.mt[index] if self.mt is not None else None term_item = self.term[index] if self.term is not None else None src_factor_item = self.src_factor[ index] if self.src_factor is not None else None mt_factor_item = self.mt_factor[ index] if self.mt_factor is not None else None # Append EOS to end of tgt sentence if it does not have an EOS and remove # EOS from end of src sentence if it exists. This is useful when we use # use existing datasets for opposite directions i.e., when we want to # use tgt_dataset as src_dataset and vice versa if self.append_eos_to_target: eos = self.tgt_dict.eos() if self.tgt_dict else self.src_dict.eos() if self.mt and self.mt[index][-1] != eos: mt_item = torch.cat([self.mt[index], torch.LongTensor([eos])]) if self.term and self.term[index][-1] != eos: term_item = torch.cat( [self.term[index], torch.LongTensor([eos])]) if self.src_factor and self.src_factor[index][-1] != eos: src_factor_item = torch.cat( [self.src_factor[index], torch.LongTensor([eos])]) if self.mt_factor and self.mt_factor[index][-1] != eos: mt_factor_item = torch.cat( [self.mt_factor[index], torch.LongTensor([eos])]) if self.append_bos: bos = self.tgt_dict.bos() if self.tgt_dict else self.src_dict.bos() if self.mt and self.mt[index][0] != bos: mt_item = torch.cat([torch.LongTensor([bos]), self.mt[index]]) if self.term and self.term[index][0] != bos: term_item = torch.cat( [torch.LongTensor([bos]), self.term[index]]) if self.src_factor and self.src_factor[index][0] != bos: src_factor_item = torch.cat( [torch.LongTensor([bos]), self.src_factor[index]]) if self.mt_factor and self.mt_factor[index][0] != bos: mt_factor_item = torch.cat( [torch.LongTensor([bos]), self.mt_factor[index]]) if self.input_type == "concatenate": src_item = example["source"] eos = [src_item[-1]] combined_item = src_item[:-1] mt_sep = [self.src_dict.index('<sep>')] combined_item = torch.cat( [combined_item, torch.LongTensor(mt_sep), mt_item[1:-1]]) combined_item = torch.cat([combined_item, torch.LongTensor(eos)]) example['source'] = combined_item example["mt"] = mt_item example["term"] = term_item example["src_factor"] = src_factor_item example["mt_factor"] = mt_factor_item return example def collater(self, samples): """Merge a list of samples to form a mini-batch.""" return collate( samples, pad_idx=self.src_dict.pad(), eos_idx=self.eos, left_pad_source=self.left_pad_source, left_pad_target=self.left_pad_target, input_feeding=self.input_feeding, input_type=self.input_type, ) def num_tokens(self, index): """Return the number of tokens in a sample. This value is used to enforce ``--max-tokens`` during batching.""" return max( self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0, self.mt_sizes[index] if self.mt_sizes is not None else 0, self.term_sizes[index] if self.term_sizes is not None else 0, self.src_factor_sizes[index] if self.src_factor_sizes is not None else 0, self.mt_factor_sizes[index] if self.mt_factor_sizes is not None else 0, ) def size(self, index): """Return an example's size as a float or tuple. This value is used when filtering a dataset with ``--max-positions``.""" return ( self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0, self.mt_sizes[index] if self.mt_sizes is not None else 0, self.term_sizes[index] if self.term_sizes is not None else 0, self.src_factor_sizes[index] if self.src_factor_sizes is not None else 0, self.mt_factor_sizes[index] if self.mt_factor_sizes is not None else 0, ) def ordered_indices(self): """Return an ordered list of indices. Batches will be constructed based on this order.""" if self.shuffle: indices = np.random.permutation(len(self)) else: indices = np.arange(len(self)) if self.buckets is None: # sort by target length, mt_length then source length if self.tgt_sizes is not None: indices = indices[np.argsort(self.tgt_sizes[indices], kind='mergesort')] if self.mt_sizes is not None: indices = indices[np.argsort(self.mt_sizes[indices], kind='mergesort')] if self.term_sizes is not None: indices = indices[np.argsort(self.term_sizes[indices], kind='mergesort')] if self.src_factor_sizes is not None: indices = indices[np.argsort(self.src_factor_sizes[indices], kind='mergesort')] if self.mt_factor_sizes is not None: indices = indices[np.argsort(self.mt_factor_sizes[indices], kind='mergesort')] return indices[np.argsort(self.src_sizes[indices], kind='mergesort')] else: # sort by bucketed_num_tokens, which is: # max(padded_src_len, padded_tgt_len) return indices[np.argsort(self.bucketed_num_tokens[indices], kind='mergesort')] @property def supports_prefetch(self): return (getattr(self.src, 'supports_prefetch', False) and (getattr(self.tgt, 'supports_prefetch', False) or self.tgt is None) and (getattr(self.mt, 'supports_prefetch', False) or self.mt is None) and (getattr(self.term, 'supports_prefetch', False) or self.term is None) and (getattr(self.src_factor, 'supports_prefetch', False) or self.src_factor is None) and (getattr(self.mt_factor, 'supports_prefetch', False) or self.mt_factor is None)) def prefetch(self, indices): self.src.prefetch(indices) if self.tgt is not None: self.tgt.prefetch(indices) if self.mt is not None: self.mt.prefetch(indices) if self.term is not None: self.term.prefetch(indices) if self.src_factor is not None: self.src_factor.prefetch(indices) if self.mt_factor is not None: self.mt_factor.prefetch(indices)
def __init__( self, src, src_sizes, src_dict, tgt=None, tgt_sizes=None, tgt_dict=None, masking=False, src_bert_dataset=None, denoising=False, src_bart_dataset=None, src_electra_dataset=None, electra_pretrain=None, #extra_datasets=None, left_pad_source=True, left_pad_target=False, shuffle=True, input_feeding=True, remove_eos_from_source=False, append_eos_to_target=False, align_dataset=None, constraints=None, append_bos=False, eos=None, num_buckets=0, src_lang_id=None, tgt_lang_id=None, pad_to_multiple=1, ): if tgt_dict is not None: assert src_dict.pad() == tgt_dict.pad() assert src_dict.eos() == tgt_dict.eos() assert src_dict.unk() == tgt_dict.unk() #import pdb; pdb.set_trace() if tgt is not None: assert len(src) == len( tgt ), "Source and target must contain the same number of examples" assert not (denoising is True and src_bart_dataset == None is True) assert not (denoising is True and src_bart_dataset == None is True) self.src = src self.tgt = tgt self.src_sizes = np.array(src_sizes) self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None self.sizes = (np.vstack((self.src_sizes, self.tgt_sizes)).T if self.tgt_sizes is not None else self.src_sizes) self.src_dict = src_dict self.tgt_dict = tgt_dict self.left_pad_source = left_pad_source self.left_pad_target = left_pad_target self.shuffle = shuffle self.input_feeding = input_feeding self.remove_eos_from_source = remove_eos_from_source self.append_eos_to_target = append_eos_to_target self.align_dataset = align_dataset if self.align_dataset is not None: assert ( self.tgt_sizes is not None ), "Both source and target needed when alignments are provided" self.constraints = constraints self.append_bos = append_bos self.eos = eos if eos is not None else src_dict.eos() self.src_lang_id = src_lang_id self.tgt_lang_id = tgt_lang_id if num_buckets > 0: from fairseq.data import BucketPadLengthDataset self.src = BucketPadLengthDataset( self.src, sizes=self.src_sizes, num_buckets=num_buckets, pad_idx=self.src_dict.pad(), left_pad=self.left_pad_source, ) self.src_sizes = self.src.sizes logger.info("bucketing source lengths: {}".format( list(self.src.buckets))) if self.tgt is not None: self.tgt = BucketPadLengthDataset( self.tgt, sizes=self.tgt_sizes, num_buckets=num_buckets, pad_idx=self.tgt_dict.pad(), left_pad=self.left_pad_target, ) self.tgt_sizes = self.tgt.sizes logger.info("bucketing target lengths: {}".format( list(self.tgt.buckets))) # determine bucket sizes using self.num_tokens, which will return # the padded lengths (thanks to BucketPadLengthDataset) num_tokens = np.vectorize(self.num_tokens, otypes=[np.compat.long]) self.bucketed_num_tokens = num_tokens(np.arange(len(self.src))) self.buckets = [ (None, num_tokens) for num_tokens in np.unique(self.bucketed_num_tokens) ] else: self.buckets = None self.pad_to_multiple = pad_to_multiple #self.extra_data = extra_datasets if extra_datasets else None if hasattr(self.src, "pad_dict"): self.pad_dict = self.src.pad_dict else: self.pad_dict = {}
class LanguagePairDataset(FairseqDataset): """ A pair of torch.utils.data.Datasets. Args: src (torch.utils.data.Dataset): source dataset to wrap src_sizes (List[int]): source sentence lengths src_dict (~fairseq.data.Dictionary): source vocabulary tgt (torch.utils.data.Dataset, optional): target dataset to wrap tgt_sizes (List[int], optional): target sentence lengths tgt_dict (~fairseq.data.Dictionary, optional): target vocabulary left_pad_source (bool, optional): pad source tensors on the left side (default: True). left_pad_target (bool, optional): pad target tensors on the left side (default: False). shuffle (bool, optional): shuffle dataset elements before batching (default: True). input_feeding (bool, optional): create a shifted version of the targets to be passed into the model for teacher forcing (default: True). remove_eos_from_source (bool, optional): if set, removes eos from end of source if it's present (default: False). append_eos_to_target (bool, optional): if set, appends eos to end of target if it's absent (default: False). align_dataset (torch.utils.data.Dataset, optional): dataset containing alignments. append_bos (bool, optional): if set, appends bos to the beginning of source/target sentence. num_buckets (int, optional): if set to a value greater than 0, then batches will be bucketed into the given number of batch shapes. """ def __init__( self, src, src_sizes, src_dict, tgt=None, tgt_sizes=None, tgt_dict=None, left_pad_source=True, left_pad_target=False, shuffle=True, input_feeding=True, remove_eos_from_source=False, append_eos_to_target=False, align_dataset=None, append_bos=False, eos=None, num_buckets=0, max_source_positions=1024, max_target_positions=1024, ): if tgt_dict is not None: assert src_dict.pad() == tgt_dict.pad() assert src_dict.eos() == tgt_dict.eos() assert src_dict.unk() == tgt_dict.unk() if tgt is not None: assert len(src) == len(tgt), "Source and target must contain the same number of examples" self.src = src self.tgt = tgt self.src_sizes = np.array(src_sizes) self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None self.src_dict = src_dict self.tgt_dict = tgt_dict self.left_pad_source = left_pad_source self.left_pad_target = left_pad_target self.max_source_positions = max_source_positions self.max_target_positions = max_target_positions self.shuffle = shuffle self.input_feeding = input_feeding self.remove_eos_from_source = remove_eos_from_source self.append_eos_to_target = append_eos_to_target self.align_dataset = align_dataset if self.align_dataset is not None: assert self.tgt_sizes is not None, "Both source and target needed when alignments are provided" self.append_bos = append_bos self.eos = (eos if eos is not None else src_dict.eos()) if num_buckets > 0: from fairseq.data import BucketPadLengthDataset self.src = BucketPadLengthDataset( self.src, sizes=self.src_sizes, num_buckets=num_buckets, pad_idx=self.src_dict.pad(), left_pad=self.left_pad_source, ) self.src_sizes = self.src.sizes logger.info('bucketing source lengths: {}'.format(list(self.src.buckets))) if self.tgt is not None: self.tgt = BucketPadLengthDataset( self.tgt, sizes=self.tgt_sizes, num_buckets=num_buckets, pad_idx=self.tgt_dict.pad(), left_pad=self.left_pad_target, ) self.tgt_sizes = self.tgt.sizes logger.info('bucketing target lengths: {}'.format(list(self.tgt.buckets))) # determine bucket sizes using self.num_tokens, which will return # the padded lengths (thanks to BucketPadLengthDataset) num_tokens = np.vectorize(self.num_tokens, otypes=[np.long]) self.bucketed_num_tokens = num_tokens(np.arange(len(self.src))) self.buckets = [ (None, num_tokens) for num_tokens in np.unique(self.bucketed_num_tokens) ] else: self.buckets = None def get_batch_shapes(self): return self.buckets def __getitem__(self, index): tgt_item = self.tgt[index] if self.tgt is not None else None src_item = self.src[index] # Append EOS to end of tgt sentence if it does not have an EOS and remove # EOS from end of src sentence if it exists. This is useful when we use # use existing datasets for opposite directions i.e., when we want to # use tgt_dataset as src_dataset and vice versa if self.append_eos_to_target: eos = self.tgt_dict.eos() if self.tgt_dict else self.src_dict.eos() if self.tgt and self.tgt[index][-1] != eos: tgt_item = torch.cat([self.tgt[index], torch.LongTensor([eos])]) if self.append_bos: bos = self.tgt_dict.bos() if self.tgt_dict else self.src_dict.bos() if self.tgt and self.tgt[index][0] != bos: tgt_item = torch.cat([torch.LongTensor([bos]), self.tgt[index]]) bos = self.src_dict.bos() if self.src[index][-1] != bos: src_item = torch.cat([torch.LongTensor([bos]), self.src[index]]) if self.remove_eos_from_source: eos = self.src_dict.eos() if self.src[index][-1] == eos: src_item = self.src[index][:-1] example = { 'id': index, 'source': src_item, 'target': tgt_item, } if self.align_dataset is not None: example['alignment'] = self.align_dataset[index] return example def __len__(self): return len(self.src) def collater(self, samples): """Merge a list of samples to form a mini-batch. Args: samples (List[dict]): samples to collate Returns: dict: a mini-batch with the following keys: - `id` (LongTensor): example IDs in the original input order - `ntokens` (int): total number of tokens in the batch - `net_input` (dict): the input to the Model, containing keys: - `src_tokens` (LongTensor): a padded 2D Tensor of tokens in the source sentence of shape `(bsz, src_len)`. Padding will appear on the left if *left_pad_source* is ``True``. - `src_lengths` (LongTensor): 1D Tensor of the unpadded lengths of each source sentence of shape `(bsz)` - `prev_output_tokens` (LongTensor): a padded 2D Tensor of tokens in the target sentence, shifted right by one position for teacher forcing, of shape `(bsz, tgt_len)`. This key will not be present if *input_feeding* is ``False``. Padding will appear on the left if *left_pad_target* is ``True``. - `target` (LongTensor): a padded 2D Tensor of tokens in the target sentence of shape `(bsz, tgt_len)`. Padding will appear on the left if *left_pad_target* is ``True``. """ return collate( samples, pad_idx=self.src_dict.pad(), eos_idx=self.eos, left_pad_source=self.left_pad_source, left_pad_target=self.left_pad_target, input_feeding=self.input_feeding, ) def get_dummy_batch(self, num_tokens, max_positions, src_len=128, tgt_len=128): """Return a dummy batch with a given number of tokens.""" src_len, tgt_len = utils.resolve_max_positions( (src_len, tgt_len), max_positions, (self.max_source_positions, self.max_target_positions), ) bsz = max(num_tokens // max(src_len, tgt_len), 1) return self.collater([ { 'id': i, 'source': self.src_dict.dummy_sentence(src_len), 'target': self.tgt_dict.dummy_sentence(tgt_len) if self.tgt_dict is not None else None, } for i in range(bsz) ]) def num_tokens(self, index): """Return the number of tokens in a sample. This value is used to enforce ``--max-tokens`` during batching.""" return max(self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0) def size(self, index): """Return an example's size as a float or tuple. This value is used when filtering a dataset with ``--max-positions``.""" return (self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0) def ordered_indices(self): """Return an ordered list of indices. Batches will be constructed based on this order.""" if self.shuffle: indices = np.random.permutation(len(self)) else: indices = np.arange(len(self)) if self.buckets is None: # sort by target length, then source length if self.tgt_sizes is not None: indices = indices[ np.argsort(self.tgt_sizes[indices], kind='mergesort') ] return indices[np.argsort(self.src_sizes[indices], kind='mergesort')] else: # sort by bucketed_num_tokens, which is: # max(padded_src_len, padded_tgt_len) return indices[ np.argsort(self.bucketed_num_tokens[indices], kind='mergesort') ] @property def supports_prefetch(self): return ( getattr(self.src, 'supports_prefetch', False) and (getattr(self.tgt, 'supports_prefetch', False) or self.tgt is None) ) def prefetch(self, indices): self.src.prefetch(indices) if self.tgt is not None: self.tgt.prefetch(indices) if self.align_dataset is not None: self.align_dataset.prefetch(indices)
def __init__( self, src, src_sizes, src_dict, tgt=None, tgt_sizes=None, tgt_dict=None, left_pad_source=True, left_pad_target=False, shuffle=True, input_feeding=True, remove_eos_from_source=False, append_eos_to_target=False, align_dataset=None, constraints=None, append_bos=False, eos=None, num_buckets=0, src_lang_id=None, tgt_lang_id=None, pad_to_multiple=1, add_lang_token=False, lang_pair=None, shuffle_lang_pair=False, args=None, ): if tgt_dict is not None: assert src_dict.pad() == tgt_dict.pad() assert src_dict.eos() == tgt_dict.eos() assert src_dict.unk() == tgt_dict.unk() if tgt is not None: assert len(src) == len( tgt ), "Source and target must contain the same number of examples" self.src = src self.tgt = tgt self.src_sizes = np.array(src_sizes) self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None self.sizes = ( np.vstack((self.src_sizes, self.tgt_sizes)).T if self.tgt_sizes is not None else self.src_sizes ) self.src_dict = src_dict self.tgt_dict = tgt_dict self.left_pad_source = left_pad_source self.left_pad_target = left_pad_target self.shuffle = shuffle self.input_feeding = input_feeding self.remove_eos_from_source = remove_eos_from_source self.append_eos_to_target = append_eos_to_target self.align_dataset = align_dataset self.add_lang_token = add_lang_token self.lang_pair = lang_pair self.args=args if self.align_dataset is not None: assert ( self.tgt_sizes is not None ), "Both source and target needed when alignments are provided" self.constraints = constraints self.append_bos = append_bos self.eos = eos if eos is not None else src_dict.eos() self.src_lang_id = src_lang_id self.tgt_lang_id = tgt_lang_id if num_buckets > 0: from fairseq.data import BucketPadLengthDataset self.src = BucketPadLengthDataset( self.src, sizes=self.src_sizes, num_buckets=num_buckets, pad_idx=self.src_dict.pad(), left_pad=self.left_pad_source, ) self.src_sizes = self.src.sizes logger.info("bucketing source lengths: {}".format(list(self.src.buckets))) if self.tgt is not None: self.tgt = BucketPadLengthDataset( self.tgt, sizes=self.tgt_sizes, num_buckets=num_buckets, pad_idx=self.tgt_dict.pad(), left_pad=self.left_pad_target, ) self.tgt_sizes = self.tgt.sizes logger.info( "bucketing target lengths: {}".format(list(self.tgt.buckets)) ) # determine bucket sizes using self.num_tokens, which will return # the padded lengths (thanks to BucketPadLengthDataset) num_tokens = np.vectorize(self.num_tokens, otypes=[np.long]) self.bucketed_num_tokens = num_tokens(np.arange(len(self.src))) self.buckets = [ (None, num_tokens) for num_tokens in np.unique(self.bucketed_num_tokens) ] else: self.buckets = None self.pad_to_multiple = pad_to_multiple self.shuffle_lang_pair = shuffle_lang_pair self.src_mask_index = self.src_dict.index("<mask>") self.tgt_mask_index = self.tgt_dict.index("<mask>") self.padding_index = self.tgt_dict.index("<pad>") self.ratio = 0.5 if args != None and args.task == "translation_from_pretrained_maskdecode_multi": self.random_ratio = args.mask_random self.replace_length = args.replace_length if self.replace_length not in [-1, 0, 1]: raise ValueError(f"invalid arg: replace_length={self.replace_length}") if args.mask_length not in ["subword", "word", "span-poisson"]: raise ValueError(f"invalid arg: mask-length={args.mask_length}") if args.mask_length == "subword" and args.replace_length not in [0, 1]: raise ValueError(f"if using subwords, use replace-length=1 or 0") self.mask_span_distribution = None if self.args.mask_length == "span-poisson": _lambda = self.args.poisson_lambda lambda_to_the_k = 1 e_to_the_minus_lambda = math.exp(-_lambda) k_factorial = 1 ps = [] for k in range(0, 128): ps.append(e_to_the_minus_lambda * lambda_to_the_k / k_factorial) lambda_to_the_k *= _lambda k_factorial *= k + 1 if ps[-1] < 0.0000001: break ps = torch.FloatTensor(ps) self.mask_span_distribution = torch.distributions.Categorical(ps)
class LanguagePairDataset(FairseqDataset): """ A pair of torch.utils.data.Datasets. Args: src (torch.utils.data.Dataset): source dataset to wrap src_sizes (List[int]): source sentence lengths src_dict (~fairseq.data.Dictionary): source vocabulary tgt (torch.utils.data.Dataset, optional): target dataset to wrap tgt_sizes (List[int], optional): target sentence lengths tgt_dict (~fairseq.data.Dictionary, optional): target vocabulary left_pad_source (bool, optional): pad source tensors on the left side (default: True). left_pad_target (bool, optional): pad target tensors on the left side (default: False). shuffle (bool, optional): shuffle dataset elements before batching (default: True). input_feeding (bool, optional): create a shifted version of the targets to be passed into the model for teacher forcing (default: True). remove_eos_from_source (bool, optional): if set, removes eos from end of source if it's present (default: False). append_eos_to_target (bool, optional): if set, appends eos to end of target if it's absent (default: False). align_dataset (torch.utils.data.Dataset, optional): dataset containing alignments. constraints (Tensor, optional): 2d tensor with a concatenated, zero- delimited list of constraints for each sentence. append_bos (bool, optional): if set, appends bos to the beginning of source/target sentence. num_buckets (int, optional): if set to a value greater than 0, then batches will be bucketed into the given number of batch shapes. src_lang_id (int, optional): source language ID, if set, the collated batch will contain a field 'src_lang_id' in 'net_input' which indicates the source language of the samples. tgt_lang_id (int, optional): target language ID, if set, the collated batch will contain a field 'tgt_lang_id' which indicates the target language of the samples. """ def __init__( self, src, src_sizes, src_dict, tgt=None, tgt_sizes=None, tgt_dict=None, left_pad_source=True, left_pad_target=False, shuffle=True, input_feeding=True, remove_eos_from_source=False, append_eos_to_target=False, align_dataset=None, constraints=None, append_bos=False, eos=None, num_buckets=0, src_lang_id=None, tgt_lang_id=None, pad_to_multiple=1, add_lang_token=False, lang_pair=None, shuffle_lang_pair=False, args=None, ): if tgt_dict is not None: assert src_dict.pad() == tgt_dict.pad() assert src_dict.eos() == tgt_dict.eos() assert src_dict.unk() == tgt_dict.unk() if tgt is not None: assert len(src) == len( tgt ), "Source and target must contain the same number of examples" self.src = src self.tgt = tgt self.src_sizes = np.array(src_sizes) self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None self.sizes = ( np.vstack((self.src_sizes, self.tgt_sizes)).T if self.tgt_sizes is not None else self.src_sizes ) self.src_dict = src_dict self.tgt_dict = tgt_dict self.left_pad_source = left_pad_source self.left_pad_target = left_pad_target self.shuffle = shuffle self.input_feeding = input_feeding self.remove_eos_from_source = remove_eos_from_source self.append_eos_to_target = append_eos_to_target self.align_dataset = align_dataset self.add_lang_token = add_lang_token self.lang_pair = lang_pair self.args=args if self.align_dataset is not None: assert ( self.tgt_sizes is not None ), "Both source and target needed when alignments are provided" self.constraints = constraints self.append_bos = append_bos self.eos = eos if eos is not None else src_dict.eos() self.src_lang_id = src_lang_id self.tgt_lang_id = tgt_lang_id if num_buckets > 0: from fairseq.data import BucketPadLengthDataset self.src = BucketPadLengthDataset( self.src, sizes=self.src_sizes, num_buckets=num_buckets, pad_idx=self.src_dict.pad(), left_pad=self.left_pad_source, ) self.src_sizes = self.src.sizes logger.info("bucketing source lengths: {}".format(list(self.src.buckets))) if self.tgt is not None: self.tgt = BucketPadLengthDataset( self.tgt, sizes=self.tgt_sizes, num_buckets=num_buckets, pad_idx=self.tgt_dict.pad(), left_pad=self.left_pad_target, ) self.tgt_sizes = self.tgt.sizes logger.info( "bucketing target lengths: {}".format(list(self.tgt.buckets)) ) # determine bucket sizes using self.num_tokens, which will return # the padded lengths (thanks to BucketPadLengthDataset) num_tokens = np.vectorize(self.num_tokens, otypes=[np.long]) self.bucketed_num_tokens = num_tokens(np.arange(len(self.src))) self.buckets = [ (None, num_tokens) for num_tokens in np.unique(self.bucketed_num_tokens) ] else: self.buckets = None self.pad_to_multiple = pad_to_multiple self.shuffle_lang_pair = shuffle_lang_pair self.src_mask_index = self.src_dict.index("<mask>") self.tgt_mask_index = self.tgt_dict.index("<mask>") self.padding_index = self.tgt_dict.index("<pad>") self.ratio = 0.5 if args != None and args.task == "translation_from_pretrained_maskdecode_multi": self.random_ratio = args.mask_random self.replace_length = args.replace_length if self.replace_length not in [-1, 0, 1]: raise ValueError(f"invalid arg: replace_length={self.replace_length}") if args.mask_length not in ["subword", "word", "span-poisson"]: raise ValueError(f"invalid arg: mask-length={args.mask_length}") if args.mask_length == "subword" and args.replace_length not in [0, 1]: raise ValueError(f"if using subwords, use replace-length=1 or 0") self.mask_span_distribution = None if self.args.mask_length == "span-poisson": _lambda = self.args.poisson_lambda lambda_to_the_k = 1 e_to_the_minus_lambda = math.exp(-_lambda) k_factorial = 1 ps = [] for k in range(0, 128): ps.append(e_to_the_minus_lambda * lambda_to_the_k / k_factorial) lambda_to_the_k *= _lambda k_factorial *= k + 1 if ps[-1] < 0.0000001: break ps = torch.FloatTensor(ps) self.mask_span_distribution = torch.distributions.Categorical(ps) #print(self.mask_span_distribution) def get_batch_shapes(self): return self.buckets def word_starts(self, source): is_word_start=[] sources = [self.tgt_dict[token] for token in source] if True: for token in sources: #print(token) if token[0] == "▁": is_word_start.append(1) else: is_word_start.append(0) is_word_start = torch.tensor(is_word_start) else: is_word_start = torch.ones(source.size()) is_word_start[0] = 0 is_word_start[-1] = 0 return is_word_start def add_whole_word_mask(self, source, p): is_word_start = self.word_starts(source) num_to_mask = int(math.ceil(is_word_start.float().sum() * p)) # print(num_to_mask) num_inserts = 0 if num_to_mask == 0: return source if self.mask_span_distribution is not None: lengths = self.mask_span_distribution.sample(sample_shape=(num_to_mask,)) # print("mask_length:{}".format(lengths)) # exit(0) # Make sure we have enough to mask cum_length = torch.cumsum(lengths, 0) while cum_length[-1] < num_to_mask: lengths = torch.cat( [ lengths, self.mask_span_distribution.sample(sample_shape=(num_to_mask,)), ], dim=0, ) cum_length = torch.cumsum(lengths, 0) # Trim to masking budget i = 0 while cum_length[i] < num_to_mask: i += 1 lengths[i] = num_to_mask - (0 if i == 0 else cum_length[i - 1]) num_to_mask = i + 1 lengths = lengths[:num_to_mask] # Handle 0-length mask (inserts) separately lengths = lengths[lengths > 0] num_inserts = num_to_mask - lengths.size(0) num_to_mask -= num_inserts if num_to_mask == 0: return self.add_insertion_noise(source, num_inserts / source.size(0)) assert (lengths > 0).all() else: lengths = torch.ones((num_to_mask,)).long() assert is_word_start[-1] == 0 word_starts = is_word_start.nonzero(as_tuple=False) indices = word_starts[ torch.randperm(word_starts.size(0))[:num_to_mask] ].squeeze(1) mask_random = torch.FloatTensor(num_to_mask).uniform_() < self.random_ratio source_length = source.size(0) assert source_length - 1 not in indices to_keep = torch.ones(source_length, dtype=torch.bool) is_word_start[ -1 ] = 255 # acts as a long length, so spans don't go over the end of doc if self.replace_length == 0: to_keep[indices] = 0 else: # keep index, but replace it with [MASK] source[indices] = self.tgt_mask_index source[indices[mask_random]] = torch.randint( 1, len(self.tgt_dict), size=(mask_random.sum(),) ) if self.mask_span_distribution is not None: assert len(lengths.size()) == 1 assert lengths.size() == indices.size() lengths -= 1 while indices.size(0) > 0: assert lengths.size() == indices.size() lengths -= is_word_start[indices + 1].long() uncompleted = lengths >= 0 indices = indices[uncompleted] + 1 mask_random = mask_random[uncompleted] lengths = lengths[uncompleted] if self.replace_length != -1: # delete token to_keep[indices] = 0 else: # keep index, but replace it with [MASK] source[indices] = self.tgt_mask_index source[indices[mask_random]] = torch.randint( 1, len(self.tgt_dict), size=(mask_random.sum(),) ) else: # A bit faster when all lengths are 1 while indices.size(0) > 0: uncompleted = is_word_start[indices + 1] == 0 indices = indices[uncompleted] + 1 mask_random = mask_random[uncompleted] if self.replace_length != -1: # delete token to_keep[indices] = 0 else: # keep index, but replace it with [MASK] source[indices] = self.tgt_mask_index source[indices[mask_random]] = torch.randint( 1, len(self.tgt_dict), size=(mask_random.sum(),) ) assert source_length - 1 not in indices source = source[to_keep] if num_inserts > 0: source = self.add_insertion_noise(source, num_inserts / source.size(0)) return source def add_insertion_noise(self, tokens, p): if p == 0.0: return tokens num_tokens = len(tokens) n = int(math.ceil(num_tokens * p)) noise_indices = torch.randperm(num_tokens + n - 2)[:n] + 1 noise_mask = torch.zeros(size=(num_tokens + n,), dtype=torch.bool) noise_mask[noise_indices] = 1 result = torch.LongTensor(n + len(tokens)).fill_(-1) num_random = int(math.ceil(n * self.random_ratio)) result[noise_indices[num_random:]] = self.tgt_mask_index result[noise_indices[:num_random]] = torch.randint( low=1, high=len(self.tgt_dict), size=(num_random,) ) result[~noise_mask] = tokens assert (result >= 0).all() return result def add_rolling_noise(self, tokens): offset = np.random.randint(1, max(1, tokens.size(-1) - 1) + 1) tokens = torch.cat( (tokens[0:1], tokens[offset:-1], tokens[1:offset], tokens[-1:]), dim=0, ) return tokens def denoising_mono(self,src_tokens): source=src_tokens target=src_tokens.tolist().copy() source = self.add_whole_word_mask(source, self.args.mask) source = self.add_rolling_noise(source) return torch.tensor(source),torch.tensor(target) def span_masking_mono(self, src_tokens): src_list = src_tokens.tolist() lang_id = src_list[0] start, length = self.mask_interval(len(src_list)) source = [] for i, w in enumerate(src_list): if i >= start and i < start + length: w = self.mask_word(w) if w is not None: source.append(w) output = [self.padding_index] * len(src_list) output[0] = lang_id output[start:start + length] = src_list[start:start + length].copy() if start + length < len(src_list) and output[start + length] == self.padding_index: output[start + length] = self.tgt_dict.eos_index target = [self.tgt_mask_index] * len(src_list) target[0] = lang_id target[-1] = self.tgt_dict.eos_index target[start:start + length] = src_list[start: start + length].copy() # print(start) # print("src_raw:{}\nsrc_input:{}\ntgt_out:{}\ntgt_input:{}\n".format( # " ".join([self.tgt_dict[vocab_i] for vocab_i in src_list]), # " ".join([self.tgt_dict[vocab_i] for vocab_i in source]), # " ".join([self.tgt_dict[vocab_i] for vocab_i in output]), # " ".join([self.tgt_dict[vocab_i] for vocab_i in target]))) assert len(target) == len(output) return torch.tensor(source), torch.tensor(target), torch.tensor(output) def __getitem__(self, index): tgt_item = self.tgt[index] if self.tgt is not None else None src_item = self.src[index] prev_output_tokens = None is_swap = random.random() if self.lang_pair == None or self.lang_pair[0] != self.lang_pair[1]: if is_swap > 0.5 and self.shuffle_lang_pair: src_item, tgt_item = tgt_item, src_item else: if self.args.bart_mono: src_item,tgt_item = self.denoising_mono(src_item) else: src_item, prev_output_tokens, tgt_item = self.span_masking_mono(src_item) # Append EOS to end of tgt sentence if it does not have an EOS and remove # EOS from end of src sentence if it exists. This is useful when we use # use existing datasets for opposite directions i.e., when we want to # use tgt_dataset as src_dataset and vice versa # print(self.append_eos_to_target, self.append_bos, self.remove_eos_from_source) : all False if self.append_eos_to_target: eos = self.tgt_dict.eos() if self.tgt_dict else self.src_dict.eos() if self.tgt and self.tgt[index][-1] != eos: tgt_item = torch.cat([self.tgt[index], torch.LongTensor([eos])]) if self.append_bos: bos = self.tgt_dict.bos() if self.tgt_dict else self.src_dict.bos() if self.tgt and self.tgt[index][0] != bos: tgt_item = torch.cat([torch.LongTensor([bos]), self.tgt[index]]) bos = self.src_dict.bos() if self.src[index][0] != bos: src_item = torch.cat([torch.LongTensor([bos]), self.src[index]]) if self.remove_eos_from_source: eos = self.src_dict.eos() if self.src[index][-1] == eos: src_item = self.src[index][:-1] example = { "id": index, "source": src_item, "target": tgt_item, "prev_output_tokens": prev_output_tokens, } # print("src_input:{}\ntgt_input:{}\n".format( # " ".join([self.src_dict[vocab_i] for vocab_i in src_item]), # " ".join([self.tgt_dict[vocab_i] for vocab_i in tgt_item]))) if self.align_dataset is not None: example["alignment"] = self.align_dataset[index] if self.constraints is not None: example["constraints"] = self.constraints[index] return example def __len__(self): return len(self.src) def collater(self, samples, pad_to_length=None): """Merge a list of samples to form a mini-batch. Args: samples (List[dict]): samples to collate pad_to_length (dict, optional): a dictionary of {'source': source_pad_to_length, 'target': target_pad_to_length} to indicate the max length to pad to in source and target respectively. Returns: dict: a mini-batch with the following keys: - `id` (LongTensor): example IDs in the original input order - `ntokens` (int): total number of tokens in the batch - `net_input` (dict): the input to the Model, containing keys: - `src_tokens` (LongTensor): a padded 2D Tensor of tokens in the source sentence of shape `(bsz, src_len)`. Padding will appear on the left if *left_pad_source* is ``True``. - `src_lengths` (LongTensor): 1D Tensor of the unpadded lengths of each source sentence of shape `(bsz)` - `prev_output_tokens` (LongTensor): a padded 2D Tensor of tokens in the target sentence, shifted right by one position for teacher forcing, of shape `(bsz, tgt_len)`. This key will not be present if *input_feeding* is ``False``. Padding will appear on the left if *left_pad_target* is ``True``. - `src_lang_id` (LongTensor): a long Tensor which contains source language IDs of each sample in the batch - `target` (LongTensor): a padded 2D Tensor of tokens in the target sentence of shape `(bsz, tgt_len)`. Padding will appear on the left if *left_pad_target* is ``True``. - `tgt_lang_id` (LongTensor): a long Tensor which contains target language IDs of each sample in the batch """ res = collate( samples, pad_idx=self.src_dict.pad(), eos_idx=self.eos, left_pad_source=self.left_pad_source, left_pad_target=self.left_pad_target, input_feeding=self.input_feeding, pad_to_length=pad_to_length, pad_to_multiple=self.pad_to_multiple, ) if self.src_lang_id is not None or self.tgt_lang_id is not None: src_tokens = res["net_input"]["src_tokens"] bsz = src_tokens.size(0) if self.src_lang_id is not None: res["net_input"]["src_lang_id"] = ( torch.LongTensor([[self.src_lang_id]]).expand(bsz, 1).to(src_tokens) ) if self.tgt_lang_id is not None: res["tgt_lang_id"] = ( torch.LongTensor([[self.tgt_lang_id]]).expand(bsz, 1).to(src_tokens) ) # for i in range(len(res["id"])): # print("src_raw:{}\ntgt_output:{}\ntgt_input:{}\n".format( # " ".join([self.src_dict[vocab_i] for vocab_i in res["net_input"]["src_tokens"][i]]), # " ".join([self.tgt_dict[vocab_i] for vocab_i in res["target"][i]]), # " ".join([self.tgt_dict[vocab_i] for vocab_i in res["net_input"]["prev_output_tokens"][i]]),)) # exit(0) return res def num_tokens(self, index): """Return the number of tokens in a sample. This value is used to enforce ``--max-tokens`` during batching.""" return max( self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0, ) def size(self, index): """Return an example's size as a float or tuple. This value is used when filtering a dataset with ``--max-positions``.""" return ( self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0, ) def ordered_indices(self): """Return an ordered list of indices. Batches will be constructed based on this order.""" if self.shuffle: indices = np.random.permutation(len(self)).astype(np.int64) else: indices = np.arange(len(self), dtype=np.int64) if self.buckets is None: # sort by target length, then source length if self.tgt_sizes is not None: indices = indices[np.argsort(self.tgt_sizes[indices], kind="mergesort")] return indices[np.argsort(self.src_sizes[indices], kind="mergesort")] else: # sort by bucketed_num_tokens, which is: # max(padded_src_len, padded_tgt_len) return indices[ np.argsort(self.bucketed_num_tokens[indices], kind="mergesort") ] @property def supports_prefetch(self): return getattr(self.src, "supports_prefetch", False) and ( getattr(self.tgt, "supports_prefetch", False) or self.tgt is None ) def prefetch(self, indices): self.src.prefetch(indices) if self.tgt is not None: self.tgt.prefetch(indices) if self.align_dataset is not None: self.align_dataset.prefetch(indices) def filter_indices_by_size(self, indices, max_sizes): """Filter a list of sample indices. Remove those that are longer than specified in max_sizes. Args: indices (np.array): original array of sample indices max_sizes (int or list[int] or tuple[int]): max sample size, can be defined separately for src and tgt (then list or tuple) Returns: np.array: filtered sample array list: list of removed indices """ return data_utils.filter_paired_dataset_indices_by_size( self.src_sizes, self.tgt_sizes, indices, max_sizes, ) def mask_start(self, end): p = np.random.random() if p >= 0.8 or 1 >= end: return 1 elif p >= 0.6: return end else: return np.random.randint(1, end) def mask_word(self, w): p = np.random.random() if p >= 0.2: return self.src_mask_index elif p >= 0.1: return np.random.randint(self.tgt_dict.nspecial, len(self.tgt_dict)) else: return w def random_word(self, w, pred_probs): self.pred_probs = [0, 0, 1] cands = [self.src_mask_index, np.random.randint(self.tgt_dict.nspecial, len(self.tgt_dict)), w] prob = torch.multinomial(self.pred_probs, 1, replacement=True) return cands[prob] def mask_interval(self, l): mask_length = round(l * self.ratio) mask_length = max(1, mask_length) mask_start = self.mask_start(l - mask_length) return mask_start, mask_length