Пример #1
0
def compute_spect(wav):
    # compute spectrogram
    D = pySTFT(wav).T
    D_mel = np.dot(D, mel_basis)
    D_db = 20 * np.log10(np.maximum(min_level, D_mel)) - 16
    S = (D_db + 100) / 100

    S = S[np.newaxis, :, :]
    if S.shape[1] <= 192:
        S, _ = pad_seq_to_2(S, 192)
    uttr = torch.from_numpy(S.astype(np.float32)).to(device)

    return uttr
Пример #2
0
def extract_f0(wav, fs):
    f0_rapt = sptk.rapt(wav.astype(np.float32) * 32768,
                        fs,
                        256,
                        min=lo,
                        max=hi,
                        otype=2)
    index_nonzero = (f0_rapt != -1e10)
    mean_f0, std_f0 = np.mean(f0_rapt[index_nonzero]), np.std(
        f0_rapt[index_nonzero])
    f0_norm = speaker_normalization(f0_rapt, index_nonzero, mean_f0, std_f0)

    f0_quantized = quantize_f0_numpy(f0_norm)[0]
    f0_onehot = f0_quantized[np.newaxis, :, :]
    print(f0_onehot.shape)

    if f0_onehot.shape[1] <= 192:
        f0_onehot, _ = pad_seq_to_2(f0_onehot, 192)

    return torch.from_numpy(f0_onehot).to(device)
