def collater(self,
                 samples: List[SpeechToSpeechDatasetItem],
                 return_order: bool = False) -> Dict:
        if len(samples) == 0:
            return {}
        indices = torch.tensor([x.index for x in samples], dtype=torch.long)
        frames = _collate_frames([x.source for x in samples],
                                 self.cfg.use_audio_input)
        # sort samples by descending number of frames
        n_frames = torch.tensor([x.source.size(0) for x in samples],
                                dtype=torch.long)
        n_frames, order = n_frames.sort(descending=True)
        indices = indices.index_select(0, order)
        frames = frames.index_select(0, order)

        target, prev_output_tokens, target_lengths = self._collate_target(
            samples)
        target = target.index_select(0, order)
        target_lengths = target_lengths.index_select(0, order)
        prev_output_tokens = prev_output_tokens.index_select(0, order)
        ntokens = sum(x.target.size(0) for x in samples)

        tgt_speakers = None
        if self.cfg.target_speaker_embed:
            tgt_speakers = _collate_frames([x.target_speaker for x in samples],
                                           is_audio_input=True).index_select(
                                               0, order)

        net_input = {
            "src_tokens": frames,
            "src_lengths": n_frames,
            "prev_output_tokens": prev_output_tokens,
            "tgt_speaker":
            tgt_speakers,  # TODO: unify "speaker" and "tgt_speaker"
        }
        if self.tgt_texts is not None and samples[0].tgt_lang_tag is not None:
            for i in range(len(samples)):
                net_input["prev_output_tokens"][i][0] = samples[
                    order[i]].tgt_lang_tag
        out = {
            "id": indices,
            "net_input": net_input,
            "speaker":
            tgt_speakers,  # to support Tacotron2 loss for speech-to-spectrogram model
            "target": target,
            "target_lengths": target_lengths,
            "ntokens": ntokens,
            "nsentences": len(samples),
        }
        if return_order:
            out["order"] = order
        return out
Beispiel #2
0
    def _collate_target(
            self, samples: List[SpeechToSpeechDatasetItem]) -> torch.Tensor:
        if self.target_is_code:
            target = fairseq_data_utils.collate_tokens(
                [x.target for x in samples],
                self.tgt_dict.pad(),
                self.tgt_dict.eos(),
                left_pad=False,
                move_eos_to_beginning=False,
            )
            # convert stacked units to a single id
            pack_targets = [self.pack_units(x.target) for x in samples]
            prev_output_tokens = fairseq_data_utils.collate_tokens(
                pack_targets,
                self.tgt_dict.pad(),
                self.tgt_dict.eos(),
                left_pad=False,
                move_eos_to_beginning=True,
            )
            target_lengths = torch.tensor([x.size(0) for x in pack_targets],
                                          dtype=torch.long)
        else:
            target = _collate_frames([x.target for x in samples],
                                     is_audio_input=False)
            bsz, _, d = target.size()
            prev_output_tokens = torch.cat((target.new_full(
                (bsz, 1, d), 0.0), target[:, :-1, :]),
                                           dim=1)
            target_lengths = torch.tensor([x.target.size(0) for x in samples],
                                          dtype=torch.long)

        return target, prev_output_tokens, target_lengths
    def collater(self,
                 samples: List[TextToSpeechDatasetItem]) -> Dict[str, Any]:
        if len(samples) == 0:
            return {}

        src_lengths, order = torch.tensor(
            [s.target.shape[0] for s in samples],
            dtype=torch.long).sort(descending=True)
        id_ = torch.tensor([s.index for s in samples],
                           dtype=torch.long).index_select(0, order)
        feat = _collate_frames([s.source for s in samples],
                               self.cfg.use_audio_input).index_select(
                                   0, order)
        target_lengths = torch.tensor([s.source.shape[0] for s in samples],
                                      dtype=torch.long).index_select(0, order)

        src_tokens = fairseq_data_utils.collate_tokens(
            [s.target for s in samples],
            self.tgt_dict.pad(),
            self.tgt_dict.eos(),
            left_pad=False,
            move_eos_to_beginning=False,
        ).index_select(0, order)

        speaker = None
        if self.speaker_to_id is not None:
            speaker = torch.tensor([s.speaker_id for s in samples],
                                   dtype=torch.long).index_select(0,
                                                                  order).view(
                                                                      -1, 1)

        bsz, _, d = feat.size()
        prev_output_tokens = torch.cat((feat.new_zeros(
            (bsz, 1, d)), feat[:, :-1, :]),
                                       dim=1)

        durations, pitches, energies = None, None, None
        if self.durations is not None:
            durations = fairseq_data_utils.collate_tokens(
                [s.duration for s in samples], 0).index_select(0, order)
            assert src_tokens.shape[1] == durations.shape[1]
        if self.pitches is not None:
            pitches = _collate_frames([s.pitch for s in samples], True)
            pitches = pitches.index_select(0, order)
            assert src_tokens.shape[1] == pitches.shape[1]
        if self.energies is not None:
            energies = _collate_frames([s.energy for s in samples], True)
            energies = energies.index_select(0, order)
            assert src_tokens.shape[1] == energies.shape[1]
        src_texts = [self.tgt_dict.string(samples[i].target) for i in order]

        return {
            "id": id_,
            "net_input": {
                "src_tokens": src_tokens,
                "src_lengths": src_lengths,
                "prev_output_tokens": prev_output_tokens,
            },
            "speaker": speaker,
            "target": feat,
            "durations": durations,
            "pitches": pitches,
            "energies": energies,
            "target_lengths": target_lengths,
            "ntokens": sum(target_lengths).item(),
            "nsentences": len(samples),
            "src_texts": src_texts,
        }
