Beispiel #1
0
def binary_divergence_masked(input, target, lengths):
    """ Provides non-vanishing gradient, but does not equal zero if spectrograms are the same
    Inspired by https://github.com/r9y9/deepvoice3_pytorch/blob/897f31e57eb6ec2f0cafa8dc62968e60f6a96407/train.py#L537
    """

    input_logits = logit(input)
    z = -target * input_logits + torch.log1p(torch.exp(input_logits))
    m = mask(input.shape, lengths, dim=1).float().to(input.device)

    return masked_mean(z, m)
Beispiel #2
0
def masked_huber(input, target, lengths):
    """
    Always mask the first (non-batch dimension) -> usually time

    :param input:
    :param target:
    :param lengths:
    :return:
    """
    m = mask(input.shape, lengths, dim=1).float().to(input.device)
    return F.smooth_l1_loss(input * m, target * m, reduction='sum') / m.sum()
Beispiel #3
0
 def inference(self, texts, alpha=1.0):
     texts, tlens = self.processor(texts)
     texts = torch.from_numpy(texts).long().to(self.device)
     texts = torch.cat(
         (texts, torch.zeros(len(texts), 5).long().to(self.device)), dim=-1)
     tlens = torch.Tensor(tlens).to(self.device)
     with torch.no_grad():
         melspecs, prd_durans = self.model((texts, tlens, None, alpha))
     melspecs = self.normalizer.inverse(melspecs)
     msk = mask(melspecs.shape, prd_durans.sum(dim=-1).long(),
                dim=1).to(self.device)
     melspecs = melspecs.masked_fill(~msk, -11.5129).permute(0, 2, 1)
     melspecs = torch.cat(
         (melspecs, -11.5129 *
          torch.ones(len(melspecs), melspecs.size(1), 3).to(self.device)),
         dim=-1)
     return melspecs
Beispiel #4
0
    def forward(self, inputs):
        texts, tlens, mels, pos = inputs

        mels = ZeroPad2d((0, 0, 1, 0))(mels)[:, :-1, :]
        keys, values = self.text_enc(texts)
        queries = self.spec_enc(mels)

        msk = mask(shape=(len(keys), queries.shape[1], keys.shape[1]),
                   lengths=tlens,
                   dim=-1).to(texts.device)

        if pos:
            keys += positional_encoding(keys.shape[-1], keys.shape[1],
                                        w=6.42).to(keys.device)
            queries += positional_encoding(queries.shape[-1],
                                           queries.shape[1],
                                           w=1).to(queries.device)

        seeds, attns = self.attention(queries, keys, values, mask=msk)
        melspecs = self.spec_dec(seeds + queries)
        return melspecs, attns
Beispiel #5
0
def l1_masked(input, target, lengths):
    m = mask(input.shape, lengths, dim=1).float().to(input.device)
    return F.l1_loss(input * m, target * m, reduction='sum') / m.sum()
Beispiel #6
0
    def forward(self, input, target, lengths):

        m = mask(input.shape, lengths, dim=1).float().to(input.device)
        return self.l1(input * m, target * m) / m.sum()
Beispiel #7
0
def masked_ssim(input, target, lengths):
    m = mask(input.shape, lengths, dim=1).float().to(input.device)
    input, target = input * m, target * m
    return 1 - ssim(input.unsqueeze(1), target.unsqueeze(1))