def generate(self,
                 mels,
                 save_path: Union[str, Path, None],
                 batched,
                 target,
                 overlap,
                 mu_law,
                 silent=False):
        self.eval()

        device = next(
            self.parameters()).device  # use same device as parameters

        mu_law = mu_law if self.mode == 'RAW' else False

        output = []
        start = time.time()
        rnn1 = self.get_gru_cell(self.rnn1)
        rnn2 = self.get_gru_cell(self.rnn2)

        with torch.no_grad():

            mels = torch.as_tensor(mels, device=device)
            wave_len = (mels.size(-1) - 1) * self.hop_length
            mels = self.pad_tensor(mels.transpose(1, 2),
                                   pad=self.pad,
                                   side='both')
            mels, aux = self.upsample(mels.transpose(1, 2))

            if batched:
                mels = self.fold_with_overlap(mels, target, overlap)
                aux = self.fold_with_overlap(aux, target, overlap)

            b_size, seq_len, _ = mels.size()

            h1 = torch.zeros(b_size, self.rnn_dims, device=device)
            h2 = torch.zeros(b_size, self.rnn_dims, device=device)
            x = torch.zeros(b_size, 1, device=device)

            d = self.aux_dims
            aux_split = [aux[:, :, d * i:d * (i + 1)] for i in range(4)]

            for i in range(seq_len):

                m_t = mels[:, i, :]

                a1_t, a2_t, a3_t, a4_t = \
                    (a[:, i, :] for a in aux_split)

                x = torch.cat([x, m_t, a1_t], dim=1)
                x = self.I(x)
                h1 = rnn1(x, h1)

                x = x + h1
                inp = torch.cat([x, a2_t], dim=1)
                h2 = rnn2(inp, h2)

                x = x + h2
                x = torch.cat([x, a3_t], dim=1)
                x = F.relu(self.fc1(x))

                x = torch.cat([x, a4_t], dim=1)
                x = F.relu(self.fc2(x))

                logits = self.fc3(x)

                if self.mode == 'MOL':
                    sample = sample_from_discretized_mix_logistic(
                        logits.unsqueeze(0).transpose(1, 2))
                    output.append(sample.view(-1))
                    # x = torch.FloatTensor([[sample]]).cuda()
                    x = sample.transpose(0, 1)

                elif self.mode == 'RAW':
                    posterior = F.softmax(logits, dim=1)
                    distrib = torch.distributions.Categorical(posterior)

                    sample = 2 * distrib.sample().float() / (self.n_classes -
                                                             1.) - 1.
                    output.append(sample)
                    x = sample.unsqueeze(-1)
                else:
                    raise RuntimeError("Unknown model mode value - ",
                                       self.mode)

                if not silent and i % 100 == 0:
                    self.gen_display(i, seq_len, b_size, start)

        output = torch.stack(output).transpose(0, 1)
        output = output.cpu().numpy()
        output = output.astype(np.float64)

        if mu_law:
            output = decode_mu_law(output, self.n_classes, False)

        if batched:
            output = self.xfade_and_unfold(output, target, overlap)
        else:
            output = output[0]

        # Fade-out at the end to avoid signal cutting out suddenly
        fade_out = np.linspace(1, 0, 20 * self.hop_length)
        output = output[:wave_len]
        output[-20 * self.hop_length:] *= fade_out

        if save_path is not None:
            save_wav(output, save_path)

        self.train()

        return output
    def generate(self, mels, batched, target, overlap) :
        
        self.eval()
        output = []
        start = time.time()
        rnn1 = self.get_gru_cell(self.rnn1)
        rnn2 = self.get_gru_cell(self.rnn2)
        
        with torch.no_grad() :
            
            # mels = torch.FloatTensor(mels).cuda().unsqueeze(0)
            wave_len = (mels.size(-1) - 1) * self.hop_length
            mels = self.pad_tensor(mels.transpose(1, 2), pad=self.pad, side='both')
            mels, aux = self.upsample(mels.transpose(1, 2))
            
            if batched :
                mels = self.fold_with_overlap(mels, target, overlap)
                if aux is not None:
                    aux = self.fold_with_overlap(aux, target, overlap)

            b_size, seq_len, _ = mels.size()
            
            h1 = torch.zeros(b_size, self.rnn_dims).cuda()
            h2 = torch.zeros(b_size, self.rnn_dims).cuda()
            x = torch.zeros(b_size, 1).cuda()
            
            if self.use_aux_net:
                d = self.aux_dims
                aux_split = [aux[:, :, d*i:d*(i+1)] for i in range(4)]
            
            for i in range(seq_len) :

                m_t = mels[:, i, :]
                
                if self.use_aux_net:
                    a1_t, a2_t, a3_t, a4_t = \
                        (a[:, i, :] for a in aux_split)
                
                x = torch.cat([x, m_t, a1_t], dim=1) if self.use_aux_net else torch.cat([x, m_t], dim=1)
                x = self.I(x)
                h1 = rnn1(x, h1)
                
                x = x + h1
                inp = torch.cat([x, a2_t], dim=1) if self.use_aux_net else x
                h2 = rnn2(inp, h2)
                
                x = x + h2
                x = torch.cat([x, a3_t], dim=1) if self.use_aux_net else x
                x = F.relu(self.fc1(x))
                
                x = torch.cat([x, a4_t], dim=1) if self.use_aux_net else x
                x = F.relu(self.fc2(x))
                
                logits = self.fc3(x)

                if self.mode == 'mold':
                    sample = sample_from_discretized_mix_logistic(logits.unsqueeze(0).transpose(1, 2))
                    output.append(sample.view(-1))
                    x = sample.transpose(0, 1).cuda()
                elif self.mode == 'gauss':
                    sample = sample_from_gaussian(logits.unsqueeze(0).transpose(1, 2))
                    output.append(sample.view(-1))
                    x = sample.transpose(0, 1).cuda()
                elif type(self.mode) is int:
                    posterior = F.softmax(logits, dim=1)
                    distrib = torch.distributions.Categorical(posterior)

                    sample = 2 * distrib.sample().float() / (self.n_classes - 1.) - 1.
                    output.append(sample)
                    x = sample.unsqueeze(-1)                    
                else:
                    raise RuntimeError("Unknown model mode value - ", self.mode)
                
                if i % 100 == 0 : self.gen_display(i, seq_len, b_size, start)
        
        output = torch.stack(output).transpose(0, 1)
        output = output.cpu().numpy()
        output = output.astype(np.float64)
        
        if batched :
            output = self.xfade_and_unfold(output, target, overlap)
        else :
            output = output[0]
        
        if self.mulaw and type(self.mode) == int:
            output = ap.mulaw_decode(output, self.mode)

        # Fade-out at the end to avoid signal cutting out suddenly
        fade_out = np.linspace(1, 0, 20 * self.hop_length)
        output = output[:wave_len]
        output[-20 * self.hop_length:] *= fade_out
            
        self.train()
        return output