Beispiel #4
0
    def collater(
            self, samples: List[Tuple[int, torch.Tensor,
                                      torch.Tensor]]) -> Dict:
        super().collater(samples)
        if len(samples) == 0:
            return {}
        indices = torch.tensor([i for i, _, _ in samples], dtype=torch.long)
        frames = _collate_frames([s for _, s, _ in samples],
                                 self.data_cfg.use_audio_input)
        # sort samples by descending number of frames
        n_frames = torch.tensor([s.size(0) for _, s, _ in samples],
                                dtype=torch.long)
        n_frames, order = n_frames.sort(descending=True)
        indices = indices.index_select(0, order)
        frames = frames.index_select(0, order)

        target, target_lengths = None, None
        prev_output_tokens = None
        ntokens = None
        if self.tgt_texts is not None:
            target = fairseq_data_utils.collate_tokens(
                [t for _, _, t in samples],
                self.tgt_dict.pad(),
                self.tgt_dict.eos(),
                left_pad=False,
                move_eos_to_beginning=False,
            )
            target = target.index_select(0, order)
            target_lengths = torch.tensor([t.size(0) for _, _, t in samples],
                                          dtype=torch.long).index_select(
                                              0, order)
            prev_output_tokens = fairseq_data_utils.collate_tokens(
                [t for _, _, t in samples],
                self.tgt_dict.pad(),
                self.tgt_dict.eos(),
                left_pad=False,
                move_eos_to_beginning=True,
            )
            prev_output_tokens = prev_output_tokens.index_select(0, order)
            ntokens = sum(t.size(0) for _, _, t in samples)

        # Source transcripts
        transcript, transcript_lengths = None, None
        prev_transcript_tokens = None
        ntokens_transcript = None
        if self.src_texts is not None:
            transcript = fairseq_data_utils.collate_tokens(
                [t for _, _, t in samples],
                self.src_dict.pad(),
                self.src_dict.eos(),
                left_pad=False,
                move_eos_to_beginning=False,
            )
            transcript = transcript.index_select(0, order)
            transcript_lengths = torch.tensor(
                [t.size(0) for _, _, t in samples],
                dtype=torch.long).index_select(0, order)
            prev_transcript_tokens = fairseq_data_utils.collate_tokens(
                [t for _, _, t in samples],
                self.src_dict.pad(),
                self.src_dict.eos(),
                left_pad=False,
                move_eos_to_beginning=True,
            )
            prev_transcript_tokens = prev_transcript_tokens.index_select(
                0, order)
            ntokens_transcript = sum(t.size(0) for _, _, t in samples)

        out = {
            "id": indices,
            "net_input": {
                "src_tokens": frames,
                "src_lengths": n_frames,
                "prev_output_tokens": prev_output_tokens,
                "prev_transcript_tokens": prev_transcript_tokens,
            },
            "target": target,
            "target_lengths": target_lengths,
            "transcript": transcript,
            "transcript_lengths": transcript_lengths,
            "ntokens": ntokens,
            "ntokens_transcript": ntokens_transcript,
            "nsentences": len(samples),
        }
        return out
Beispiel #5
0
    def collater(
            self, samples: List[Tuple[int, torch.Tensor,
                                      torch.Tensor]]) -> Dict:
        if len(samples) == 0:
            return {}

        indices = torch.tensor([i for i, _, _, _, _ in samples],
                               dtype=torch.long)
        frames = _collate_frames([s for _, s, _, _, _ in samples],
                                 self.data_cfg.use_audio_input)

        tokens_masked = torch.tensor([i for _, _, _, i, _ in samples])
        hit = torch.tensor([i for _, _, _, _, i in samples])

        ntokens_masked = torch.sum(tokens_masked)
        nhit = torch.sum(hit)

        n_frames = torch.tensor([s.size(0) for _, s, _, _, _ in samples],
                                dtype=torch.long)
        n_frames, order = n_frames.sort(descending=True)
        indices = indices.index_select(0, order)
        frames = frames.index_select(0, order)

        target, target_lengths = None, None
        prev_output_tokens = None
        ntokens = None
        if self.tgt_texts is not None:
            target = fairseq_data_utils.collate_tokens(
                [t for _, _, t, _, _ in samples],
                self.tgt_dict.pad(),
                self.tgt_dict.eos(),
                left_pad=False,
                move_eos_to_beginning=False,
            )
            target = target.index_select(0, order)
            target_lengths = torch.tensor(
                [t.size(0) for _, _, t, _, _ in samples],
                dtype=torch.long).index_select(0, order)
            prev_output_tokens = fairseq_data_utils.collate_tokens(
                [t for _, _, t, _, _ in samples],
                self.tgt_dict.pad(),
                self.tgt_dict.eos(),
                left_pad=False,
                move_eos_to_beginning=True,
            )
            prev_output_tokens = prev_output_tokens.index_select(0, order)
            ntokens = sum(t.size(0) for _, _, t, _, _ in samples)

        out = {
            "id": indices,
            "net_input": {
                "src_tokens": frames,
                "src_lengths": n_frames,
                "prev_output_tokens": prev_output_tokens,
            },
            "target": target,
            "target_lengths": target_lengths,
            "ntokens": ntokens,
            "nsentences": len(samples),
            "ntokens_masked": ntokens_masked,
            "nhit": nhit
        }

        return out