예제 #1
0
    def _target_mask(self, olens):
        """Make masks for masked self-attention.

        Examples:
            >>> olens = [5, 3]
            >>> self._target_mask(olens)
            tensor([[[1, 0, 0, 0, 0],
                     [1, 1, 0, 0, 0],
                     [1, 1, 1, 0, 0],
                     [1, 1, 1, 1, 0],
                     [1, 1, 1, 1, 1]],
                    [[1, 0, 0, 0, 0],
                     [1, 1, 0, 0, 0],
                     [1, 1, 1, 0, 0],
                     [0, 0, 0, 0, 0],
                     [0, 0, 0, 0, 0]]], dtype=torch.uint8)

        """
        #print("O lens:",olens)
        y_masks = make_non_pad_mask(olens).to(next(self.parameters()).device)
        s_masks = subsequent_mask(y_masks.size(-1), device=y_masks.device).unsqueeze(0)
        # y_masks = torch.mul(y_masks,1)
        # print("y masks ", y_masks)
        # print("s masks ", s_masks)
        # print("y mask",y_masks.size())
        # print("s mask", s_masks.size())
        # print("y mask", y_masks.unsqueeze(-2).size())
        # print("y mask", y_masks.unsqueeze(-1).size())
        # print("S marks : {} - {}".format(y_masks.unsqueeze(-2) & s_masks & y_masks.unsqueeze(-1),type(y_masks.unsqueeze(-2) & s_masks & y_masks.unsqueeze(-1))))
        return y_masks.unsqueeze(-2) & s_masks & y_masks.unsqueeze(-1)
예제 #2
0
    def _source_to_target_mask(self, ilens, olens):
        """Make masks for encoder-decoder attention.

        Examples:
            >>> ilens = [4, 2]
            >>> olens = [5, 3]
            >>> self._source_to_target_mask(ilens)
            tensor([[[1, 1, 1, 1],
                     [1, 1, 1, 1],
                     [1, 1, 1, 1],
                     [1, 1, 1, 1],
                     [1, 1, 1, 1]],
                    [[1, 1, 0, 0],
                     [1, 1, 0, 0],
                     [1, 1, 0, 0],
                     [0, 0, 0, 0],
                     [0, 0, 0, 0]]], dtype=torch.uint8)

        """
        x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device)
        y_masks = make_non_pad_mask(olens).to(next(self.parameters()).device)
        return x_masks.unsqueeze(-2) & y_masks.unsqueeze(-1)
예제 #3
0
 def _source_mask(self, ilens: torch.Tensor) -> torch.Tensor:
     """Make masks for self-attention.
     Examples:
         >>> ilens = [5, 3]
         >>> self._source_mask(ilens)
         tensor([[[1, 1, 1, 1, 1],
                  [1, 1, 1, 1, 1],
                  [1, 1, 1, 1, 1],
                  [1, 1, 1, 1, 1],
                  [1, 1, 1, 1, 1]],
                 [[1, 1, 1, 0, 0],
                  [1, 1, 1, 0, 0],
                  [1, 1, 1, 0, 0],
                  [0, 0, 0, 0, 0],
                  [0, 0, 0, 0, 0]]], dtype=torch.uint8)
     """
     x_masks = make_non_pad_mask(ilens).to(device=next(self.parameters()).device)
     return x_masks.unsqueeze(-2) & x_masks.unsqueeze(-1)
예제 #4
0
    def forward(
        self,
        xs: torch.Tensor,
        ilens: torch.Tensor,
        ys: torch.Tensor,
        olens: torch.Tensor,
        ds: torch.Tensor,
        es: torch.Tensor,
        ps: torch.Tensor,
        avg_mel: torch.Tensor = None,
        phn_level_predictor: bool = False
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """Calculate forward propagation.
        Args:
            xs (Tensor): Batch of padded character ids (B, Tmax).
            ilens (LongTensor): Batch of lengths of each input batch (B,).
            ys (Tensor): Batch of padded target features (B, Lmax, odim).
            olens (LongTensor): Batch of the lengths of each target (B,).
            spembs (Tensor, optional): Batch of speaker embedding vectors (B, spk_embed_dim).
        Returns:
            Tensor: Loss value.
        """
        # remove unnecessary padded part (for multi-gpus)
        xs = xs[:, :max(ilens)]  # torch.Size([32, 121]) -> [B, Tmax]
        ys = ys[:, :max(olens)]  # torch.Size([32, 868, 80]) -> [B, Lmax, odim]

        # forward propagation
        before_outs, after_outs, d_outs, e_outs, p_outs, phn, ys_phn = self._forward(
            xs,
            ilens,
            olens,
            ds,
            es,
            ps,
            is_inference=False,
            avg_mel=avg_mel,
            phn_level_predictor=phn_level_predictor)

        # apply mask to remove padded part
        if self.use_masking:
            in_masks = make_non_pad_mask(ilens).to(xs.device)
            d_outs = d_outs.masked_select(in_masks)
            ds = ds.masked_select(in_masks)
            out_masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device)
            mel_masks = make_non_pad_mask(olens).to(ys.device)
            before_outs = before_outs.masked_select(out_masks)
            es = es.masked_select(mel_masks)  # Write size
            ps = ps.masked_select(mel_masks)  # Write size
            e_outs = e_outs.masked_select(mel_masks)  # Write size
            p_outs = p_outs.masked_select(mel_masks)  # Write size
            after_outs = (after_outs.masked_select(out_masks)
                          if after_outs is not None else None)
            ys = ys.masked_select(out_masks)
            if phn is not None and ys_phn is not None:
                phn = phn.masked_select(in_masks.unsqueeze(-1))
                ys_phn = ys_phn.masked_select(in_masks.unsqueeze(-1))

        acoustic_loss = 0

        if phn_level_predictor:
            acoustic_loss = self.acoustic_criterion(ys_phn, phn)

        # calculate loss
        before_loss = self.criterion(before_outs, ys)
        after_loss = 0
        if after_outs is not None:
            after_loss = self.criterion(after_outs, ys)
            l1_loss = before_loss + after_loss
        duration_loss = self.duration_criterion(d_outs, ds)
        energy_loss = self.energy_criterion(e_outs, es)
        pitch_loss = self.pitch_criterion(p_outs, ps)

        # make weighted mask and apply it
        if self.use_weighted_masking:
            out_masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device)
            out_weights = out_masks.float() / out_masks.sum(
                dim=1, keepdim=True).float()
            out_weights /= ys.size(0) * ys.size(2)
            duration_masks = make_non_pad_mask(ilens).to(ys.device)
            duration_weights = (
                duration_masks.float() /
                duration_masks.sum(dim=1, keepdim=True).float())
            duration_weights /= ds.size(0)

            # apply weight
            l1_loss = l1_loss.mul(out_weights).masked_select(out_masks).sum()
            duration_loss = (duration_loss.mul(duration_weights).masked_select(
                duration_masks).sum())

        loss = l1_loss + duration_loss + energy_loss + pitch_loss + acoustic_loss
        report_keys = [
            {
                "l1_loss": l1_loss.item()
            },
            {
                "before_loss": before_loss.item()
            },
            {
                "after_loss": after_loss.item()
            },
            {
                "duration_loss": duration_loss.item()
            },
            {
                "energy_loss": energy_loss.item()
            },
            {
                "pitch_loss": pitch_loss.item()
            },
            {
                "acostic_loss": acoustic_loss
            },
            {
                "loss": loss.item()
            },
        ]

        # self.reporter.report(report_keys)

        return loss, report_keys