Пример #3
0
    def train(self):
        # Set data loader.
        data_loader = self.vcc_loader
        
        # Fetch fixed inputs for debugging.
        data_iter = iter(data_loader)
        
        # Start training from scratch or resume training.
        start_iters = 0
        if self.resume_iters:
            print('Resuming ...')
            start_iters = self.resume_iters
            self.num_iters += self.resume_iters
            self.restore_model(self.resume_iters)
            self.print_optimizer(self.g_optimizer, 'G_optimizer')
            self.print_optimizer(self.ortho_optimizer, 'OrthoDisen_optimizer')
                        
        # Learning rate cache for decaying.
        g_lr = self.g_lr
        print('Current learning rates, g_lr: {}.'.format(g_lr))
        ortho_lr = self.ortho_lr
        print('Current learning rates, ortho_lr: {}.'.format(ortho_lr))
        
        # Print logs in specified order
        keys = ['G/loss_id']
        keys_ortho = ['ortho/loss_dis']
            
        # Start training.
        print('Start training...')
        start_time = time.time()
        for i in range(start_iters, self.num_iters):

            # =================================================================================== #
            #                             1. Preprocess input data                                #
            # =================================================================================== #

            # Fetch real images and labels.
            try:
                x_real_org, emb_org, f0_org, len_org = next(data_iter)
            except:
                data_iter = iter(data_loader)
                x_real_org, emb_org, f0_org, len_org = next(data_iter)
            
            x_real_org = x_real_org.to(self.device)
            # x_real : melsp ?
            emb_org = emb_org.to(self.device)
            len_org = len_org.to(self.device)
            f0_org = f0_org.to(self.device)

            # =================================================================================== #
            #                               2. Train the generator                                #
            # =================================================================================== #
            
            self.model = self.model.train()
                        
            # Identity mapping loss
            x_f0 = torch.cat((x_real_org, f0_org), dim=-1)
            x_f0_intrp = self.Interp(x_f0, len_org) 
            f0_org_intrp = quantize_f0_torch(x_f0_intrp[:,:,-1])[0]
            x_f0_intrp_org = torch.cat((x_f0_intrp[:,:,:-1], f0_org_intrp), dim=-1)
            
            # x_identic = self.G(x_f0_intrp_org, x_real_org, emb_org)
            x_identic = self.model(x_f0_intrp_org, x_real_org, emb_org)
            # identity:一致性,所以输入的都是original spk的
            g_loss_id = F.mse_loss(x_real_org, x_identic, reduction='mean')
            # x_real_org:ground truth
            # x_identic:预测的predicted
            # mse loss
           
            # Backward and optimize.
            g_loss = g_loss_id
            self.reset_grad()
            # zero grad : Sets gradients of all model parameters to zero.
            # 因为gradient是累积的accumulate,所以要让gradient朝着minimum的方向去走
            g_loss.backward()
            # loss.backward()获得所有parameter的gradient
            self.g_optimizer.step()
            # optimizer存了这些parameter的指针
            # step()根据这些parameter的gradient对parameter的值进行更新
            # https://www.zhihu.com/question/266160054

            # =================================================================================== #
            #                               3. Train the orthogonal                                #
            # =================================================================================== #

            self.ortho = self.ortho.train()
            # TODO : 根据张景轩的 + bert loss 改 3. train the orthogonal

            # Identity mapping loss
            x_f0 = torch.cat((x_real_org, f0_org), dim=-1)
            x_f0_intrp = self.Interp(x_f0, len_org)
            f0_org_intrp = quantize_f0_torch(x_f0_intrp[:, :, -1])[0]
            x_f0_intrp_org = torch.cat((x_f0_intrp[:, :, :-1], f0_org_intrp), dim=-1)

            x_identic = self.G(x_f0_intrp_org, x_real_org, emb_org)
            # identity:一致性,所以输入的都是original spk的
            ortho_loss_id = F.mse_loss(x_real_org, x_identic, reduction='mean')
            # mse loss

            # Backward and optimize.
            ortho_loss = ortho_loss_id
            self.reset_grad()
            # zero grad : Sets gradients of all model parameters to zero.
            # 因为gradient是累积的accumulate,所以要让gradient朝着minimum的方向去走
            g_loss.backward()
            # loss.backward()获得所有parameter的gradient
            self.ortho_optimizer.step()
            # optimizer存了这些parameter的指针
            # step()根据这些parameter的gradient对parameter的值进行更新
            # https://www.zhihu.com/question/266160054





            # Logging.
            loss = {}
            loss['G/loss_id'] = g_loss_id.item()

            # =================================================================================== #
            #                                 4. Miscellaneous
            #                                 # 杂项
            # =================================================================================== #

            # Print out training information.
            if (i+1) % self.log_step == 0:
                et = time.time() - start_time
                et = str(datetime.timedelta(seconds=et))[:-7]
                log = "Elapsed [{}], Iteration [{}/{}]".format(et, i+1, self.num_iters)
                for tag in keys:
                    log += ", {}: {:.8f}".format(tag, loss[tag])
                print(log)

                if self.use_tensorboard:
                    for tag, value in loss.items():
                        self.writer.add_scalar(tag, value, i+1)

            # Save model checkpoints.
            if (i+1) % self.model_save_step == 0:
                # model_save_step : 1000
                G_path = os.path.join(self.model_save_dir, '{}.ckpt'.format(i+1))
                torch.save({'model': self.G.state_dict(),
                            'optimizer_G': self.g_optimizer.state_dict(),
                            'optimizer_Ortho': self.ortho_optimizer.state_dict()}, G_path)
                # 保存了 model + optimizer
                print('Saved model checkpoints into {}...'.format(self.model_save_dir))            

            # Validation.
            if (i+1) % self.sample_step == 0:
                self.G = self.G.eval()
                with torch.no_grad():
                    loss_val = []
                    for val_sub in validation_pt:
                        # validation_pt: load 进 demo.pkl
                        emb_org_val = torch.from_numpy(val_sub[1]).to(self.device)
                        # spk的one hot embedding
                        for k in range(2, 3):
                            x_real_pad, _ = pad_seq_to_2(val_sub[k][0][np.newaxis,:,:], 192)
                            len_org = torch.tensor([val_sub[k][2]]).to(self.device) 
                            f0_org = np.pad(val_sub[k][1], (0, 408-val_sub[k][2]), 'constant', constant_values=(0, 0))
                            f0_quantized = quantize_f0_numpy(f0_org)[0]
                            f0_onehot = f0_quantized[np.newaxis, :, :]
                            f0_org_val = torch.from_numpy(f0_onehot).to(self.device) 
                            x_real_pad = torch.from_numpy(x_real_pad).to(self.device) 
                            x_f0 = torch.cat((x_real_pad, f0_org_val), dim=-1)
                            x_identic_val = self.G(x_f0, x_real_pad, emb_org_val)
                            g_loss_val = F.mse_loss(x_real_pad, x_identic_val, reduction='sum')
                            loss_val.append(g_loss_val.item())
                val_loss = np.mean(loss_val) 
                print('Validation loss: {}'.format(val_loss))
                if self.use_tensorboard:
                    self.writer.add_scalar('Validation_loss', val_loss, i+1)

            # plot test samples
            if (i+1) % self.sample_step == 0:
                self.G = self.G.eval()
                with torch.no_grad():
                    for val_sub in validation_pt:
                        emb_org_val = torch.from_numpy(val_sub[1]).to(self.device)         
                        for k in range(2, 3):
                            x_real_pad, _ = pad_seq_to_2(val_sub[k][0][np.newaxis,:,:], 408)
                            len_org = torch.tensor([val_sub[k][2]]).to(self.device) 
                            f0_org = np.pad(val_sub[k][1], (0, 408-val_sub[k][2]), 'constant', constant_values=(0, 0))
                            f0_quantized = quantize_f0_numpy(f0_org)[0]
                            f0_onehot = f0_quantized[np.newaxis, :, :]
                            f0_org_val = torch.from_numpy(f0_onehot).to(self.device) 
                            x_real_pad = torch.from_numpy(x_real_pad).to(self.device)
                            # 以下三行:把其中的某一个特征置0
                            x_f0 = torch.cat((x_real_pad, f0_org_val), dim=-1)
                            x_f0_F = torch.cat((x_real_pad, torch.zeros_like(f0_org_val)), dim=-1)
                            x_f0_C = torch.cat((torch.zeros_like(x_real_pad), f0_org_val), dim=-1)
                            
                            x_identic_val = self.G(x_f0, x_real_pad, emb_org_val)
                            x_identic_woF = self.G(x_f0_F, x_real_pad, emb_org_val)
                            x_identic_woR = self.G(x_f0, torch.zeros_like(x_real_pad), emb_org_val)
                            # woR:without rhythm
                            x_identic_woC = self.G(x_f0_C, x_real_pad, emb_org_val)
                            
                            melsp_gd_pad = x_real_pad[0].cpu().numpy().T
                            # ground truth
                            melsp_out = x_identic_val[0].cpu().numpy().T
                            # 4部分完整
                            melsp_woF = x_identic_woF[0].cpu().numpy().T
                            # 没有 pitch
                            melsp_woR = x_identic_woR[0].cpu().numpy().T
                            # 没有 rhythm
                            melsp_woC = x_identic_woC[0].cpu().numpy().T
                            # 没有content
                            
                            min_value = np.min(np.hstack([melsp_gd_pad, melsp_out, melsp_woF, melsp_woR, melsp_woC]))
                            max_value = np.max(np.hstack([melsp_gd_pad, melsp_out, melsp_woF, melsp_woR, melsp_woC]))
                            
                            fig, (ax1,ax2,ax3,ax4,ax5) = plt.subplots(5, 1, sharex=True)
                            im1 = ax1.imshow(melsp_gd_pad, aspect='auto', vmin=min_value, vmax=max_value)
                            im2 = ax2.imshow(melsp_out, aspect='auto', vmin=min_value, vmax=max_value)
                            im3 = ax3.imshow(melsp_woC, aspect='auto', vmin=min_value, vmax=max_value)
                            im4 = ax4.imshow(melsp_woR, aspect='auto', vmin=min_value, vmax=max_value)
                            im5 = ax5.imshow(melsp_woF, aspect='auto', vmin=min_value, vmax=max_value)
                            plt.savefig(f'{self.sample_dir}/{i+1}_{val_sub[0]}_{k}.png', dpi=150)
                            plt.close(fig)
