示例#1
0
def segment_test():
    x = T.range(0, 11)
    x = x.repeat(8, 1).unsqueeze(1)
    segment_ids = T.randint(0, 7, (8, ))

    segments = segment(x, segment_ids, segment_size=4)
    for idx, start_indx in enumerate(segment_ids):
        assert x[idx, :,
                 start_indx:start_indx + 4].sum() == segments[idx, :, :].sum()
示例#2
0
def segment_test():
    x = T.range(0, 11)
    x = x.repeat(8, 1).unsqueeze(1)
    segment_ids = T.randint(0, 7, (8, ))

    segments = segment(x, segment_ids, segment_size=4)
    for idx, start_indx in enumerate(segment_ids):
        assert x[idx, :,
                 start_indx:start_indx + 4].sum() == segments[idx, :, :].sum()

    try:
        segments = segment(x, segment_ids, segment_size=10)
        raise Exception("Should have failed")
    except:
        pass

    segments = segment(x, segment_ids, segment_size=10, pad_short=True)
    for idx, start_indx in enumerate(segment_ids):
        assert x[idx, :, start_indx:start_indx +
                 10].sum() == segments[idx, :, :].sum()
示例#3
0
    def train_step(self, batch: dict, criterion: nn.Module,
                   optimizer_idx: int) -> Tuple[Dict, Dict]:
        """Perform a single training step. Run the model forward pass and compute losses.

        Args:
            batch (Dict): Input tensors.
            criterion (nn.Module): Loss layer designed for the model.
            optimizer_idx (int): Index of optimizer to use. 0 for the generator and 1 for the discriminator networks.

        Returns:
            Tuple[Dict, Dict]: Model ouputs and computed losses.
        """
        # pylint: disable=attribute-defined-outside-init
        if optimizer_idx not in [0, 1]:
            raise ValueError(" [!] Unexpected `optimizer_idx`.")

        if optimizer_idx == 0:
            text_input = batch["text_input"]
            text_lengths = batch["text_lengths"]
            mel_lengths = batch["mel_lengths"]
            linear_input = batch["linear_input"]
            d_vectors = batch["d_vectors"]
            speaker_ids = batch["speaker_ids"]
            waveform = batch["waveform"]

            # generator pass
            outputs = self.forward(
                text_input,
                text_lengths,
                linear_input.transpose(1, 2),
                mel_lengths,
                aux_input={
                    "d_vectors": d_vectors,
                    "speaker_ids": speaker_ids
                },
            )

            # cache tensors for the discriminator
            self.y_disc_cache = None
            self.wav_seg_disc_cache = None
            self.y_disc_cache = outputs["model_outputs"]
            wav_seg = segment(
                waveform.transpose(1, 2),
                outputs["slice_ids"] * self.config.audio.hop_length,
                self.args.spec_segment_size * self.config.audio.hop_length,
            )
            self.wav_seg_disc_cache = wav_seg
            outputs["waveform_seg"] = wav_seg

            # compute discriminator scores and features
            (
                outputs["scores_disc_fake"],
                outputs["feats_disc_fake"],
                _,
                outputs["feats_disc_real"],
            ) = self.disc(outputs["model_outputs"], wav_seg)

            # compute losses
            with autocast(enabled=False):  # use float32 for the criterion
                loss_dict = criterion[optimizer_idx](
                    waveform_hat=outputs["model_outputs"].float(),
                    waveform=wav_seg.float(),
                    z_p=outputs["z_p"].float(),
                    logs_q=outputs["logs_q"].float(),
                    m_p=outputs["m_p"].float(),
                    logs_p=outputs["logs_p"].float(),
                    z_len=mel_lengths,
                    scores_disc_fake=outputs["scores_disc_fake"],
                    feats_disc_fake=outputs["feats_disc_fake"],
                    feats_disc_real=outputs["feats_disc_real"],
                    loss_duration=outputs["loss_duration"],
                )

        elif optimizer_idx == 1:
            # discriminator pass
            outputs = {}

            # compute scores and features
            outputs["scores_disc_fake"], _, outputs[
                "scores_disc_real"], _ = self.disc(self.y_disc_cache.detach(),
                                                   self.wav_seg_disc_cache)

            # compute loss
            with autocast(enabled=False):  # use float32 for the criterion
                loss_dict = criterion[optimizer_idx](
                    outputs["scores_disc_real"],
                    outputs["scores_disc_fake"],
                )
        return outputs, loss_dict
