Exemplo n.º 1
0
    def audio_to_normal_dist(
            self, *, spec: torch.Tensor,
            audio: torch.Tensor) -> (torch.Tensor, list, list):
        audio = split_view(audio, self.n_group, 1).permute(0, 2, 1)
        output_audio = []
        log_s_list = []
        log_det_W_list = []

        for k in range(self.n_flows):
            if k % self.n_early_every == 0 and k > 0:
                output_audio.append(audio[:, :self.n_early_size, :])
                audio = audio[:, self.n_early_size:, :]

            audio, log_det_W = self.convinv[k](audio)
            log_det_W_list.append(log_det_W)

            n_half = int(audio.size(1) / 2)
            audio_0 = audio[:, :n_half, :]
            audio_1 = audio[:, n_half:, :]

            output = self.wavenet[k]((audio_0, spec))
            log_s = output[:, n_half:, :]
            b = output[:, :n_half, :]
            audio_1 = torch.exp(log_s) * audio_1 + b
            log_s_list.append(log_s)

            audio = torch.cat([audio_0, audio_1], 1)

        output_audio.append(audio)
        return torch.cat(output_audio, 1), log_s_list, log_det_W_list
Exemplo n.º 2
0
    def audio_to_normal_dist(self, *, spec: torch.Tensor,
                             audio: torch.Tensor) -> (torch.Tensor, float):
        logdet = 0

        spec = spec[:, :, :-1]
        audio = split_view(audio, self.n_group, 1).permute(0, 2, 1)
        if spec.size(2) != audio.size(2):
            spec = F.interpolate(spec, size=audio.size(2))

        for _ in range(self.n_flows):
            audio, log_det_W = self.conv(audio)
            logdet += log_det_W

            n_half = int(audio.size(1) / 2)
            audio_0 = audio[:, :n_half, :]
            audio_1 = audio[:, n_half:, :]

            output = self.wn((audio_0, spec))
            log_s = output[:, n_half:, :]
            b = output[:, :n_half, :]
            audio_1 = torch.exp(log_s) * audio_1 + b
            logdet += torch.sum(log_s)

            audio = torch.cat([audio_0, audio_1], 1)

        return audio, logdet
Exemplo n.º 3
0
    def audio_to_normal_dist(
            self, *, spec: torch.Tensor,
            audio: torch.Tensor) -> (torch.Tensor, list, list):
        #  Upsample spectrogram to size of audio
        spec = self.upsample(spec)
        assert spec.size(2) >= audio.size(1)
        if spec.size(2) > audio.size(1):
            spec = spec[:, :, :audio.size(1)]

        # logging.debug(f"spec: {spec.shape}. n_group: {self.n_group}")
        spec = split_view(spec, self.n_group, 2).permute(0, 2, 1, 3)
        spec = spec.contiguous().view(spec.size(0), spec.size(1), -1)
        spec = spec.permute(0, 2, 1)

        audio = split_view(audio, self.n_group, 1).permute(0, 2, 1)
        output_audio = []
        log_s_list = []
        log_det_W_list = []

        for k in range(self.n_flows):
            if k % self.n_early_every == 0 and k > 0:
                output_audio.append(audio[:, :self.n_early_size, :])
                audio = audio[:, self.n_early_size:, :]

            audio, log_det_W = self.convinv[k](audio)
            log_det_W_list.append(log_det_W)

            n_half = int(audio.size(1) / 2)
            audio_0 = audio[:, :n_half, :]
            audio_1 = audio[:, n_half:, :]

            output = self.wavenet[k]((audio_0, spec))
            log_s = output[:, n_half:, :]
            b = output[:, :n_half, :]
            audio_1 = torch.exp(log_s) * audio_1 + b
            log_s_list.append(log_s)

            audio = torch.cat([audio_0, audio_1], 1)

        output_audio.append(audio)
        return torch.cat(output_audio, 1), log_s_list, log_det_W_list
Exemplo n.º 4
0
    def norm_dist_to_audio(self, *, spec, z=None, sigma: float = 1.0):
        if self.converted_to_2D:
            spec = torch.unsqueeze(spec, 3)
        spec = self.upsample(spec)
        spec = spec.contiguous().view(spec.size(0), spec.size(1), -1)
        # trim conv artifacts. maybe pad spec to kernel multiple
        if self.time_cutoff != 0:
            spec = spec[:, :, :self.time_cutoff]

        spec = split_view(spec, self.n_group, 2).permute(0, 2, 1, 3)
        spec = spec.contiguous().view(spec.size(0), spec.size(1), -1)
        spec = spec.permute(0, 2, 1)

        z_size = torch.Size([spec.size(0), self.n_group, spec.size(2)])
        if z is None:
            z = sigma * torch.randn(z_size, device=spec.device).to(spec.dtype)

        if self.converted_to_2D:
            z = torch.unsqueeze(z, 3)
            spec = torch.unsqueeze(spec, 3)

        audio, z = torch.split(
            z,
            [self.n_remaining_channels,
             z.size(1) - self.n_remaining_channels], 1)

        for k in reversed(range(self.n_flows)):
            n_half = self.n_halves[k]
            audio_0, audio_1 = torch.split(
                audio, [n_half, audio.size(1) - n_half], 1)

            output = self.wavenet[k]((audio_0, spec))

            b, s = torch.split(output, [n_half, output.size(1) - n_half], 1)

            audio_1 = audio_1 - b
            audio_1 = audio_1 / torch.exp(s)
            audio = torch.cat((audio_0, audio_1), 1)

            audio = self.convinv[k](audio, reverse=True)
            if k % self.n_early_every == 0 and k > 0:
                z1, z = torch.split(
                    z, [self.n_early_size,
                        z.size(1) - self.n_early_size], 1)
                audio = torch.cat((z1, audio), 1)

        if self.converted_to_2D:
            audio = audio.view(audio.size(0), audio.size(1), -1)
        return audio.permute(0, 2, 1).contiguous().view(audio.size(0), -1)
Exemplo n.º 5
0
 def get_upsample_factor(self) -> int:
     """
     As the MelSpectrogram upsampling is done using interpolation, the upsampling factor is determined
     by the ratio of the MelSpectrogram length and the waveform length
     Returns:
         An integer representing the upsampling factor
     """
     audio = torch.ones(1, self._cfg.train_ds.dataset.n_segments)
     spec, spec_len = self.audio_to_melspec_precessor(
         audio, torch.FloatTensor([len(audio)]))
     spec = spec[:, :, :-1]
     audio = split_view(audio, self._cfg.uniglow.n_group,
                        1).permute(0, 2, 1)
     upsample_factor = audio.shape[2] // spec.shape[2]
     return upsample_factor