Beispiel #1
0
    def forward(self, xs, ilens, ys, olens, spembs=None):
        """Calculate forward propagation.
        Args:
            xs (Tensor): Batch of the padded sequences of character ids (B, Tmax).
            ilens (Tensor): Batch of lengths of each input sequence (B,).
            ys (Tensor):
                Batch of the padded sequence of target features (B, Lmax, odim).
            olens (Tensor): Batch of lengths of each output sequence (B,).
            spembs (Tensor, optional):
                Batch of speaker embedding vectors (B, spk_embed_dim).
        Returns:
            Tensor: Batch of durations (B, Tmax).
        """
        if isinstance(self.teacher_model, Transformer):
            att_ws = self._calculate_encoder_decoder_attentions(xs,
                                                                ilens,
                                                                ys,
                                                                olens,
                                                                spembs=spembs)
            # TODO(kan-bayashi): fix this issue
            # this does not work in multi-gpu case. registered buffer is not saved.
            if int(self.diag_head_idx) == -1:
                self._init_diagonal_head(att_ws)
            att_ws = att_ws[:, self.diag_head_idx]
        else:
            # NOTE(kan-bayashi): Here we assume that the teacher is tacotron 2
            att_ws = self.teacher_model.calculate_all_attentions(
                xs, ilens, ys, spembs=spembs, keep_tensor=True)
        durations = [
            self._calculate_duration(att_w, ilen, olen)
            for att_w, ilen, olen in zip(att_ws, ilens, olens)
        ]

        return pad_list(durations, 0)
Beispiel #2
0
def common_collate_fn(
    data: Collection[Tuple[str, Dict[str, np.ndarray]]],
    float_pad_value: Union[float, int] = 0.0,
    int_pad_value: int = -32768,
    not_sequence: Collection[str] = (),
) -> Tuple[List[str], Dict[str, torch.Tensor]]:
    """Concatenate ndarray-list to an array and convert to torch.Tensor.

    Examples:
        >>> from muskit.samplers.constant_batch_sampler import ConstantBatchSampler,
        >>> import muskit.tasks.abs_task
        >>> from muskit.train.dataset import MuskitDataset
        >>> sampler = ConstantBatchSampler(...)
        >>> dataset = MuskitDataset(...)
        >>> keys = next(iter(sampler)
        >>> batch = [dataset[key] for key in keys]
        >>> batch = common_collate_fn(batch)
        >>> model(**batch)

        Note that the dict-keys of batch are propagated from
        that of the dataset as they are.

    """
    assert check_argument_types()
    uttids = [u for u, _ in data]
    data = [d for _, d in data]

    assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching"
    assert all(
        not k.endswith("_lengths") for k in data[0]
    ), f"*_lengths is reserved: {list(data[0])}"

    output = {}
    for key in data[0]:
        # NOTE(kamo):
        # Each models, which accepts these values finally, are responsible
        # to repaint the pad_value to the desired value for each tasks.
        if data[0][key].dtype.kind == "i":
            pad_value = int_pad_value
        else:
            pad_value = float_pad_value

        array_list = [d[key] for d in data]
        # Assume the first axis is length:
        # tensor_list: Batch x (Length, ...)
        tensor_list = [torch.from_numpy(a) for a in array_list]
        # tensor: (Batch, Length, ...)
        tensor = pad_list(tensor_list, pad_value)
        output[key] = tensor

        # lens: (Batch,)
        if key not in not_sequence:
            lens = torch.tensor([d[key].shape[0] for d in data], dtype=torch.long)
            output[key + "_lengths"] = lens

    output = (uttids, output)
    # logging.info(f'output:{output}')
    # TODO allow the tuple type
    # assert check_return_type(output)
    return output
Beispiel #3
0
def add_sos_eos(ys_pad, sos, eos, ignore_id):
    """Add <sos> and <eos> labels.

    :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
    :param int sos: index of <sos>
    :param int eos: index of <eos>
    :param int ignore_id: index of padding
    :return: padded tensor (B, Lmax)
    :rtype: torch.Tensor
    :return: padded tensor (B, Lmax)
    :rtype: torch.Tensor
    """
    from muskit.torch_utils.nets_utils import pad_list

    _sos = ys_pad.new([sos])
    _eos = ys_pad.new([eos])
    ys = [y[y != ignore_id] for y in ys_pad]  # parse padded ys
    ys_in = [torch.cat([_sos, y], dim=0) for y in ys]
    ys_out = [torch.cat([y, _eos], dim=0) for y in ys]
    return pad_list(ys_in, eos), pad_list(ys_out, ignore_id)