示例#4
0
    def forward(
        self,
        x: torch.tensor,
        x_lengths: torch.tensor,
        y: torch.tensor,
        y_lengths: torch.tensor,
        waveform: torch.tensor,
        aux_input={
            "d_vectors": None,
            "speaker_ids": None,
            "language_ids": None
        },
    ) -> Dict:
        """Forward pass of the model.

        Args:
            x (torch.tensor): Batch of input character sequence IDs.
            x_lengths (torch.tensor): Batch of input character sequence lengths.
            y (torch.tensor): Batch of input spectrograms.
            y_lengths (torch.tensor): Batch of input spectrogram lengths.
            waveform (torch.tensor): Batch of ground truth waveforms per sample.
            aux_input (dict, optional): Auxiliary inputs for multi-speaker and multi-lingual training.
                Defaults to {"d_vectors": None, "speaker_ids": None, "language_ids": None}.

        Returns:
            Dict: model outputs keyed by the output name.

        Shapes:
            - x: :math:`[B, T_seq]`
            - x_lengths: :math:`[B]`
            - y: :math:`[B, C, T_spec]`
            - y_lengths: :math:`[B]`
            - waveform: :math:`[B, T_wav, 1]`
            - d_vectors: :math:`[B, C, 1]`
            - speaker_ids: :math:`[B]`
            - language_ids: :math:`[B]`
        """
        outputs = {}
        sid, g, lid = self._set_cond_input(aux_input)
        # speaker embedding
        if self.args.use_speaker_embedding and sid is not None:
            g = self.emb_g(sid).unsqueeze(-1)  # [b, h, 1]

        # language embedding
        lang_emb = None
        if self.args.use_language_embedding and lid is not None:
            lang_emb = self.emb_l(lid).unsqueeze(-1)

        x, m_p, logs_p, x_mask = self.text_encoder(x,
                                                   x_lengths,
                                                   lang_emb=lang_emb)

        # posterior encoder
        z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g)

        # flow layers
        z_p = self.flow(z, y_mask, g=g)

        # find the alignment path
        attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
        with torch.no_grad():
            o_scale = torch.exp(-2 * logs_p)
            logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p,
                              [1]).unsqueeze(-1)  # [b, t, 1]
            logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p**2)])
            logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p])
            logp4 = torch.sum(-0.5 * (m_p**2) * o_scale,
                              [1]).unsqueeze(-1)  # [b, t, 1]
            logp = logp2 + logp3 + logp1 + logp4
            attn = maximum_path(logp,
                                attn_mask.squeeze(1)).unsqueeze(1).detach()

        # duration predictor
        attn_durations = attn.sum(3)
        if self.args.use_sdp:
            loss_duration = self.duration_predictor(
                x.detach() if self.args.detach_dp_input else x,
                x_mask,
                attn_durations,
                g=g.detach()
                if self.args.detach_dp_input and g is not None else g,
                lang_emb=lang_emb.detach() if self.args.detach_dp_input
                and lang_emb is not None else lang_emb,
            )
            loss_duration = loss_duration / torch.sum(x_mask)
        else:
            attn_log_durations = torch.log(attn_durations + 1e-6) * x_mask
            log_durations = self.duration_predictor(
                x.detach() if self.args.detach_dp_input else x,
                x_mask,
                g=g.detach()
                if self.args.detach_dp_input and g is not None else g,
                lang_emb=lang_emb.detach() if self.args.detach_dp_input
                and lang_emb is not None else lang_emb,
            )
            loss_duration = torch.sum((log_durations - attn_log_durations)**2,
                                      [1, 2]) / torch.sum(x_mask)
        outputs["loss_duration"] = loss_duration

        # expand prior
        m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
        logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p])

        # select a random feature segment for the waveform decoder
        z_slice, slice_ids = rand_segments(z, y_lengths,
                                           self.spec_segment_size)
        o = self.waveform_decoder(z_slice, g=g)

        wav_seg = segment(
            waveform,
            slice_ids * self.config.audio.hop_length,
            self.args.spec_segment_size * self.config.audio.hop_length,
        )

        if self.args.use_speaker_encoder_as_loss and self.speaker_manager.speaker_encoder is not None:
            # concate generated and GT waveforms
            wavs_batch = torch.cat((wav_seg, o), dim=0)

            # resample audio to speaker encoder sample_rate
            # pylint: disable=W0105
            if self.audio_transform is not None:
                wavs_batch = self.audio_transform(wavs_batch)

            pred_embs = self.speaker_manager.speaker_encoder.forward(
                wavs_batch, l2_norm=True)

            # split generated and GT speaker embeddings
            gt_spk_emb, syn_spk_emb = torch.chunk(pred_embs, 2, dim=0)
        else:
            gt_spk_emb, syn_spk_emb = None, None

        outputs.update({
            "model_outputs": o,
            "alignments": attn.squeeze(1),
            "z": z,
            "z_p": z_p,
            "m_p": m_p,
            "logs_p": logs_p,
            "m_q": m_q,
            "logs_q": logs_q,
            "waveform_seg": wav_seg,
            "gt_spk_emb": gt_spk_emb,
            "syn_spk_emb": syn_spk_emb,
        })
        return outputs