Esempio n. 3
0
    def generate(self, mels, batched, target, overlap):

        self.eval()
        output = []
        start = time.time()
        rnn1 = self.get_gru_cell(self.rnn1)
        rnn2 = self.get_gru_cell(self.rnn2)

        with torch.no_grad():

            # mels = torch.FloatTensor(mels).cuda().unsqueeze(0)
            mels = self.pad_tensor(mels.transpose(1, 2),
                                   pad=self.pad,
                                   side='both')
            mels, aux = self.upsample(mels.transpose(1, 2))

            if batched:
                mels = self.fold_with_overlap(mels, target, overlap)
                aux = self.fold_with_overlap(aux, target, overlap)

            b_size, seq_len, _ = mels.size()

            h1 = torch.zeros(b_size, self.rnn_dims).cuda()
            h2 = torch.zeros(b_size, self.rnn_dims).cuda()
            x = torch.zeros(b_size, 1).cuda()

            d = self.aux_dims
            aux_split = [aux[:, :, d * i:d * (i + 1)] for i in range(4)]

            for i in range(seq_len):

                m_t = mels[:, i, :]

                a1_t, a2_t, a3_t, a4_t = \
                    (a[:, i, :] for a in aux_split)

                x = torch.cat([x, m_t, a1_t], dim=1)
                x = self.I(x)
                h1 = rnn1(x, h1)

                x = x + h1
                inp = torch.cat([x, a2_t], dim=1)
                h2 = rnn2(inp, h2)

                x = x + h2
                x = torch.cat([x, a3_t], dim=1)
                x = F.relu(self.fc1(x))

                x = torch.cat([x, a4_t], dim=1)
                x = F.relu(self.fc2(x))

                logits = self.fc3(x)
                # TODO: implement other modes
                if self.mode == 'mold':
                    sample = sample_from_discretized_mix_logistic(
                        logits.unsqueeze(0).transpose(1, 2))
                    output.append(sample.view(-1))
                    # x = torch.FloatTensor([[sample]]).cuda()
                    x = sample.transpose(0, 1).cuda()
                else:
                    raise RuntimeError("Unknown model mode value - ",
                                       self.mode)

                if i % 100 == 0: self.gen_display(i, seq_len, b_size, start)

        output = torch.stack(output).transpose(0, 1)
        output = output.cpu().numpy()
        output = output.astype(np.float64)

        if batched:
            output = self.xfade_and_unfold(output, target, overlap)
        else:
            output = output[0]

        self.train()
        return output