Пример #4
0
    def train(self):
        # Set data loader.
        data_loader = self.vcc_loader

        # Fetch fixed inputs for debugging.
        data_iter = iter(data_loader)

        # Start training from scratch or resume training.
        start_iters = 0
        if self.resume_iters:
            print('Resuming ...')
            start_iters = self.resume_iters
            self.num_iters += self.resume_iters
            self.restore_model(self.resume_iters)
            self.print_optimizer(self.g_optimizer, 'G_optimizer')

        # Learning rate cache for decaying.
        g_lr = self.g_lr
        print('Current learning rates, g_lr: {}.'.format(g_lr))

        # Print logs in specified order
        keys = ['G/loss_id']

        # Start training.
        print('Start training...')
        start_time = time.time()
        for i in range(start_iters, self.num_iters):

            # =================================================================================== #
            #                             1. Preprocess input data                                #
            # =================================================================================== #

            # Fetch real images and labels.

            try:
                x_real_org, emb_org, f0_org, len_org = next(data_iter)
                x_real_trg, emb_trg, f0_trg, len_trg = next(data_iter)
            except:
                data_iter = iter(data_loader)
                x_real_org, emb_org, f0_org, len_org = next(data_iter)
                x_real_trg, emb_trg, f0_trg, len_trg = next(data_iter)

            x_real_org = x_real_org.to(self.device)
            emb_org = emb_org.to(self.device)
            len_org = len_org.to(self.device)
            f0_org = f0_org.to(self.device)

            f0_list = []
            for i in range(f0_trg.shape[0]):
                log_f0 = f0_trg[i].cpu().numpy()
                flatten_log_f0 = log_f0.flatten()
                f0_trg_quantized = quantize_f0_numpy(flatten_log_f0)[0]
                f0_trg_onehot = torch.from_numpy(f0_trg_quantized).to(
                    self.device)
                f0_list.append(f0_trg_onehot)
            f0_trg = torch.stack(f0_list)

            # =================================================================================== #
            #                               2. Train the generator                                #
            # =================================================================================== #

            self.G = self.G.train()

            # Identity mapping loss
            x_f0 = torch.cat((x_real_org, f0_org), dim=-1)
            x_f0_intrp = self.Interp(x_f0, len_org)
            f0_org_intrp = quantize_f0_torch(x_f0_intrp[:, :, -1])[0]
            x_f0_intrp_org = torch.cat((x_f0_intrp[:, :, :-1], f0_org_intrp),
                                       dim=-1)

            x_identic = self.G(x_f0_intrp_org, x_real_org, emb_org)
            g_loss_id = F.mse_loss(x_real_org, x_identic, reduction='mean')

            f0_pred = self.P(x_real_org, f0_trg)
            p_loss_id = F.mse_loss(f0_pred, f0_trg, reduction='mean')

            # Backward and optimize.
            g_loss = g_loss_id
            self.reset_grad()
            g_loss.backward()
            self.g_optimizer.step()

            # Logging.
            loss = {}
            loss['G/loss_id'] = g_loss_id.item()

            # =================================================================================== #
            #                                 4. Miscellaneous                                    #
            # =================================================================================== #

            # Print out training information.
            if (i + 1) % self.log_step == 0:
                et = time.time() - start_time
                et = str(datetime.timedelta(seconds=et))[:-7]
                log = "Elapsed [{}], Iteration [{}/{}]".format(
                    et, i + 1, self.num_iters)
                for tag in keys:
                    log += ", {}: {:.8f}".format(tag, loss[tag])
                print(log)

                if self.use_tensorboard:
                    for tag, value in loss.items():
                        self.writer.add_scalar(tag, value, i + 1)

            # Save model checkpoints.
            if (i + 1) % self.model_save_step == 0:
                G_path = os.path.join(self.model_save_dir,
                                      '{}-G.ckpt'.format(i + 1))
                torch.save(
                    {
                        'model': self.G.state_dict(),
                        'optimizer': self.g_optimizer.state_dict()
                    }, G_path)
                print('Saved model checkpoints into {}...'.format(
                    self.model_save_dir))

            # Validation.
            if (i + 1) % self.sample_step == 0:
                self.G = self.G.eval()
                with torch.no_grad():
                    loss_val = []
                    for val_sub in validation_pt:
                        emb_org_val = torch.from_numpy(val_sub[1]).to(
                            self.device)
                        for k in range(2, 3):
                            x_real_pad, _ = pad_seq_to_2(
                                val_sub[k][0][np.newaxis, :, :], 192)
                            len_org = torch.tensor([val_sub[k][2]
                                                    ]).to(self.device)
                            f0_org = np.pad(val_sub[k][1],
                                            (0, 192 - val_sub[k][2]),
                                            'constant',
                                            constant_values=(0, 0))
                            f0_quantized = quantize_f0_numpy(f0_org)[0]
                            f0_onehot = f0_quantized[np.newaxis, :, :]
                            f0_org_val = torch.from_numpy(f0_onehot).to(
                                self.device)
                            x_real_pad = torch.from_numpy(x_real_pad).to(
                                self.device)
                            x_f0 = torch.cat((x_real_pad, f0_org_val), dim=-1)
                            x_identic_val = self.G(x_f0, x_real_pad,
                                                   emb_org_val)
                            g_loss_val = F.mse_loss(x_real_pad,
                                                    x_identic_val,
                                                    reduction='sum')
                            loss_val.append(g_loss_val.item())
                val_loss = np.mean(loss_val)
                print('Validation loss: {}'.format(val_loss))
                if self.use_tensorboard:
                    self.writer.add_scalar('Validation_loss', val_loss, i + 1)

            # plot test samples
            if (i + 1) % self.sample_step == 0:
                self.G = self.G.eval()
                with torch.no_grad():
                    for val_sub in validation_pt:
                        emb_org_val = torch.from_numpy(val_sub[1]).to(
                            self.device)
                        for k in range(2, 3):
                            x_real_pad, _ = pad_seq_to_2(
                                val_sub[k][0][np.newaxis, :, :], 192)
                            len_org = torch.tensor([val_sub[k][2]
                                                    ]).to(self.device)
                            f0_org = np.pad(val_sub[k][1],
                                            (0, 192 - val_sub[k][2]),
                                            'constant',
                                            constant_values=(0, 0))
                            f0_quantized = quantize_f0_numpy(f0_org)[0]
                            f0_onehot = f0_quantized[np.newaxis, :, :]
                            f0_org_val = torch.from_numpy(f0_onehot).to(
                                self.device)
                            x_real_pad = torch.from_numpy(x_real_pad).to(
                                self.device)
                            x_f0 = torch.cat((x_real_pad, f0_org_val), dim=-1)
                            x_f0_F = torch.cat(
                                (x_real_pad, torch.zeros_like(f0_org_val)),
                                dim=-1)
                            x_f0_C = torch.cat(
                                (torch.zeros_like(x_real_pad), f0_org_val),
                                dim=-1)

                            x_identic_val = self.G(x_f0, x_real_pad,
                                                   emb_org_val)
                            x_identic_woF = self.G(x_f0_F, x_real_pad,
                                                   emb_org_val)
                            x_identic_woR = self.G(
                                x_f0, torch.zeros_like(x_real_pad),
                                emb_org_val)
                            x_identic_woC = self.G(x_f0_C, x_real_pad,
                                                   emb_org_val)

                            melsp_gd_pad = x_real_pad[0].cpu().numpy().T
                            melsp_out = x_identic_val[0].cpu().numpy().T
                            melsp_woF = x_identic_woF[0].cpu().numpy().T
                            melsp_woR = x_identic_woR[0].cpu().numpy().T
                            melsp_woC = x_identic_woC[0].cpu().numpy().T

                            min_value = np.min(
                                np.hstack([
                                    melsp_gd_pad, melsp_out, melsp_woF,
                                    melsp_woR, melsp_woC
                                ]))
                            max_value = np.max(
                                np.hstack([
                                    melsp_gd_pad, melsp_out, melsp_woF,
                                    melsp_woR, melsp_woC
                                ]))

                            fig, (ax1, ax2, ax3, ax4,
                                  ax5) = plt.subplots(5, 1, sharex=True)
                            im1 = ax1.imshow(melsp_gd_pad,
                                             aspect='auto',
                                             vmin=min_value,
                                             vmax=max_value)
                            im2 = ax2.imshow(melsp_out,
                                             aspect='auto',
                                             vmin=min_value,
                                             vmax=max_value)
                            im3 = ax3.imshow(melsp_woC,
                                             aspect='auto',
                                             vmin=min_value,
                                             vmax=max_value)
                            im4 = ax4.imshow(melsp_woR,
                                             aspect='auto',
                                             vmin=min_value,
                                             vmax=max_value)
                            im5 = ax5.imshow(melsp_woF,
                                             aspect='auto',
                                             vmin=min_value,
                                             vmax=max_value)
                            plt.savefig(
                                f'{self.sample_dir}/{i+1}_{val_sub[0]}_{k}.png',
                                dpi=150)
                            plt.close(fig)