예제 #5
0
    def forward(self, xs, ilens, ys, olens, ds, es, ps, *args, **kwargs):
        """Calculate forward propagation.

        Args:
            xs (Tensor): Batch of padded character ids (B, Tmax).
            ilens (LongTensor): Batch of lengths of each input batch (B,).
            ys (Tensor): Batch of padded target features (B, Lmax, odim).
            olens (LongTensor): Batch of the lengths of each target (B,).
            spembs (Tensor, optional): Batch of speaker embedding vectors (B, spk_embed_dim).

        Returns:
            Tensor: Loss value.

        """
        # remove unnecessary padded part (for multi-gpus)
        xs = xs[:, :max(ilens)]
        ys = ys[:, :max(olens)]

        # forward propagation
        before_outs, after_outs, d_outs, e_outs, p_outs = self._forward(
            xs, ilens, ys, olens, ds, es, ps, is_inference=False)

        # modifiy mod part of groundtruth
        if hp.reduction_factor > 1:
            olens = olens.new(
                [olen - olen % self.reduction_factor for olen in olens])
            max_olen = max(olens)
            ys = ys[:, :max_olen]

        # apply mask to remove padded part
        if self.use_masking:
            in_masks = make_non_pad_mask(ilens).to(xs.device)
            d_outs = d_outs.masked_select(in_masks)
            ds = ds.masked_select(in_masks)
            out_masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device)
            mel_masks = make_non_pad_mask(olens).to(ys.device)
            before_outs = before_outs.masked_select(out_masks)
            es = es.masked_select(mel_masks)  # Write size
            ps = ps.masked_select(mel_masks)  # Write size
            e_outs = e_outs.masked_select(mel_masks)  # Write size
            p_outs = p_outs.masked_select(mel_masks)  # Write size
            after_outs = (after_outs.masked_select(out_masks)
                          if after_outs is not None else None)
            ys = ys.masked_select(out_masks)

        # calculate loss
        before_loss = self.criterion(before_outs, ys)
        after_loss = 0
        if after_outs is not None:
            after_loss = self.criterion(after_outs, ys)
            l1_loss = before_loss + after_loss
        duration_loss = self.duration_criterion(d_outs, ds)
        energy_loss = self.energy_criterion(e_outs, es)
        pitch_loss = self.pitch_criterion(p_outs, ps)

        # make weighted mask and apply it
        if hp.use_weighted_masking:
            out_masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device)
            out_weights = out_masks.float() / out_masks.sum(
                dim=1, keepdim=True).float()
            out_weights /= ys.size(0) * ys.size(2)
            duration_masks = make_non_pad_mask(ilens).to(ys.device)
            duration_weights = (
                duration_masks.float() /
                duration_masks.sum(dim=1, keepdim=True).float())
            duration_weights /= ds.size(0)

            # apply weight
            l1_loss = l1_loss.mul(out_weights).masked_select(out_masks).sum()
            duration_loss = (duration_loss.mul(duration_weights).masked_select(
                duration_masks).sum())

        loss = l1_loss + duration_loss + energy_loss + pitch_loss
        report_keys = [
            {
                "l1_loss": l1_loss.item()
            },
            {
                "before_loss": before_loss.item()
            },
            {
                "after_loss": after_loss.item()
            },
            {
                "duration_loss": duration_loss.item()
            },
            {
                "energy_loss": energy_loss.item()
            },
            {
                "pitch_loss": pitch_loss.item()
            },
            {
                "loss": loss.item()
            },
        ]

        # report extra information
        if self.use_scaled_pos_enc:
            report_keys += [
                {
                    "encoder_alpha": self.encoder.embed[-1].alpha.data.item()
                },
                {
                    "decoder_alpha": self.decoder.embed[-1].alpha.data.item()
                },
            ]
        #self.reporter.report(report_keys)

        return loss, report_keys