Esempio n. 4
0
    def generate(self, mels, save_path: Union[str, Path], batched, target, overlap,
                 mu_law):
        self.eval()

        device = next(self.parameters()).device  # use same device as parameters

        mu_law = mu_law if self.mode == 'RAW' else False

        output = []
        start = time.time()
        rnn1 = self.get_gru_cell(self.rnn1)
        rnn2 = self.get_gru_cell(self.rnn2)


        mypqmf = PQMF()

        # MB-WaveRNN    |     WaveRNN
        mels = torch.as_tensor(mels, device=device)  # (80, 748)
        wave_len = (mels.size(-1) - 1) * self.hop_length
        # mels = self.pad_tensor(mels.transpose(1, 2), self.pad, self.pad_val, side='both')  # (752, 80)
        mels, aux = self.upsample(mels)  # (23936,80) (23936,128) | (95744,80) (95744,128)
        # print("mels.shape",mels.shape,"aux.shape",aux.shape)
        if batched:
            mels = self.fold_with_overlap(mels, target, overlap, self.pad_val)
            aux = self.fold_with_overlap(aux, target, overlap, self.pad_val)

        b_size, seq_len, _ = mels.size()

        h1 = torch.zeros(b_size, self.rnn_dims, device=device)
        h2 = torch.zeros(b_size, self.rnn_dims, device=device)

        x = torch.zeros(b_size, 4, device=device)

        d = self.aux_dims
        aux_split = [aux[:, :, d * i:d * (i + 1)] for i in range(4)]

        #########################  MultiBand-WaveRNN   #########################

        for i in range(seq_len):  # 23936 | 95744

            m_t = mels[:, i, :]
            a1_t, a2_t, a3_t, a4_t = \
                (a[:, i, :] for a in aux_split)

            # print("x.shape",x.shape,"m_t.shape",m_t.shape,"a1_t.shape",a1_t.shape)

            x = torch.cat([x, m_t, a1_t], dim=1)  # (5,4) + (5,32) + (5,80)
            x = self.I(x)

            h1 = rnn1(x, h1)
            x = x + h1
            inp = torch.cat([x, a2_t], dim=1)
            h2 = rnn2(inp, h2)
            x = x + h2
            x = torch.cat([x, a3_t], dim=1)
            x = F.relu(self.fc1(x))
            x = torch.cat([x, a4_t], dim=1)
            x = F.relu(self.fc2(x))

            logits0 = self.fc30(x)  # (batch,num_classes)
            logits1 = self.fc31(x)
            logits2 = self.fc32(x)
            logits3 = self.fc33(x)

            if self.mode == 'MOL':
                sample0 = sample_from_discretized_mix_logistic(logits0.unsqueeze(0).transpose(1, 2))
                sample1 = sample_from_discretized_mix_logistic(logits1.unsqueeze(0).transpose(1, 2))
                sample2 = sample_from_discretized_mix_logistic(logits2.unsqueeze(0).transpose(1, 2))
                sample3 = sample_from_discretized_mix_logistic(logits3.unsqueeze(0).transpose(1, 2))
                sample = torch.cat([sample0, sample1, sample2, sample3], dim=1)
                # x = torch.FloatTensor([[sample]]).cuda()
                x = sample.transpose(0, 1)

            elif self.mode == 'RAW':
                posterior0 = F.softmax(logits0, dim=1)  # (batch, num_classes)
                posterior1 = F.softmax(logits1, dim=1)
                posterior2 = F.softmax(logits2, dim=1)
                posterior3 = F.softmax(logits3, dim=1)

                distrib0 = torch.distributions.Categorical(posterior0)
                distrib1 = torch.distributions.Categorical(posterior1)
                distrib2 = torch.distributions.Categorical(posterior2)
                distrib3 = torch.distributions.Categorical(posterior3)

                # label -> float
                sample0 = 2 * distrib0.sample().float() / (self.n_classes - 1.) - 1.
                sample1 = 2 * distrib1.sample().float() / (self.n_classes - 1.) - 1.
                sample2 = 2 * distrib2.sample().float() / (self.n_classes - 1.) - 1.
                sample3 = 2 * distrib3.sample().float() / (self.n_classes - 1.) - 1.
                sample = torch.cat(
                    [sample0.unsqueeze(-1), sample1.unsqueeze(-1), sample2.unsqueeze(-1), sample3.unsqueeze(-1)],
                    dim=-1)

                output.append(sample)
                x = sample  # (batch, subbands)

            else:
                raise RuntimeError("Unknown model mode value - ", self.mode)

            if i % 100 == 0: self.gen_display(i, seq_len, b_size, start)


        output = torch.stack(output).squeeze()  # (T//sub_bands, sub_bands)
        output = output.cpu().numpy()
        output = output.astype(np.float)

        if mu_law:
            output = decode_mu_law(output, self.n_classes, False)

    #########################  MultiBand-WaveRNN   #########################

        if batched:
            output = self.xfade_and_unfold(output, overlap)
        # else:
        #     output = output[0]

        output = mypqmf.synthesis(
            torch.tensor(output, dtype=torch.float).unsqueeze(0).transpose(1,2)).numpy()  # (batch, sub_band, T//sub_band) -> (batch, 1, T)
        output = output.squeeze()

        # Fade-out at the end to avoid signal cutting out suddenly
        # fade_out = np.linspace(1, 0, 20 * self.hop_length)
        # output = output[:wave_len]
        # output[-20 * self.hop_length:] *= fade_out

        save_wav(output, save_path)
        self.train()
        return output