Пример #5
0
metadata = pickle.load(open('assets/spmel/test.pkl', "rb"))

sbmt_i = metadata[0]
# P226
# 取出 speaker i 的信息,一共有3维:
# 0 :speaker名字
# 1 :speaker的onehot
# 2 :speaker的mel、F0、长度、uid

emb_org = torch.from_numpy(sbmt_i[1][np.newaxis, :]).to(device)
# i spk的one-hot的编码表示

x_org, f0_org, len_org, uid_org = sbmt_i[2]
# i speaker 的utterance等信息(选了某一句?)
uttr_org_pad, len_org_pad = pad_seq_to_2(x_org[np.newaxis, :, :], 400)
uttr_org_pad = torch.from_numpy(uttr_org_pad).to(device)
f0_org_pad = np.pad(f0_org, (0, 400 - len_org),
                    'constant',
                    constant_values=(0, 0))
f0_org_quantized = quantize_f0_numpy(f0_org_pad)[0]
f0_org_onehot = f0_org_quantized[np.newaxis, :, :]
f0_org_onehot = torch.from_numpy(f0_org_onehot).to(device)
uttr_f0_org = torch.cat((uttr_org_pad, f0_org_onehot), dim=-1)

sbmt_j = metadata[-1]
# P231
emb_trg = torch.from_numpy(sbmt_j[1][np.newaxis, :]).to(device)
x_trg, f0_trg, len_trg, uid_trg = sbmt_j[2]
uttr_trg_pad, len_trg_pad = pad_seq_to_2(x_trg[np.newaxis, :, :], 400)
uttr_trg_pad = torch.from_numpy(uttr_trg_pad).to(device)
    def train(self):
        # Set data loader.
        data_loader = self.vcc_loader

        # Fetch fixed inputs for debugging.
        data_iter = iter(data_loader)

        # Start training from scratch or resume training.
        start_iters = 0
        if self.resume_iters:
            print('Resuming ...', flush=True)
            start_iters = self.resume_iters
            self.num_iters += self.resume_iters
            self.restore_model(self.resume_iters)
            self.print_optimizer(self.optimizer_main, 'G_optimizer')
            # self.print_optimizer(self.optimizer_ortho, 'OrthoDisen_optimizer')

        # criterion = ParrotLoss(self.hparams).cuda()

        # Learning rate cache for decaying.
        lr_main = self.lr_main
        print('Current learning rates, lr_g: {}.'.format(lr_main), flush=True)
        # lr_ortho = self.lr_ortho
        # print('Current learning rates, lr_ortho: {}.'.format(lr_ortho))

        # Print logs in specified order
        keys = ['overall/loss_id', 'main/loss_id', 'ortho/loss_id']

        # Start training.
        print('Start training...', flush=True)
        start_time = time.time()
        for i in range(start_iters, self.num_iters):

            # =================================================================================== #
            #                             1. Preprocess input data                                #
            # =================================================================================== #

            # Fetch real images and labels.
            try:
                x_real_org, emb_org, f0_org, len_org = next(data_iter)
            except:
                data_iter = iter(data_loader)
                x_real_org, emb_org, f0_org, len_org = next(data_iter)

            x_real_org = x_real_org.to(self.device)
            # x_real : melsp ?
            emb_org = emb_org.to(self.device)
            len_org = len_org.to(self.device)
            f0_org = f0_org.to(self.device)

            # =================================================================================== #
            #                               2. Train the generator                                #
            # =================================================================================== #

            self.model.train()
            b = list(self.model.named_parameters())
            # Identity mapping loss
            x_f0 = torch.cat((x_real_org, f0_org), dim=-1)
            # x_f0 : [batch_size, seq_len, dim_feature]
            x_f0_intrp = self.Interp(x_f0, len_org)
            # x_f0_intrp : [batch_size, seq_len, dim_feature]
            f0_org_intrp = quantize_f0_torch(x_f0_intrp[:, :, -1])[0]
            x_f0_intrp_org = torch.cat((x_f0_intrp[:, :, :-1], f0_org_intrp),
                                       dim=-1)
            # x_f0_intrp_org : [batch_size, seq_len, dim_feature]
            self.model = self.model.to(self.device)

            # x_identic = self.G(x_f0_intrp_org, x_real_org, emb_org)
            mel_outputs, feature_predicts, ortho_inputs_integrals, mask_parts, invert_masks, alignments = self.model(
                x_f0_intrp_org, x_real_org, emb_org, len_org, len_org, len_org)
            loss_main_id = F.mse_loss(x_real_org,
                                      mel_outputs,
                                      reduction='mean')

            loss_ortho_id_L1 = self.loss_o(
                ortho_inputs_integrals[0].cuda(),
                feature_predicts[0].cuda() * invert_masks[0].cuda() +
                ortho_inputs_integrals[0].cuda() * mask_parts[0].cuda())

            temp = feature_predicts[1].cuda() * invert_masks[1].cuda(
            ) + ortho_inputs_integrals[1].cuda() * mask_parts[1].cuda()
            loss_ortho_id_BCE = self.loss_BCE(
                feature_predicts[1].cuda() * invert_masks[1].cuda() +
                ortho_inputs_integrals[1].cuda() * mask_parts[1].cuda(),
                ortho_inputs_integrals[1].cuda())

            loss_main = loss_main_id
            loss_ortho_id = loss_ortho_id_L1 + loss_ortho_id_BCE

            loss_ortho = loss_ortho_id
            ''''''
            w_ini = 100
            decay_rate = 0.999
            decay_steps = 12500
            ''''''
            # w_decay = w_ini * decay_rate ^ (i / decay_steps)
            w_decay = w_ini * math.pow(decay_rate, (i + 1) / decay_steps)

            loss_overall_id = self.w_main * w_decay * loss_main + self.w_ortho * loss_ortho
            loss_overall = loss_overall_id / (self.w_main * w_ini)
            # identity:一致性,所以输入的都是original spk的
            # loss_main_id = F.mse_loss(x_real_org, x_identic, reduction='mean')
            # mse loss
            # Backward and optimize.
            # loss_main = loss_main_id
            # loss_ortho = loss_ortho_id
            self.reset_grad()
            """ loss_main 训练 """
            # for p in self.parameters_orth:
            #     p.requires_grad_(requires_grad=False)
            # zero grad : Sets gradients of all model parameters to zero.
            # 因为gradient是累积的accumulate,所以要让gradient朝着minimum的方向去走
            loss_overall.backward()
            self.optimizer_main.step()

            # for p in self.parameters_orth:
            #     p.requires_grad_(requires_grad=True)
            # for p in self.parameters_main:
            #     p.requires_grad_(requires_grad=False)
            #
            # loss_ortho.backward()
            # self.optimizer_ortho.step()

            # loss.backward()获得所有parameter的gradient
            # optimizer存了这些parameter的指针
            # step()根据这些parameter的gradient对parameter的值进行更新
            # https://www.zhihu.com/question/266160054

            # Logging.
            loss = {}
            loss['overall/loss_id'] = loss_overall_id.item()
            loss['main/loss_id'] = loss_main_id.item()
            loss['ortho/loss_id'] = loss_ortho_id.item()

            # =================================================================================== #
            #                                 4. Miscellaneous
            #                                 # 杂项
            # =================================================================================== #

            # Print out training information.
            if (i + 1) % self.log_step == 0:
                et = time.time() - start_time
                et = str(datetime.timedelta(seconds=et))[:-7]
                log = "Elapsed [{}], Iteration [{}/{}]".format(
                    et, i + 1, self.num_iters)
                for tag in keys:
                    log += ", {}: {:.8f}".format(tag, loss[tag])
                print(log, flush=True)

                if self.use_tensorboard:
                    for tag, value in loss.items():
                        self.writer.add_scalar(tag, value, i + 1)

            # Save model checkpoints.
            if (i + 1) % self.model_save_step == 0:
                # model_save_step : 1000
                model_path = os.path.join(self.model_save_dir,
                                          '{}.ckpt'.format(i + 1))
                torch.save(
                    {
                        'model': self.model.state_dict(),
                        'optimizer_main': self.optimizer_main.state_dict()
                    }, model_path)
                # 保存了 model + optimizer
                print('Saved model checkpoints at iteration {} into {}...'.
                      format(i, self.model_save_dir),
                      flush=True)

                # Validation.
            if (i + 1) % self.sample_step == 0:
                self.model = self.model.eval()
                with torch.no_grad():
                    loss_overall_val_list = []
                    loss_main_val_list = []
                    loss_frame_main_val_list = []
                    loss_ortho_val_list = []
                    for val_sub in validation_pt:
                        # validation_pt: load 进 demo.pkl
                        emb_org_val = torch.from_numpy(
                            val_sub[1][np.newaxis, :]).to(self.device)
                        # spk的one hot embedding
                        for k in range(2, 3):
                            x_real_pad, _ = pad_seq_to_2(
                                val_sub[k][0][np.newaxis, :, :], 408)
                            len_org = torch.tensor([val_sub[k][2]
                                                    ]).to(self.device)
                            f0_org = np.pad(val_sub[k][1],
                                            (0, 408 - val_sub[k][2]),
                                            'constant',
                                            constant_values=(0, 0))
                            f0_quantized = quantize_f0_numpy(f0_org)[0]
                            f0_onehot = f0_quantized[np.newaxis, :, :]
                            f0_org_val = torch.from_numpy(f0_onehot).to(
                                self.device)
                            x_real_pad = torch.from_numpy(x_real_pad).to(
                                self.device)
                            x_f0 = torch.cat((x_real_pad, f0_org_val), dim=-1)

                            mel_outputs, feature_predicts, ortho_inputs_integrals, mask_parts, invert_masks, alignments = self.model(
                                x_f0, x_real_pad, emb_org_val, len_org,
                                len_org, len_org)
                            loss_main_val = F.mse_loss(x_real_pad,
                                                       mel_outputs,
                                                       reduction='sum')
                            loss_frame_main_val = F.mse_loss(x_real_pad,
                                                             mel_outputs,
                                                             reduction='mean')
                            loss_ortho_id_L1_val = self.loss_o(
                                ortho_inputs_integrals[0].cuda(),
                                feature_predicts[0].cuda() *
                                invert_masks[0].cuda() +
                                ortho_inputs_integrals[0].cuda() *
                                mask_parts[0].cuda())

                            loss_ortho_id_BCE_val = self.loss_BCE(
                                feature_predicts[1].cuda() *
                                invert_masks[1].cuda() +
                                ortho_inputs_integrals[1].cuda() *
                                mask_parts[1].cuda(),
                                ortho_inputs_integrals[1].cuda())

                            loss_ortho_val = loss_ortho_id_L1_val + loss_ortho_id_BCE_val
                            ''''''
                            # w_ini = 0.5
                            # decay_rate = 0.1
                            # decay_steps = 1000
                            # # w_decay = w_ini * decay_rate ^ (i / decay_steps)
                            w_ini = 100
                            decay_rate = 0.999
                            decay_steps = 12500
                            w_decay = w_ini * math.pow(decay_rate,
                                                       (i + 1) / decay_steps)
                            ''''''
                            loss_overall_id = self.w_main * w_decay * loss_main_val + self.w_ortho * loss_ortho_val
                            loss_overall_id = loss_overall_id / (w_ini *
                                                                 self.w_main)
                            # loss_overall_id = self.w_main * loss_main_val + self.w_ortho * loss_ortho_val
                            # 分别的 loss list
                            loss_overall_val_list.append(
                                loss_overall_id.item())
                            loss_main_val_list.append(loss_main_val.item())
                            loss_ortho_val_list.append(loss_ortho_val.item())
                            loss_frame_main_val_list.append(
                                loss_frame_main_val.item())
                val_overall_loss = np.mean(loss_overall_val_list)
                val_main_loss = np.mean(loss_main_val_list)
                val_ortho_loss = np.mean(loss_ortho_val_list)
                val_frame_main_loss = np.mean(loss_frame_main_val_list)
                print(
                    'Validation overall loss : {}, main loss: {}, ortho loss: {}, frame_main_loss:{}'
                    .format(val_overall_loss, val_main_loss, val_ortho_loss,
                            val_frame_main_loss),
                    flush=True)
                y_pred = [
                    mel_outputs, alignments[0], alignments[1], alignments[2]
                ]
                # y = mel_targets
                self.logger.log_validation(self.model,
                                           y=x_real_pad,
                                           y_pred=y_pred,
                                           iteration=i + 1)
                if self.use_tensorboard:
                    self.writer.add_scalar('Validation_overall_loss',
                                           val_overall_loss, i + 1)
                    self.writer.add_scalar('Validation_main_loss',
                                           val_main_loss, i + 1)
                    self.writer.add_scalar('Validation_frame_main_loss',
                                           val_frame_main_loss, i + 1)
                    self.writer.add_scalar('Validation_ortho_loss',
                                           val_ortho_loss, i + 1)

            # plot test samples
            if (i + 1) % self.sample_step == 0:
                self.model = self.model.eval()
                with torch.no_grad():
                    for val_sub in validation_pt[:3]:
                        emb_org_val = torch.from_numpy(
                            val_sub[1][np.newaxis, :]).to(self.device)
                        for k in range(2, 3):
                            x_real_pad, _ = pad_seq_to_2(
                                val_sub[k][0][np.newaxis, :, :], 408)
                            len_org = torch.tensor([val_sub[k][2]
                                                    ]).to(self.device)
                            f0_org = np.pad(val_sub[k][1],
                                            (0, 408 - val_sub[k][2]),
                                            'constant',
                                            constant_values=(0, 0))
                            f0_quantized = quantize_f0_numpy(f0_org)[0]
                            f0_onehot = f0_quantized[np.newaxis, :, :]
                            f0_org_val = torch.from_numpy(f0_onehot).to(
                                self.device)
                            x_real_pad = torch.from_numpy(x_real_pad).to(
                                self.device)
                            # 以下三行:把其中的某一个特征置0
                            x_f0 = torch.cat((x_real_pad, f0_org_val), dim=-1)
                            x_f0_F = torch.cat(
                                (x_real_pad, torch.zeros_like(f0_org_val)),
                                dim=-1)
                            x_f0_C = torch.cat(
                                (torch.zeros_like(x_real_pad), f0_org_val),
                                dim=-1)

                            x_identic_val, _, _, _, _, _ = self.model(
                                x_f0, x_real_pad, emb_org_val, len_org,
                                len_org, len_org)
                            x_identic_woF, _, _, _, _, _ = self.model(
                                x_f0_F, x_real_pad, emb_org_val, len_org,
                                len_org, len_org)
                            x_identic_woR, _, _, _, _, _ = self.model(
                                x_f0, torch.zeros_like(x_real_pad),
                                emb_org_val, len_org, len_org, len_org)
                            x_identic_woC, _, _, _, _, _ = self.model(
                                x_f0_C, x_real_pad, emb_org_val, len_org,
                                len_org, len_org)
                            x_identic_woU, _, _, _, _, _ = self.model(
                                x_f0, x_real_pad,
                                torch.zeros_like(emb_org_val), len_org,
                                len_org, len_org)

                            melsp_gd_pad = x_real_pad[0].cpu().numpy().T
                            # ground truth
                            melsp_out = x_identic_val[0].cpu().numpy().T
                            # 4部分完整
                            melsp_woF = x_identic_woF[0].cpu().numpy().T
                            # 没有 pitch
                            melsp_woR = x_identic_woR[0].cpu().numpy().T
                            # 没有 rhythm
                            melsp_woC = x_identic_woC[0].cpu().numpy().T
                            # 没有content
                            melsp_woU = x_identic_woU[0].cpu().numpy().T
                            # 没有U

                            min_value = np.min(
                                np.hstack([
                                    melsp_gd_pad, melsp_out, melsp_woF,
                                    melsp_woR, melsp_woC, melsp_woU
                                ]))
                            max_value = np.max(
                                np.hstack([
                                    melsp_gd_pad, melsp_out, melsp_woF,
                                    melsp_woR, melsp_woC, melsp_woU
                                ]))

                            fig, (ax1, ax2, ax3, ax4, ax5,
                                  ax6) = plt.subplots(6, 1, sharex=True)
                            im1 = ax1.imshow(melsp_gd_pad,
                                             aspect='auto',
                                             vmin=min_value,
                                             vmax=max_value)
                            im2 = ax2.imshow(melsp_out,
                                             aspect='auto',
                                             vmin=min_value,
                                             vmax=max_value)
                            im3 = ax3.imshow(melsp_woC,
                                             aspect='auto',
                                             vmin=min_value,
                                             vmax=max_value)
                            im4 = ax4.imshow(melsp_woR,
                                             aspect='auto',
                                             vmin=min_value,
                                             vmax=max_value)
                            im5 = ax5.imshow(melsp_woF,
                                             aspect='auto',
                                             vmin=min_value,
                                             vmax=max_value)
                            im6 = ax6.imshow(melsp_woU,
                                             aspect='auto',
                                             vmin=min_value,
                                             vmax=max_value)
                            plt.savefig(
                                f'{self.sample_dir}/{i + 1}_{val_sub[0]}_{k}_{val_sub[2][3]}.png',
                                dpi=150)
                            plt.close(fig)