Beispiel #4
0
    def forward(
        self,
        input: torch.Tensor,
        input_lengths: torch.Tensor = None,
        feats_lengths: torch.Tensor = None,
        durations: torch.Tensor = None,
        durations_lengths: torch.Tensor = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # If not provide, we assume that the inputs have the same length
        if input_lengths is None:
            input_lengths = (input.new_ones(input.shape[0], dtype=torch.long) *
                             input.shape[1])

        # Domain-conversion: e.g. Stft: time -> time-freq
        input_stft, energy_lengths = self.stft(input, input_lengths)

        assert input_stft.dim() >= 4, input_stft.shape
        assert input_stft.shape[-1] == 2, input_stft.shape

        # input_stft: (..., F, 2) -> (..., F)
        input_power = input_stft[..., 0]**2 + input_stft[..., 1]**2
        # sum over frequency (B, N, F) -> (B, N)
        energy = torch.sqrt(torch.clamp(input_power.sum(dim=2), min=1.0e-10))

        # (Optional): Adjust length to match with the mel-spectrogram
        if feats_lengths is not None:
            energy = [
                self._adjust_num_frames(e[:el].view(-1), fl)
                for e, el, fl in zip(energy, energy_lengths, feats_lengths)
            ]
            energy_lengths = feats_lengths

        # (Optional): Average by duration to calculate token-wise energy
        if self.use_token_averaged_energy:
            durations = durations * self.reduction_factor
            energy = [
                self._average_by_duration(e[:el].view(-1), d)
                for e, el, d in zip(energy, energy_lengths, durations)
            ]
            energy_lengths = durations_lengths

        # Padding
        if isinstance(energy, list):
            energy = pad_list(energy, 0.0)

        # Return with the shape (B, T, 1)
        return energy.unsqueeze(-1), energy_lengths
Beispiel #5
0
    def forward(
        self,
        input: torch.Tensor,
        input_lengths: torch.Tensor = None,
        feats_lengths: torch.Tensor = None,
        durations: torch.Tensor = None,
        durations_lengths: torch.Tensor = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # If not provide, we assume that the inputs have the same length
        if input_lengths is None:
            input_lengths = (input.new_ones(input.shape[0], dtype=torch.long) *
                             input.shape[1])

        # F0 extraction
        pitch = [
            self._calculate_f0(x[:xl]) for x, xl in zip(input, input_lengths)
        ]

        # (Optional): Adjust length to match with the mel-spectrogram
        if feats_lengths is not None:
            pitch = [
                self._adjust_num_frames(p, fl).view(-1)
                for p, fl in zip(pitch, feats_lengths)
            ]

        # (Optional): Average by duration to calculate token-wise f0
        if self.use_token_averaged_f0:
            durations = durations * self.reduction_factor
            pitch = [
                self._average_by_duration(p, d).view(-1)
                for p, d in zip(pitch, durations)
            ]
            pitch_lengths = durations_lengths
        else:
            pitch_lengths = input.new_tensor([len(p) for p in pitch],
                                             dtype=torch.long)

        # Padding
        pitch = pad_list(pitch, 0.0)

        # Return with the shape (B, T, 1)
        return pitch.unsqueeze(-1), pitch_lengths
Beispiel #6
0
    def forward(self, xs, ds, alpha=1.0):
        """Calculate forward propagation.
        Args:
            xs (Tensor): Batch of sequences of char or phoneme embeddings (B, Tmax, D).
            ds (LongTensor): Batch of durations of each frame (B, T).
            alpha (float, optional): Alpha value to control speed of speech.
        Returns:
            Tensor: replicated input tensor based on durations (B, T*, D).
        """
        if alpha != 1.0:
            assert alpha > 0
            ds = torch.round(ds.float() * alpha).long()

        if ds.sum() == 0:
            logging.warning("predicted durations includes all 0 sequences. "
                            "fill the first element with 1.")
            # NOTE(kan-bayashi): This case must not be happened in teacher forcing.
            #   It will be happened in inference with a bad duration predictor.
            #   So we do not need to care the padded sequence case here.
            ds[ds.sum(dim=1).eq(0)] = 1

        repeat = [torch.repeat_interleave(x, d, dim=0) for x, d in zip(xs, ds)]
        return pad_list(repeat, self.pad_value)
Beispiel #7
0
    def forward(self, x: torch.Tensor, x_lengths: torch.Tensor = None):
        """Forward function.

        Args:
            x: (Batch, Time, Freq)
            x_lengths: (Batch,)
        """

        if x_lengths is None or all(le == x_lengths[0] for le in x_lengths):
            # Note that applying same warping for each sample
            y = time_warp(x, window=self.window, mode=self.mode)
        else:
            # FIXME(kamo): I have no idea to batchify Timewarp
            ys = []
            for i in range(x.size(0)):
                _y = time_warp(
                    x[i][None, : x_lengths[i]],
                    window=self.window,
                    mode=self.mode,
                )[0]
                ys.append(_y)
            y = pad_list(ys, 0.0)

        return y, x_lengths
Beispiel #8
0
    def forward(
        self,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        singing: torch.Tensor,
        singing_lengths: torch.Tensor,
        durations: Optional[torch.Tensor] = None,
        durations_lengths: Optional[torch.Tensor] = None,
        score: Optional[torch.Tensor] = None,
        score_lengths: Optional[torch.Tensor] = None,
        pitch: Optional[torch.Tensor] = None,
        pitch_lengths: Optional[torch.Tensor] = None,
        tempo: Optional[torch.Tensor] = None,
        tempo_lengths: Optional[torch.Tensor] = None,
        energy: Optional[torch.Tensor] = None,
        energy_lengths: Optional[torch.Tensor] = None,
        spembs: Optional[torch.Tensor] = None,
        sids: Optional[torch.Tensor] = None,
        lids: Optional[torch.Tensor] = None,
        flag_IsValid=False,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Caclualte outputs and return the loss tensor.
        Args:
            text (Tensor): Text index tensor (B, T_text).
            text_lengths (Tensor): Text length tensor (B,).
            singing (Tensor): Singing waveform tensor (B, T_wav).
            singing_lengths (Tensor): Singing length tensor (B,).
            duration (Optional[Tensor]): Duration tensor. - phone id sequence
            duration_lengths (Optional[Tensor]): Duration length tensor (B,).
            score (Optional[Tensor]): Duration tensor.
            score_lengths (Optional[Tensor]): Duration length tensor (B,).
            pitch (Optional[Tensor]): Pitch tensor.
            pitch_lengths (Optional[Tensor]): Pitch length tensor (B,).
            energy (Optional[Tensor]): Energy tensor.
            energy_lengths (Optional[Tensor]): Energy length tensor (B,).
            spembs (Optional[Tensor]): Speaker embedding tensor (B, D).
            sids (Optional[Tensor]): Speaker ID tensor (B, 1).
            lids (Optional[Tensor]): Language ID tensor (B, 1).
        Returns:
            Tensor: Loss scalar tensor.
            Dict[str, float]: Statistics to be monitored.
            Tensor: Weight tensor to summarize losses.
        """
        with autocast(False):
            # if self.text_extract is not None and text is None:
            #     text, text_lengths = self.text_extract(
            #         input=text,
            #         input_lengths=text_lengths,
            #     )
            # Extract features
            # logging.info(f'singing.shape={singing.shape}, singing_lengths.shape={singing_lengths.shape}')
            if self.feats_extract is not None:
                feats, feats_lengths = self.feats_extract(
                    singing, singing_lengths
                )  # singing to spec feature (frame level)
            else:
                # Use precalculated feats (feats_type != raw case)
                feats, feats_lengths = singing, singing_lengths

            # Extract auxiliary features
            # score : 128 midi pitch
            # tempo : bpm
            # duration :
            #   input-> phone-id seqence | output -> frame level(取众数 from window) or syllable level
            ds = None
            if isinstance(self.score_feats_extract, FrameScoreFeats):
                (
                    label,
                    label_lengths,
                    score,
                    score_lengths,
                    tempo,
                    tempo_lengths,
                ) = self.score_feats_extract(
                    durations=durations.unsqueeze(-1),
                    durations_lengths=durations_lengths,
                    score=score.unsqueeze(-1),
                    score_lengths=score_lengths,
                    tempo=tempo.unsqueeze(-1),
                    tempo_lengths=tempo_lengths,
                )

                label = label[:, : label_lengths.max()]  # for data-parallel

                # calculate durations, new text & text_length
                # Syllable Level duration info needs phone
                # NOTE(Shuai) Duplicate adjacent phones will appear in text files sometimes
                # e.g. oniku_0000000000000000hato_0002
                # 10.951 11.107 sh
                # 11.107 11.336 i
                # 11.336 11.610 i
                # 11.610 11.657 k
                _text_cal = []
                _text_length_cal = []
                ds = []
                for i, _ in enumerate(label_lengths):
                    _phone = label[i, : label_lengths[i]]

                    _output, counts = torch.unique_consecutive(
                        _phone, return_counts=True
                    )

                    _text_cal.append(_output)
                    _text_length_cal.append(len(_output))
                    ds.append(counts)
                ds = pad_list(ds, pad_value=0).to(text.device)
                text = pad_list(_text_cal, pad_value=0).to(
                    text.device, dtype=torch.long
                )
                text_lengths = torch.tensor(_text_length_cal).to(text.device)

            elif isinstance(self.score_feats_extract, SyllableScoreFeats):
                extractMethod_frame = FrameScoreFeats(
                    fs=self.score_feats_extract.fs,
                    n_fft=self.score_feats_extract.n_fft,
                    win_length=self.score_feats_extract.win_length,
                    hop_length=self.score_feats_extract.hop_length,
                    window=self.score_feats_extract.window,
                    center=self.score_feats_extract.center,
                )

                # logging.info(f"extractMethod_frame: {extractMethod_frame}")

                (
                    labelFrame,
                    labelFrame_lengths,
                    scoreFrame,
                    scoreFrame_lengths,
                    tempoFrame,
                    tempoFrame_lengths,
                ) = extractMethod_frame(
                    durations=durations.unsqueeze(-1),
                    durations_lengths=durations_lengths,
                    score=score.unsqueeze(-1),
                    score_lengths=score_lengths,
                    tempo=tempo.unsqueeze(-1),
                    tempo_lengths=tempo_lengths,
                )

                labelFrame = labelFrame[
                    :, : labelFrame_lengths.max()
                ]  # for data-parallel
                scoreFrame = scoreFrame[
                    :, : scoreFrame_lengths.max()
                ]  # for data-parallel

                # Extract Syllable Level label, score, tempo information from Frame Level
                (
                    label,
                    label_lengths,
                    score,
                    score_lengths,
                    tempo,
                    tempo_lengths,
                ) = self.score_feats_extract(
                    durations=labelFrame,
                    durations_lengths=labelFrame_lengths,
                    score=scoreFrame,
                    score_lengths=scoreFrame_lengths,
                    tempo=tempoFrame,
                    tempo_lengths=tempoFrame_lengths,
                )

                # calculate durations, represent syllable encoder outputs to feats mapping
                # Syllable Level duration info needs phone & midi
                ds = []
                for i, _ in enumerate(labelFrame_lengths):
                    assert labelFrame_lengths[i] == scoreFrame_lengths[i]
                    assert label_lengths[i] == score_lengths[i]

                    frame_length = labelFrame_lengths[i]
                    _phoneFrame = labelFrame[i, :frame_length]
                    _midiFrame = scoreFrame[i, :frame_length]

                    # Clean _phoneFrame & _midiFrame
                    for index in range(frame_length):
                        if _phoneFrame[index] == 0 and _midiFrame[index] == 0:
                            frame_length -= 1
                            feats_lengths[i] -= 1

                    syllable_length = label_lengths[i]
                    _phoneSyllable = label[i, :syllable_length]
                    _midiSyllable = score[i, :syllable_length]

                    # logging.info(f"_phoneFrame: {_phoneFrame}, _midiFrame: {_midiFrame}")
                    # logging.info(f"_phoneSyllable: {_phoneSyllable}, _midiSyllable: {_midiSyllable}, _tempoSyllable: {tempo[i]}")

                    start_index = 0
                    ds_tmp = []
                    flag_finish = 0
                    for index in range(syllable_length):
                        _findPhone = _phoneSyllable[index]
                        _findMidi = _midiSyllable[index]
                        _length = 0
                        if flag_finish == 1:
                            # Fix error in _phoneSyllable & _midiSyllable
                            label[i, index] = 0
                            score[i, index] = 0
                            tempo[i, index] = 0
                            label_lengths[i] -= 1
                            score_lengths[i] -= 1
                            tempo_lengths[i] -= 1
                        else:
                            for indexFrame in range(start_index, frame_length):
                                if (
                                    _phoneFrame[indexFrame] == _findPhone
                                    and _midiFrame[indexFrame] == _findMidi
                                ):
                                    _length += 1
                                else:
                                    # logging.info(f"_findPhone: {_findPhone}, _findMidi: {_findMidi}, _length: {_length}")
                                    ds_tmp.append(_length)
                                    start_index = indexFrame
                                    break
                                if indexFrame == frame_length - 1:
                                    # logging.info(f"_findPhone: {_findPhone}, _findMidi: {_findMidi}, _length: {_length}")
                                    flag_finish = 1
                                    ds_tmp.append(_length)
                                    start_index = indexFrame
                                    # logging.info("Finish")
                                    break

                    # logging.info(f"ds_tmp: {ds_tmp}, sum(ds_tmp): {sum(ds_tmp)}, frame_length: {frame_length}, feats_lengths[i]: {feats_lengths[i]}")
                    assert (
                        sum(ds_tmp) == frame_length and sum(ds_tmp) == feats_lengths[i]
                    )

                    ds.append(torch.tensor(ds_tmp))
                ds = pad_list(ds, pad_value=0).to(label.device)

            if self.pitch_extract is not None and pitch is None:
                pitch, pitch_lengths = self.pitch_extract(
                    input=singing,
                    input_lengths=singing_lengths,
                    feats_lengths=feats_lengths,
                )

            if self.energy_extract is not None and energy is None:
                energy, energy_lengths = self.energy_extract(
                    singing,
                    singing_lengths,
                    feats_lengths=feats_lengths,
                    durations=durations,
                    durations_lengths=durations_lengths,
                )

            # Normalize
            if self.normalize is not None:
                feats, feats_lengths = self.normalize(feats, feats_lengths)
            if self.pitch_normalize is not None:
                pitch, pitch_lengths = self.pitch_normalize(pitch, pitch_lengths)
            if self.energy_normalize is not None:
                energy, energy_lengths = self.energy_normalize(energy, energy_lengths)

        # Make batch for svs inputs
        batch = dict(
            text=text,
            text_lengths=text_lengths,
            feats=feats,
            feats_lengths=feats_lengths,
            flag_IsValid=flag_IsValid,
        )

        if spembs is not None:
            batch.update(spembs=spembs)
        if sids is not None:
            batch.update(sids=sids)
        if lids is not None:
            batch.update(lids=lids)
        if label is not None:
            label = label.to(dtype=torch.long)
            batch.update(label=label, label_lengths=label_lengths)
        if score is not None and pitch is None:
            score = score.to(dtype=torch.long)
            batch.update(midi=score, midi_lengths=score_lengths)
        if tempo is not None:
            tempo = tempo.to(dtype=torch.long)
            batch.update(tempo=tempo, tempo_lengths=tempo_lengths)
        if ds is not None:
            batch.update(ds=ds)
        if self.pitch_extract is not None and pitch is not None:
            batch.update(midi=pitch, midi_lengths=pitch_lengths)
        if self.energy_extract is not None and energy is not None:
            batch.update(energy=energy, energy_lengths=energy_lengths)
        if self.svs.require_raw_singing:
            batch.update(singing=singing, singing_lengths=singing_lengths)
        return self.svs(**batch)
Beispiel #9
0
    def inference(
        self,
        text: torch.Tensor,
        durations: torch.Tensor,
        score: torch.Tensor,
        singing: Optional[torch.Tensor] = None,
        pitch: Optional[torch.Tensor] = None,
        tempo: Optional[torch.Tensor] = None,
        energy: Optional[torch.Tensor] = None,
        spembs: Optional[torch.Tensor] = None,
        sids: Optional[torch.Tensor] = None,
        lids: Optional[torch.Tensor] = None,
        **decode_config,
    ) -> Dict[str, torch.Tensor]:
        """Caclualte features and return them as a dict.
        Args:
            text (Tensor): Text index tensor (T_text).
            singing (Tensor): Singing waveform tensor (T_wav).
            spembs (Optional[Tensor]): Speaker embedding tensor (D,).
            sids (Optional[Tensor]): Speaker ID tensor (1,).
            lids (Optional[Tensor]): Language ID tensor (1,).
            durations (Optional[Tensor): Duration tensor.
            pitch (Optional[Tensor): Pitch tensor.
            energy (Optional[Tensor): Energy tensor.
        Returns:
            Dict[str, Tensor]: Dict of outputs.
        """
        durations_lengths = torch.tensor([len(durations)])
        score_lengths = torch.tensor([len(score)])
        tempo_lengths = torch.tensor([len(tempo)])

        assert durations_lengths == score_lengths and durations_lengths == tempo_lengths

        # unsqueeze of singing must be here, or it'll cause error in the return dim of STFT
        text = text.unsqueeze(0)  # for data-parallel
        durations = durations.unsqueeze(0)  # for data-parallel
        score = score.unsqueeze(0)  # for data-parallel
        tempo = tempo.unsqueeze(0)  # for data-parallel

        # Extract auxiliary features
        # score : 128 midi pitch
        # tempo : bpm
        # duration :
        #   input-> phone-id seqence | output -> frame level(取众数 from window) or syllable level
        ds = None
        batch_size = text.size(0)
        assert batch_size == 1
        if isinstance(self.score_feats_extract, FrameScoreFeats):
            (
                label,
                label_lengths,
                score,
                score_lengths,
                tempo,
                tempo_lengths,
            ) = self.score_feats_extract(
                durations=durations.unsqueeze(-1),
                durations_lengths=durations_lengths,
                score=score.unsqueeze(-1),
                score_lengths=score_lengths,
                tempo=tempo.unsqueeze(-1),
                tempo_lengths=tempo_lengths,
            )

            # calculate durations, new text & text_length
            # Syllable Level duration info needs phone
            # NOTE(Shuai) Duplicate adjacent phones will appear in text files sometimes
            # e.g. oniku_0000000000000000hato_0002
            # 10.951 11.107 sh
            # 11.107 11.336 i
            # 11.336 11.610 i
            # 11.610 11.657 k
            _text_cal = []
            _text_length_cal = []
            ds = []
            for i in range(batch_size):
                _phone = label[i]

                _output, counts = torch.unique_consecutive(_phone, return_counts=True)

                _text_cal.append(_output)
                _text_length_cal.append(len(_output))
                ds.append(counts)
            ds = pad_list(ds, pad_value=0).to(text.device)
            text = pad_list(_text_cal, pad_value=0).to(text.device, dtype=torch.long)
            text_lengths = torch.tensor(_text_length_cal).to(text.device)

        elif isinstance(self.score_feats_extract, SyllableScoreFeats):
            extractMethod_frame = FrameScoreFeats(
                fs=self.score_feats_extract.fs,
                n_fft=self.score_feats_extract.n_fft,
                win_length=self.score_feats_extract.win_length,
                hop_length=self.score_feats_extract.hop_length,
                window=self.score_feats_extract.window,
                center=self.score_feats_extract.center,
            )

            # logging.info(f"extractMethod_frame: {extractMethod_frame}")

            (
                labelFrame,
                labelFrame_lengths,
                scoreFrame,
                scoreFrame_lengths,
                tempoFrame,
                tempoFrame_lengths,
            ) = extractMethod_frame(
                durations=durations.unsqueeze(-1),
                durations_lengths=durations_lengths,
                score=score.unsqueeze(-1),
                score_lengths=score_lengths,
                tempo=tempo.unsqueeze(-1),
                tempo_lengths=tempo_lengths,
            )

            labelFrame = labelFrame[:, : labelFrame_lengths.max()]  # for data-parallel
            scoreFrame = scoreFrame[:, : scoreFrame_lengths.max()]  # for data-parallel

            # Extract Syllable Level label, score, tempo information from Frame Level
            (
                label,
                label_lengths,
                score,
                score_lengths,
                tempo,
                tempo_lengths,
            ) = self.score_feats_extract(
                durations=labelFrame,
                durations_lengths=labelFrame_lengths,
                score=scoreFrame,
                score_lengths=scoreFrame_lengths,
                tempo=tempoFrame,
                tempo_lengths=tempoFrame_lengths,
            )

            # calculate durations, represent syllable encoder outputs to feats mapping
            # Syllable Level duration info needs phone & midi
            ds = []
            for i, _ in enumerate(labelFrame_lengths):
                assert labelFrame_lengths[i] == scoreFrame_lengths[i]
                assert label_lengths[i] == score_lengths[i]

                frame_length = labelFrame_lengths[i]
                _phoneFrame = labelFrame[i, :frame_length]
                _midiFrame = scoreFrame[i, :frame_length]

                # Clean _phoneFrame & _midiFrame
                for index in range(frame_length):
                    if _phoneFrame[index] == 0 and _midiFrame[index] == 0:
                        frame_length -= 1
                        feats_lengths[i] -= 1

                syllable_length = label_lengths[i]
                _phoneSyllable = label[i, :syllable_length]
                _midiSyllable = score[i, :syllable_length]

                # logging.info(f"_phoneFrame: {_phoneFrame}, _midiFrame: {_midiFrame}")
                # logging.info(f"_phoneSyllable: {_phoneSyllable}, _midiSyllable: {_midiSyllable}, _tempoSyllable: {tempo[i]}")

                start_index = 0
                ds_tmp = []
                flag_finish = 0
                for index in range(syllable_length):
                    _findPhone = _phoneSyllable[index]
                    _findMidi = _midiSyllable[index]
                    _length = 0
                    if flag_finish == 1:
                        # Fix error in _phoneSyllable & _midiSyllable
                        label[i, index] = 0
                        score[i, index] = 0
                        tempo[i, index] = 0
                        label_lengths[i] -= 1
                        score_lengths[i] -= 1
                        tempo_lengths[i] -= 1
                    else:
                        for indexFrame in range(start_index, frame_length):
                            if (
                                _phoneFrame[indexFrame] == _findPhone
                                and _midiFrame[indexFrame] == _findMidi
                            ):
                                _length += 1
                            else:
                                # logging.info(f"_findPhone: {_findPhone}, _findMidi: {_findMidi}, _length: {_length}")
                                ds_tmp.append(_length)
                                start_index = indexFrame
                                break
                            if indexFrame == frame_length - 1:
                                # logging.info(f"_findPhone: {_findPhone}, _findMidi: {_findMidi}, _length: {_length}")
                                flag_finish = 1
                                ds_tmp.append(_length)
                                start_index = indexFrame
                                # logging.info("Finish")
                                break

                logging.info(
                    f"ds_tmp: {ds_tmp}, sum(ds_tmp): {sum(ds_tmp)}, frame_length: {frame_length}, feats_lengths[i]: {feats_lengths[i]}"
                )
                assert sum(ds_tmp) == frame_length and sum(ds_tmp) == feats_lengths[i]

                ds.append(torch.tensor(ds_tmp))
            ds = pad_list(ds, pad_value=0).to(label.device)

        input_dict = dict(text=text)

        if score is not None and pitch is None:
            score = score.to(dtype=torch.long)
            input_dict["midi"] = score
        if durations is not None:
            label = label.to(dtype=torch.long)
            input_dict["label"] = label
        if ds is not None:
            input_dict.update(ds=ds)
        if tempo is not None:
            tempo = tempo.to(dtype=torch.long)
            input_dict.update(tempo=tempo)
        if spembs is not None:
            input_dict.update(spembs=spembs)
        if sids is not None:
            input_dict.update(sids=sids)
        if lids is not None:
            input_dict.update(lids=lids)

        # output_dict = self.svs.inference(**input_dict, **decode_config)
        outs, probs, att_ws = self.svs.inference(**input_dict)

        if self.normalize is not None:
            # NOTE: normalize.inverse is in-place operation
            outs_denorm = self.normalize.inverse(outs.clone()[None])[0][0]
        else:
            outs_denorm = outs

        return outs, outs_denorm, probs, att_ws
    def forward(self, feats, midi, label, spk_ids, feats_lengths, midi_lengths, label_lengths):
        """forward.
        Args:
            feats: Batch of lengths (B, T_feats, adim).
            midi: Batch of lengths (B, T_feats).
            label: Batch of lengths (B, T_feats).
            feats_lengths: Batch of input lengths (B,). Note that the feats, midi, label are time-aligned on frame level.
        """
        h_masks = self._source_mask(feats_lengths)              # (B, 1, T_feats)
        midi_loss, label_loss, speaker_loss = 0, 0, 0
        masked_midi_outs, masked_label_outs, speaker_predict = None, None, None
        batch_size = feats.shape[0]

        if "midi" in self.predict_type:
            # midi predict
            zs_midi, _ = self.predictor_midi(feats, h_masks)        # (B, T_feats, adim=80)
            zs_midi = self.linear_out_midi(zs_midi)                 # (B, T_feats, midi classes=129)

            # loss calculation
            if self.predict_criterion_type == "CrossEntropy":
                probs_midi = zs_midi                                    # (B, T_feats, midi classes=129) 
                midi_loss = self.predictor_criterion(probs_midi, midi, feats_lengths)

                # make midi predict output for cycle
                masked_probs_midi = probs_midi * h_masks.permute(0,2,1)                             # (B, T_feats, adim=80)
                masked_midi_outs = torch.argmax(F.softmax(masked_probs_midi,dim=-1), dim=-1)        # (B, T_feats)

            elif self.predict_criterion_type == "CTC":
                # aggregate G.T.-midi
                # NOTE(Shuai) midi need "+1" in the begin of loss calculation & "-1" in the end of prediction
                # because index-0 is for <blank> in CTC-loss calculation. Note that CTC-loss can`t make time-aligned midi prediction for cycle-singing
                _midi_cal = []
                _midi_length_cal = []
                ds = []
                for i, _ in enumerate(midi_lengths):
                    _midi = midi[i, :midi_lengths[i]] + 1

                    _output, counts = torch.unique_consecutive(_midi, return_counts=True)
                    
                    _midi_cal.append(_output)
                    _midi_length_cal.append(len(_output))
                    ds.append(counts)
                # ds = pad_list(ds, pad_value=0).to(midi.device)
                midi = pad_list(_midi_cal, pad_value=0).to(midi.device, dtype=torch.long)
                midi_lengths = torch.tensor(_midi_length_cal).to(midi.device)

                # logging.info(f"midi: {midi.shape}")
                # logging.info(f"midi: {midi}")
                # logging.info(f"midi_lengths: {midi_lengths}")
                # quit()

                probs_midi = F.log_softmax(zs_midi, dim=-1).permute(1,0,2)  # CTC need shape as (T, N-batch, Class num)
                midi_loss = self.predictor_criterion(probs_midi, midi, feats_lengths, midi_lengths)


        if "label" in self.predict_type:
            # label predict
            zs_label, _ = self.predictor_label(feats, h_masks)                                      # (B, T_feats, adim=80)
            zs_label = self.linear_out_label(zs_label)                                              # (B, T_feats, midi classes=50)

            if self.predict_criterion_type == "CrossEntropy":
                probs_label = zs_label                                                              # (B, T_feats, midi classes=50) 
                label_loss = self.predictor_criterion(probs_label, label, feats_lengths)

                # make label predict output for cycle
                masked_probs_label = probs_label * h_masks.permute(0,2,1)                           # (B, T_feats, adim=80)
                masked_label_outs = torch.argmax(F.softmax(masked_probs_label,dim=-1), dim=-1)      # (B, T_feats)

            elif self.predict_criterion_type == "CTC":
                # aggregate G.T.-label
                probs_label = F.log_softmax(zs_label, dim=-1).permute(1,0,2)                        # CTC need shape as (T, N-batch, Class num)
                label_loss = self.predictor_criterion(probs_label, label, feats_lengths, label_lengths)
        

        if "spk" in self.predict_type:
            # speaker predict
            packed_spk_embed = nn.utils.rnn.pack_padded_sequence(feats, feats_lengths.type(torch.int64).to("cpu"), batch_first=True, enforce_sorted=False)
            packed_spk_out, hn = self.predictor_spk(packed_spk_embed)
            spk_out, _ = nn.utils.rnn.pad_packed_sequence(packed_spk_out, batch_first=True)
            # hn - (dim direction * layer nums, N_batch, eunits)

            hn = hn.reshape(batch_size, -1)                                                         # (N_batch, dim direction * layer nums * eunits_spk)
            probs_spk = self.linear_out_spk(hn)                                                     # (N_batch, speaker classes=4)
            speaker_loss = self.predictor_spk_criterion(probs_spk, spk_ids)
            speaker_predict = torch.argmax(F.softmax(probs_spk,dim=-1), dim=-1)                     # (B, 1)


        return midi_loss, label_loss, speaker_loss, masked_midi_outs, masked_label_outs, speaker_predict