Esempio n. 5
0
    def generate(self, mels, save_path, batched, target, overlap, mu_law):

        output = []
        mu_law = mu_law if self.mode == 'RAW' else False

        start = time.time()
        rnn1 = self.get_gru_cell(self.rnn1)
        rnn2 = self.get_gru_cell(self.rnn2)

        with torch.no_grad():

            mels = mels.cuda()
            wave_len = mels.size(-1) * self.hop_length
            mels = self.pad_tensor(mels.transpose(1, 2),
                                   self.pad,
                                   self.pad_val,
                                   side='both')
            mels, aux = self.upsample(mels.transpose(1, 2))

            if batched:
                mels = self.fold_with_overlap(mels, target, overlap,
                                              self.pad_val)
                aux = self.fold_with_overlap(aux, target, overlap,
                                             self.pad_val)

            b_size, seq_len, _ = mels.size()

            h1 = torch.zeros(b_size, self.rnn_dims).half().cuda()
            h2 = torch.zeros(b_size, self.rnn_dims).half().cuda()
            x = torch.zeros(b_size, 1).half().cuda()

            d = self.aux_dims
            aux_split = [aux[:, :, d * i:d * (i + 1)] for i in range(4)]

            for i in range(seq_len):

                m_t = mels[:, i, :]

                a1_t, a2_t, a3_t, a4_t = (a[:, i, :] for a in aux_split)

                x = torch.cat([x, m_t, a1_t], dim=1)
                x = self.I(x)
                h1 = rnn1(x, h1)

                x = x + h1
                inp = torch.cat([x, a2_t], dim=1)
                h2 = rnn2(inp, h2)

                x = x + h2
                x = torch.cat([x, a3_t], dim=1)
                x = F.relu(self.fc1(x))

                x = torch.cat([x, a4_t], dim=1)
                x = F.relu(self.fc2(x))

                logits = self.fc3(x)

                if self.mode == 'MOL':
                    sample = sample_from_discretized_mix_logistic(
                        logits.unsqueeze(0).transpose(1, 2))
                    output.append(sample.view(-1))
                    # x = torch.FloatTensor([[sample]]).cuda()
                    x = sample.half().transpose(0, 1).cuda()
                elif self.mode == 'RAW':
                    posterior = F.softmax(logits.float(), dim=1)
                    distrib = torch.distributions.Categorical(posterior)
                    # label -> float
                    sample = 2 * distrib.sample().float() / (self.n_classes -
                                                             1.) - 1.
                    output.append(sample)
                    x = sample.half().unsqueeze(-1)
                else:
                    raise RuntimeError("Unknown model mode value - ",
                                       self.mode)

                if i % 100 == 0: self.gen_display(i, seq_len, b_size, start)

        output = torch.stack(output).transpose(0, 1)
        output = output.cpu().numpy()
        output = output.astype(np.float64)

        if mu_law:
            output = decode_mu_law(output, self.n_classes, False)

        if batched:
            output = self.xfade_and_unfold(output, target, overlap, -1)
        else:
            output = output[0]

        end = time.time()
        print(f'Elapsed {end - start} seconds')
        return save_wav(output[:wave_len], save_path)