Пример #7
0
metadata = pickle.load(
    open('/hd0/speechsplit/preprocessed/spmel/train.pkl', "rb"))

sbmt_i = metadata[0]
emb_org = torch.from_numpy(sbmt_i[1]).unsqueeze(0).to(device)

root_dir = "/hd0/speechsplit/preprocessed/spmel"
feat_dir = "/hd0/speechsplit/preprocessed/raptf0"

# mel-spectrogram, f0 contour load
x_org = np.load(os.path.join(root_dir, sbmt_i[2]))
f0_org = np.load(os.path.join(feat_dir, sbmt_i[2]))
len_org = x_org.shape[0]

uttr_org_pad, len_org_pad = pad_seq_to_2(x_org[np.newaxis, :, :], out_len)
uttr_org_pad = torch.from_numpy(uttr_org_pad).to(device)
f0_org_pad = np.pad(f0_org, (0, out_len - len_org),
                    'constant',
                    constant_values=(0, 0))
f0_org_quantized = quantize_f0_numpy(f0_org_pad)[0]
f0_org_onehot = f0_org_quantized[np.newaxis, :, :]
f0_org_onehot = torch.from_numpy(f0_org_onehot).to(device)
uttr_f0_org = torch.cat((uttr_org_pad, f0_org_onehot), dim=-1)

