def collater(self, samples: List[SpeechToSpeechDatasetItem], return_order: bool = False) -> Dict: if len(samples) == 0: return {} indices = torch.tensor([x.index for x in samples], dtype=torch.long) frames = _collate_frames([x.source for x in samples], self.cfg.use_audio_input) # sort samples by descending number of frames n_frames = torch.tensor([x.source.size(0) for x in samples], dtype=torch.long) n_frames, order = n_frames.sort(descending=True) indices = indices.index_select(0, order) frames = frames.index_select(0, order) target, prev_output_tokens, target_lengths = self._collate_target( samples) target = target.index_select(0, order) target_lengths = target_lengths.index_select(0, order) prev_output_tokens = prev_output_tokens.index_select(0, order) ntokens = sum(x.target.size(0) for x in samples) tgt_speakers = None if self.cfg.target_speaker_embed: tgt_speakers = _collate_frames([x.target_speaker for x in samples], is_audio_input=True).index_select( 0, order) net_input = { "src_tokens": frames, "src_lengths": n_frames, "prev_output_tokens": prev_output_tokens, "tgt_speaker": tgt_speakers, # TODO: unify "speaker" and "tgt_speaker" } if self.tgt_texts is not None and samples[0].tgt_lang_tag is not None: for i in range(len(samples)): net_input["prev_output_tokens"][i][0] = samples[ order[i]].tgt_lang_tag out = { "id": indices, "net_input": net_input, "speaker": tgt_speakers, # to support Tacotron2 loss for speech-to-spectrogram model "target": target, "target_lengths": target_lengths, "ntokens": ntokens, "nsentences": len(samples), } if return_order: out["order"] = order return out
def _collate_target( self, samples: List[SpeechToSpeechDatasetItem]) -> torch.Tensor: if self.target_is_code: target = fairseq_data_utils.collate_tokens( [x.target for x in samples], self.tgt_dict.pad(), self.tgt_dict.eos(), left_pad=False, move_eos_to_beginning=False, ) # convert stacked units to a single id pack_targets = [self.pack_units(x.target) for x in samples] prev_output_tokens = fairseq_data_utils.collate_tokens( pack_targets, self.tgt_dict.pad(), self.tgt_dict.eos(), left_pad=False, move_eos_to_beginning=True, ) target_lengths = torch.tensor([x.size(0) for x in pack_targets], dtype=torch.long) else: target = _collate_frames([x.target for x in samples], is_audio_input=False) bsz, _, d = target.size() prev_output_tokens = torch.cat((target.new_full( (bsz, 1, d), 0.0), target[:, :-1, :]), dim=1) target_lengths = torch.tensor([x.target.size(0) for x in samples], dtype=torch.long) return target, prev_output_tokens, target_lengths
def collater(self, samples: List[TextToSpeechDatasetItem]) -> Dict[str, Any]: if len(samples) == 0: return {} src_lengths, order = torch.tensor( [s.target.shape[0] for s in samples], dtype=torch.long).sort(descending=True) id_ = torch.tensor([s.index for s in samples], dtype=torch.long).index_select(0, order) feat = _collate_frames([s.source for s in samples], self.cfg.use_audio_input).index_select( 0, order) target_lengths = torch.tensor([s.source.shape[0] for s in samples], dtype=torch.long).index_select(0, order) src_tokens = fairseq_data_utils.collate_tokens( [s.target for s in samples], self.tgt_dict.pad(), self.tgt_dict.eos(), left_pad=False, move_eos_to_beginning=False, ).index_select(0, order) speaker = None if self.speaker_to_id is not None: speaker = torch.tensor([s.speaker_id for s in samples], dtype=torch.long).index_select(0, order).view( -1, 1) bsz, _, d = feat.size() prev_output_tokens = torch.cat((feat.new_zeros( (bsz, 1, d)), feat[:, :-1, :]), dim=1) durations, pitches, energies = None, None, None if self.durations is not None: durations = fairseq_data_utils.collate_tokens( [s.duration for s in samples], 0).index_select(0, order) assert src_tokens.shape[1] == durations.shape[1] if self.pitches is not None: pitches = _collate_frames([s.pitch for s in samples], True) pitches = pitches.index_select(0, order) assert src_tokens.shape[1] == pitches.shape[1] if self.energies is not None: energies = _collate_frames([s.energy for s in samples], True) energies = energies.index_select(0, order) assert src_tokens.shape[1] == energies.shape[1] src_texts = [self.tgt_dict.string(samples[i].target) for i in order] return { "id": id_, "net_input": { "src_tokens": src_tokens, "src_lengths": src_lengths, "prev_output_tokens": prev_output_tokens, }, "speaker": speaker, "target": feat, "durations": durations, "pitches": pitches, "energies": energies, "target_lengths": target_lengths, "ntokens": sum(target_lengths).item(), "nsentences": len(samples), "src_texts": src_texts, }
def collater( self, samples: List[Tuple[int, torch.Tensor, torch.Tensor]]) -> Dict: super().collater(samples) if len(samples) == 0: return {} indices = torch.tensor([i for i, _, _ in samples], dtype=torch.long) frames = _collate_frames([s for _, s, _ in samples], self.data_cfg.use_audio_input) # sort samples by descending number of frames n_frames = torch.tensor([s.size(0) for _, s, _ in samples], dtype=torch.long) n_frames, order = n_frames.sort(descending=True) indices = indices.index_select(0, order) frames = frames.index_select(0, order) target, target_lengths = None, None prev_output_tokens = None ntokens = None if self.tgt_texts is not None: target = fairseq_data_utils.collate_tokens( [t for _, _, t in samples], self.tgt_dict.pad(), self.tgt_dict.eos(), left_pad=False, move_eos_to_beginning=False, ) target = target.index_select(0, order) target_lengths = torch.tensor([t.size(0) for _, _, t in samples], dtype=torch.long).index_select( 0, order) prev_output_tokens = fairseq_data_utils.collate_tokens( [t for _, _, t in samples], self.tgt_dict.pad(), self.tgt_dict.eos(), left_pad=False, move_eos_to_beginning=True, ) prev_output_tokens = prev_output_tokens.index_select(0, order) ntokens = sum(t.size(0) for _, _, t in samples) # Source transcripts transcript, transcript_lengths = None, None prev_transcript_tokens = None ntokens_transcript = None if self.src_texts is not None: transcript = fairseq_data_utils.collate_tokens( [t for _, _, t in samples], self.src_dict.pad(), self.src_dict.eos(), left_pad=False, move_eos_to_beginning=False, ) transcript = transcript.index_select(0, order) transcript_lengths = torch.tensor( [t.size(0) for _, _, t in samples], dtype=torch.long).index_select(0, order) prev_transcript_tokens = fairseq_data_utils.collate_tokens( [t for _, _, t in samples], self.src_dict.pad(), self.src_dict.eos(), left_pad=False, move_eos_to_beginning=True, ) prev_transcript_tokens = prev_transcript_tokens.index_select( 0, order) ntokens_transcript = sum(t.size(0) for _, _, t in samples) out = { "id": indices, "net_input": { "src_tokens": frames, "src_lengths": n_frames, "prev_output_tokens": prev_output_tokens, "prev_transcript_tokens": prev_transcript_tokens, }, "target": target, "target_lengths": target_lengths, "transcript": transcript, "transcript_lengths": transcript_lengths, "ntokens": ntokens, "ntokens_transcript": ntokens_transcript, "nsentences": len(samples), } return out
def collater( self, samples: List[Tuple[int, torch.Tensor, torch.Tensor]]) -> Dict: if len(samples) == 0: return {} indices = torch.tensor([i for i, _, _, _, _ in samples], dtype=torch.long) frames = _collate_frames([s for _, s, _, _, _ in samples], self.data_cfg.use_audio_input) tokens_masked = torch.tensor([i for _, _, _, i, _ in samples]) hit = torch.tensor([i for _, _, _, _, i in samples]) ntokens_masked = torch.sum(tokens_masked) nhit = torch.sum(hit) n_frames = torch.tensor([s.size(0) for _, s, _, _, _ in samples], dtype=torch.long) n_frames, order = n_frames.sort(descending=True) indices = indices.index_select(0, order) frames = frames.index_select(0, order) target, target_lengths = None, None prev_output_tokens = None ntokens = None if self.tgt_texts is not None: target = fairseq_data_utils.collate_tokens( [t for _, _, t, _, _ in samples], self.tgt_dict.pad(), self.tgt_dict.eos(), left_pad=False, move_eos_to_beginning=False, ) target = target.index_select(0, order) target_lengths = torch.tensor( [t.size(0) for _, _, t, _, _ in samples], dtype=torch.long).index_select(0, order) prev_output_tokens = fairseq_data_utils.collate_tokens( [t for _, _, t, _, _ in samples], self.tgt_dict.pad(), self.tgt_dict.eos(), left_pad=False, move_eos_to_beginning=True, ) prev_output_tokens = prev_output_tokens.index_select(0, order) ntokens = sum(t.size(0) for _, _, t, _, _ in samples) out = { "id": indices, "net_input": { "src_tokens": frames, "src_lengths": n_frames, "prev_output_tokens": prev_output_tokens, }, "target": target, "target_lengths": target_lengths, "ntokens": ntokens, "nsentences": len(samples), "ntokens_masked": ntokens_masked, "nhit": nhit } return out