sbmt_j = metadata[1]
emb_trg = torch.from_numpy(sbmt_j[1]).unsqueeze(0).to(device)
#x_trg, f0_trg, len_trg, uid_trg = sbmt_j[2]
x_trg = np.load(os.path.join(root_dir, sbmt_j[2]))
f0_trg = np.load(os.path.join(feat_dir, sbmt_j[2]))
len_trg = x_org.shape[0]
Пример #8
0
                    min=lo,
                    max=hi,
                    otype=2)
index_nonzero = (f0_rapt != -1e10)
mean_f0, std_f0 = np.mean(f0_rapt[index_nonzero]), np.std(
    f0_rapt[index_nonzero])
f0_norm = speaker_normalization(f0_rapt, index_nonzero, mean_f0, std_f0)

assert len(S) == len(f0_rapt)

# S:utterance spectrogram, f0_norm:normalized pitch contour
f0_quantized = quantize_f0_numpy(f0_norm)[0]
f0_onehot = f0_quantized[np.newaxis, :, :]
print(f0_onehot.shape)
f0_onehot = f0_onehot[:, :192, :]
f0_onehot, _ = pad_seq_to_2(f0_onehot, 192)
f0_onehot = torch.from_numpy(f0_onehot).to(device)

# concat pitch contour to freq axis (cols)
S = S[np.newaxis, :192, :]
S, _ = pad_seq_to_2(S, 192)
uttr = torch.from_numpy(S.astype(np.float32)).to(device)

#f0_onehot = tr.zeros_like(f0_onehot)
uttr_f0 = torch.cat((uttr, f0_onehot), dim=-1)

# Generate back from components
emb = tr.zeros(1, 82).to(device)
print(uttr_f0.shape, uttr.shape, emb.shape)

# uttr_f0 = tr.zeros_like(uttr_f0)