Пример #1
0
class Solver(object):
    """训练操作的说明"""

    #  参数为数据加载器与配置
    def __init__(self, data_loader, config):
        # 进行赋值
        self.config = config
        self.data_loader = data_loader
        # 模型配置
        # 赋值三个损失函数,循环损失,域分类损失,身份映射损失的衡量参数
        self.lambda_cycle = config.lambda_cycle
        self.lambda_cls = config.lambda_cls
        self.lambda_identity = config.lambda_identity

        # 训练配置
        # 数据文件路由
        self.data_dir = config.data_dir
        # 测试文件路由
        self.test_dir = config.test_dir
        # 批处理大小
        self.batch_size = config.batch_size
        # 训练判别器D的总迭代次数
        self.num_iters = config.num_iters
        # 衰减学习率的迭代次数
        self.num_iters_decay = config.num_iters_decay
        # 生成器G的学习频率
        self.g_lr = config.g_lr
        # 判别器D的学习频率
        self.d_lr = config.d_lr
        # 域分类器C的学习频率
        self.c_lr = config.c_lr
        # 每次G更新时的D更新次数
        self.n_critic = config.n_critic
        # Adam优化器的beta1,2参数
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        # 从此步骤恢复培训
        self.resume_iters = config.resume_iters

        # 测试配置
        # 从这个步骤开始训练模型
        self.test_iters = config.test_iters
        # ast.literal_eval为解析函数,并安全地进行类型转换
        # 目标发音者
        self.trg_speaker = ast.literal_eval(config.trg_speaker)
        # 源发音者
        self.src_speaker = config.src_speaker

        # 其他配置
        # 是否使用tensorboard记录
        self.use_tensorboard = config.use_tensorboard
        # torch.device代表将torch.Tensor分配到的设备的对象。torch.device包含一个设备类型(‘cpu’或‘cuda’)和可选的设备序号。
        # 是使用cuda还是cpu计算
        self.device = torch.device(
            'cuda:0' if torch.cuda.is_available() else 'cpu')
        # 将speakers标签列表二值化
        self.spk_enc = LabelBinarizer().fit(speakers)
        # 字典
        # 记录目录
        self.log_dir = config.log_dir
        # 样本目录
        self.sample_dir = config.sample_dir
        # 模型目录
        self.model_save_dir = config.model_save_dir
        # 输出目录
        self.result_dir = config.result_dir

        # 步长
        # 记录步长
        self.log_step = config.log_step
        # 采样步长
        self.sample_step = config.sample_step
        # 模型保存间隔步长
        self.model_save_step = config.model_save_step
        # 学习率更新步长
        self.lr_update_step = config.lr_update_step

        # 建立模型与tensorboard.
        self.build_model()
        if self.use_tensorboard:
            # 使用tensorboard记录器
            self.build_tensorboard()

    # 赋值三个模型器
    def build_model(self):
        # 将模型赋值给类属性
        self.G = Generator()
        self.D = Discriminator()
        self.C = DomainClassifier()
        # torch.optim.Adam用于实现Adam算法

        # Adam(Adaptive Moment Estimation)本质上是带有动量项的RMSprop,它利用梯度的一阶矩估计和二阶矩估计动态调整每个参数的学习率。
        # 它的优点主要在于经过偏置校正后,每一次迭代学习率都有个确定范围,使得参数比较平稳。

        # params (iterable) – 待优化参数的iterable或者是定义了参数组的dict
        # lr (float, 可选) – 学习率(默认:1e-3)
        # 同样也称为学习率或步长因子,它控制了权重的更新比率(如 0.001)。较大的值(如 0.3)在学习率更新前会有更快的初始学习,而较小的值(如 1.0E-5)会令训练收敛到更好的性能。
        # betas (Tuple[float, float], 可选) – 用于计算梯度以及梯度平方的运行平均值的系数(默认:0.9,0.999)betas = (beta1,beta2)
        # beta1:一阶矩估计的指数衰减率(如 0.9)。
        # beta2:二阶矩估计的指数衰减率(如 0.999)。该超参数在稀疏梯度(如在 NLP 或计算机视觉任务中)中应该设置为接近 1 的数。
        # eps (float, 可选) – 为了增加数值计算的稳定性而加到分母里的项(默认:1e-8)epsilon:该参数是非常小的数,其为了防止在实现中除以零(如 10E-8)。
        # weight_decay (float, 可选) – 权重衰减(L2级惩罚)(默认: 0)

        self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr,
                                            [self.beta1, self.beta2])
        self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr,
                                            [self.beta1, self.beta2])
        self.c_optimizer = torch.optim.Adam(self.C.parameters(), self.c_lr,
                                            [self.beta1, self.beta2])

        self.print_network(self.G, 'G')
        self.print_network(self.D, 'D')
        self.print_network(self.C, 'C')
        # Module.to方法用来移动和转换参数与缓冲区,类似于torch.Tensor.to,但是仅支持float类型
        # 这里将模型转移到GPU运算
        self.G.to(self.device)
        self.D.to(self.device)
        self.C.to(self.device)

    def print_network(self, model, name):
        """打印出网络的相关信息"""
        num_params = 0
        # Module.parameters()获取网络的参数
        # 计算模型网络对应的参数频次
        for p in model.parameters():
            # numel返回数组中元素的个数
            num_params += p.numel()
        print(model)
        print(name)
        print("参数的个数为:{}".format(num_params))

    # 使用tensorboard
    def build_tensorboard(self):
        """建立一个tensorboard记录器"""
        from logger import Logger
        # 建立一个记录器,传入记录地址
        self.logger = Logger(self.log_dir)

    def update_lr(self, g_lr, d_lr, c_lr):
        """生成器、判别器和域分类器的衰减学习率"""
        for param_group in self.g_optimizer.param_groups:
            param_group['lr'] = g_lr
        for param_group in self.d_optimizer.param_groups:
            param_group['lr'] = d_lr
        for param_group in self.c_optimizer.param_groups:
            param_group['lr'] = c_lr

    # 训练方法
    def train(self):
        # 衰减的学习率缓存
        g_lr = self.g_lr
        d_lr = self.d_lr
        c_lr = self.c_lr
        # 开始训练步骤数为0
        start_iters = 0
        # 如果存在就跳过
        if self.resume_iters:
            pass
        # 调用定义的个性化标准化方法
        norm = Normalizer()
        # iter用来生成迭代器,这里用来迭代加载数据集
        data_iter = iter(self.data_loader)
        print('开始训练......')
        # 记录当前时间,now函数取当前时间
        start_time = datetime.now()
        # 利用总迭代次数来进行遍历
        for i in range(start_iters, self.num_iters):
            # =================================================================================== #
            #                                 1.预处理输入数据                                    #
            # =================================================================================== #
            # 获取真实的图像和对应标签标签
            try:
                # next方法为迭代下一个迭代器
                # 利用自定义的加载器获取真实x值,发音者标签在组中索引与源标签
                x_real, speaker_idx_org, label_org = next(data_iter)
            except:
                # 如果迭代器有问题就再转换为迭代器一次然后迭代
                data_iter = iter(self.data_loader)
                x_real, speaker_idx_org, label_org = next(data_iter)

            # 随机生成目标域标签
            # torch.randperm返回一个从0到参数-1范围的随机数组
            # 因为标签二值化了,所以这里的标签是10组成的,所以一共有label_org.size(0)个标签
            # 获得的是随机索引
            rand_idx = torch.randperm(label_org.size(0))
            # 根据随机数作为源标签的索引作为目标标签数
            label_trg = label_org[rand_idx]
            # 同理得到随机目标发音者
            speaker_idx_trg = speaker_idx_org[rand_idx]
            # to表示使用cpu或者gpu运行
            x_real = x_real.to(self.device)  # 输入数据
            label_org = label_org.to(self.device)  # 源域one-hot格式标签
            label_trg = label_trg.to(self.device)  # 目标域ont-hot格式标签
            speaker_idx_org = speaker_idx_org.to(self.device)  # 源域标签
            speaker_idx_trg = speaker_idx_trg.to(self.device)  # 目标域标签

            # =================================================================================== #
            #                                      2.训练判别器                                   #
            # =================================================================================== #
            # 用真实音频数据计算损失
            # nn.CrossEntropyLoss()为交叉熵损失函数,但是不是普通的形式,而是主要是将softmax-log-NLLLoss合并到一块得到的结果。
            CELoss = nn.CrossEntropyLoss()
            # 调用分类器计算真实数据
            cls_real = self.C(x_real)
            # 计算对应的域分类损失,即用交叉熵实现
            cls_loss_real = CELoss(input=cls_real, target=speaker_idx_org)
            # 重置缓冲区,具体实现在下面
            self.reset_grad()
            # tensor.backward为自动求导函数
            cls_loss_real.backward()
            # optimizer.step这个方法会更新模型所有的参数以提升学习率,一般在backward函数后根据其计算的梯度来更新参数
            self.c_optimizer.step()
            # 记录中
            loss = {}
            # 从真实域分类损失张量中获取元素值
            # item()得到一个元素张量里面的元素值
            loss['C/C_loss'] = cls_loss_real.item()

            # 基于源数据的D判断结果
            out_r = self.D(x_real, label_org)
            # 用假音频帧计算损失
            # 根据真实样本与目标标签生成生成样本
            x_fake = self.G(x_real, label_trg)
            # detach截断反向传播的梯度流,从而让梯度不影响判别器D
            # 基于生成样本的D判断结果
            out_f = self.D(x_fake.detach(), label_trg)
            # torch.nn.Function.binary_cross_entropy_with_logits度量目标逻辑和输出逻辑之间的二进制交叉熵的函数
            # 接受任意形状的输入,target要求与输入形状一致。切记:target的值必须在[0,N-1]之间,其中N为类别数,否则会出现莫名其妙的错误,比如loss为负数。
            # 计算其实就是交叉熵,不过输入不要求在0,1之间,该函数会自动添加sigmoid运算
            # 返回一个填充了标量值1的张量,其大小与输入相同。torch.ones_like(input)
            # 相当于torch.ones(input.size(), dtype=input.dtype, layout=input.layout, device=input.device)

            # binary_cross_entropy_with_logits和binary_cross_entropy的区别
            # 有一个(类)损失函数名字中带了with_logits. 而这里的logits指的是,该损失函数已经内部自带了计算logit的操作,
            # 无需在传入给这个loss函数之前手动使用sigmoid/softmax将之前网络的输入映射到[0,1]之间
            d_loss_t = F.binary_cross_entropy_with_logits(input=out_f,target=torch.zeros_like(out_f, dtype=torch.float)) + \
                F.binary_cross_entropy_with_logits(input=out_r, target=torch.ones_like(out_r, dtype=torch.float))
            # 生成样本的分类结果
            out_cls = self.C(x_fake)
            # 交叉熵计算生成样本的域分类损失
            d_loss_cls = CELoss(input=out_cls, target=speaker_idx_trg)

            # 计算梯度惩罚的损失
            alpha = torch.rand(x_real.size(0), 1, 1, 1).to(self.device)
            # 计算x_hat
            # requires_grad_设置积分方法,将requires_grad是否积分的属性设置为真
            # 取一个随机数混合真实样本和生成样本得到一个x尖
            x_hat = (alpha * x_real.data +
                     (1 - alpha) * x_fake.data).requires_grad_(True)
            # 计算混合样本和目标标签的判别结果
            out_src = self.D(x_hat, label_trg)
            # 调用自定义方法得到处理导数后的数据
            d_loss_gp = self.gradient_penalty(out_src, x_hat)
            # 计算判别器的总体损失
            d_loss = d_loss_t + self.lambda_cls * d_loss_cls + 5 * d_loss_gp
            # 调用自定义方法重置梯度变化缓冲区
            self.reset_grad()
            # 对D的损失求导
            d_loss.backward()
            # 更新模型判别器D参数
            self.d_optimizer.step()

            # loss['D/d_loss_t'] = d_loss_t.item()
            # loss['D/loss_cls'] = d_loss_cls.item()
            # loss['D/D_gp'] = d_loss_gp.item()
            # 获取判别器损失
            loss['D/D_loss'] = d_loss.item()

            # =================================================================================== #
            #                                       3.训练生成器                                  #
            # =================================================================================== #
            # 进行模运算,判读更新时间
            if (i + 1) % self.n_critic == 0:
                # 源至目标域
                # 利用真实样本和目标标签生成生成样本
                x_fake = self.G(x_real, label_trg)
                #  判别生成样本与目标标签
                g_out_src = self.D(x_fake, label_trg)
                # 将生成与目标标签的损失与相同大小纯1张量计算交叉熵得到生成G损失
                g_loss_fake = F.binary_cross_entropy_with_logits(
                    input=g_out_src,
                    target=torch.ones_like(g_out_src, dtype=torch.float))
                # 得到真实样本通过域分类器得到的类别
                out_cls = self.C(x_real)
                # 计算C计算类别与输入的类别的交叉熵损失即G的分类损失
                g_loss_cls = CELoss(input=out_cls, target=speaker_idx_org)

                # 目标至源域
                # 通过G将生成样本转换为源标签
                x_reconst = self.G(x_fake, label_org)
                # 得到循环一致性损失,即通过G转回来的损失,按道理这两个是同样的
                # l1_loss为L1损失函数,即平均绝对误差
                g_loss_rec = F.l1_loss(x_reconst, x_real)

                # 源到源域(身份一致性损失).
                # 通过真实样本与源标签生成,按道理也是生成x_real
                x_fake_iden = self.G(x_real, label_org)
                # 利用L1损失函数计算
                id_loss = F.l1_loss(x_fake_iden, x_real)

                # 后退和优化
                # 得到生成器的总体损失函数
                g_loss = g_loss_fake + self.lambda_cycle * g_loss_rec +\
                 self.lambda_cls * g_loss_cls + self.lambda_identity * id_loss
                # 重置梯度变化缓冲区
                self.reset_grad()
                # 对G损失求导
                g_loss.backward()
                # 更新生成器参数
                self.g_optimizer.step()

                # 记录对应的损失
                loss['G/loss_fake'] = g_loss_fake.item()
                loss['G/loss_rec'] = g_loss_rec.item()
                loss['G/loss_cls'] = g_loss_cls.item()
                loss['G/loss_id'] = id_loss.item()
                loss['G/g_loss'] = g_loss.item()
            # =================================================================================== #
            #                                           4.其他                                    #
            # =================================================================================== #
            # 打印训练相关信息
            if (i + 1) % self.log_step == 0:
                # 得到训练时间
                et = datetime.now() - start_time
                # 截取后面的时间段
                et = str(et)[:-7]
                # 耗时与迭代次数
                log = "耗时:[{}], 迭代次数:[{}/{}]".format(et, i + 1, self.num_iters)
                # 打印对应损失值
                for tag, value in loss.items():
                    log += ", {}: {:.4f}".format(tag, value)
                print(log)
                # 如果调用tensorboard来记录训练过程
                if self.use_tensorboard:
                    for tag, value in loss.items():
                        # 添加到log中
                        self.logger.scalar_summary(tag, value, i + 1)

            # 翻译固定数据进行调试
            if (i + 1) % self.sample_step == 0:
                # torch.no_grad是一个上下文管理器,被该语句包括起来的部分将不会track 梯度
                # 所有依赖他的tensor会全部变成True,反向传播时就不会自动求导了,反向传播就不会保存梯度,因此大大节约了显存或者说内存。
                with torch.no_grad():
                    # 调用自定义方法,定义一个路由,并随机选取一个发音者作为测试数据
                    d, speaker = TestSet(self.test_dir).test_data()
                    # random.choice返回参数的随机项
                    # 随机在speakers中选择一个不是目标的发音者
                    target = random.choice(
                        [x for x in speakers if x != speaker])
                    # 将二值化的标签组取出第一个作为目标
                    # LabelBinary.transfrom方法将复杂类标签转换为二进制标签
                    label_t = self.spk_enc.transform([target])[0]
                    # np.asarray将python原生列表或元组形式的现有数据来创建numpy数组
                    label_t = np.asarray([label_t])
                    # 取出字典中的文件名与内容
                    for filename, content in d.items():
                        f0 = content['f0']
                        ap = content['ap']
                        # 调用自定义方法处理对应的数据
                        sp_norm_pad = self.pad_coded_sp(
                            content['coded_sp_norm'])

                        convert_result = []
                        for start_idx in range(
                                0, sp_norm_pad.shape[1] - FRAMES + 1, FRAMES):
                            one_seg = sp_norm_pad[:,
                                                  start_idx:start_idx + FRAMES]

                            one_seg = torch.FloatTensor(one_seg).to(
                                self.device)
                            one_seg = one_seg.view(1, 1, one_seg.size(0),
                                                   one_seg.size(1))
                            l = torch.FloatTensor(label_t)
                            one_seg = one_seg.to(self.device)
                            l = l.to(self.device)
                            one_set_return = self.G(one_seg,
                                                    l).data.cpu().numpy()
                            one_set_return = np.squeeze(one_set_return)
                            one_set_return = norm.backward_process(
                                one_set_return, target)
                            convert_result.append(one_set_return)

                        convert_con = np.concatenate(convert_result, axis=1)
                        convert_con = convert_con[:,
                                                  0:content['coded_sp_norm'].
                                                  shape[1]]
                        contigu = np.ascontiguousarray(convert_con.T,
                                                       dtype=np.float64)
                        decoded_sp = decode_spectral_envelope(contigu,
                                                              SAMPLE_RATE,
                                                              fft_size=FFTSIZE)
                        f0_converted = norm.pitch_conversion(
                            f0, speaker, target)
                        wav = synthesize(f0_converted, decoded_sp, ap,
                                         SAMPLE_RATE)

                        name = f'{speaker}-{target}_iter{i+1}_{filename}'
                        path = os.path.join(self.sample_dir, name)
                        print(f'[save]:{path}')
                        librosa.output.write_wav(path, wav, SAMPLE_RATE)

            # 保存模型检查点
            if (i + 1) % self.model_save_step == 0:
                G_path = os.path.join(self.model_save_dir,
                                      '{}-G.ckpt'.format(i + 1))
                D_path = os.path.join(self.model_save_dir,
                                      '{}-D.ckpt'.format(i + 1))
                C_path = os.path.join(self.model_save_dir,
                                      '{}-C.ckpt'.format(i + 1))
                torch.save(self.G.state_dict(), G_path)
                torch.save(self.D.state_dict(), D_path)
                torch.save(self.C.state_dict(), C_path)
                print('Saved model checkpoints into {}...'.format(
                    self.model_save_dir))

            # 衰减学习率
            if (i + 1) % self.lr_update_step == 0 and (i + 1) > (
                    self.num_iters - self.num_iters_decay):
                g_lr -= (self.g_lr / float(self.num_iters_decay))
                d_lr -= (self.d_lr / float(self.num_iters_decay))
                c_lr -= (self.c_lr / float(self.num_iters_decay))
                self.update_lr(g_lr, d_lr, c_lr)
                print('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(
                    g_lr, d_lr))

    def gradient_penalty(self, y, x):
        """计算梯度惩罚: (L2_norm(dy/dx) - 1)**2."""
        # 根据标签的纬度创建出权重矩阵
        weight = torch.ones(y.size()).to(self.device)
        # torch.autograd.grad计算并返回outputs对inputs的梯度dy/dx
        dydx = torch.autograd.grad(
            outputs=y,
            inputs=x,
            # 雅可比向量积中的“向量”,用来将梯度向量转换为梯度标量,并可以衡量y梯度的各个数据的权重
            grad_outputs=weight,
            retain_graph=True,
            # 保持计算的导数图
            create_graph=True,
            only_inputs=True)[0]
        # 将导数变为二维的数据并保持行数
        dydx = dydx.view(dydx.size(0), -1)
        # 计算导数的L2范数
        dydx_l2norm = torch.sqrt(torch.sum(dydx**2, dim=1))
        return torch.mean((dydx_l2norm - 1)**2)

    def reset_grad(self):
        """重置梯度变化缓冲区"""
        self.g_optimizer.zero_grad()
        self.d_optimizer.zero_grad()
        self.c_optimizer.zero_grad()

    def restore_model(self, resume_iters):
        """重置训练好的发生器和鉴别器"""
        print('从{}步骤开始加载训练过的模型...'.format(resume_iters))
        G_path = os.path.join(self.model_save_dir,
                              '{}-G.ckpt'.format(resume_iters))
        D_path = os.path.join(self.model_save_dir,
                              '{}-D.ckpt'.format(resume_iters))
        C_path = os.path.join(self.model_save_dir,
                              '{}-C.ckpt'.format(resume_iters))
        self.G.load_state_dict(
            torch.load(G_path, map_location=lambda storage, loc: storage))
        self.D.load_state_dict(
            torch.load(D_path, map_location=lambda storage, loc: storage))
        self.C.load_state_dict(
            torch.load(C_path, map_location=lambda storage, loc: storage))

    @staticmethod
    def pad_coded_sp(coded_sp_norm):
        f_len = coded_sp_norm.shape[1]
        if f_len >= FRAMES:
            pad_length = FRAMES - (f_len - (f_len // FRAMES) * FRAMES)
        elif f_len < FRAMES:
            pad_length = FRAMES - f_len

        sp_norm_pad = np.hstack(
            (coded_sp_norm, np.zeros((coded_sp_norm.shape[0], pad_length))))
        return sp_norm_pad

    def test(self):
        """用StarGAN处理音频数据"""
        # 加载训练生成器
        self.restore_model(self.test_iters)
        norm = Normalizer()

        # 设置数据加载器
        d, speaker = TestSet(self.test_dir).test_data(self.src_speaker)
        targets = self.trg_speaker

        for target in targets:
            print(target)
            assert target in speakers
            label_t = self.spk_enc.transform([target])[0]
            label_t = np.asarray([label_t])

            with torch.no_grad():

                for filename, content in d.items():
                    f0 = content['f0']
                    ap = content['ap']
                    sp_norm_pad = self.pad_coded_sp(content['coded_sp_norm'])

                    convert_result = []
                    for start_idx in range(0,
                                           sp_norm_pad.shape[1] - FRAMES + 1,
                                           FRAMES):
                        one_seg = sp_norm_pad[:, start_idx:start_idx + FRAMES]

                        one_seg = torch.FloatTensor(one_seg).to(self.device)
                        one_seg = one_seg.view(1, 1, one_seg.size(0),
                                               one_seg.size(1))
                        l = torch.FloatTensor(label_t)
                        one_seg = one_seg.to(self.device)
                        l = l.to(self.device)
                        one_set_return = self.G(one_seg, l).data.cpu().numpy()
                        one_set_return = np.squeeze(one_set_return)
                        one_set_return = norm.backward_process(
                            one_set_return, target)
                        convert_result.append(one_set_return)

                    convert_con = np.concatenate(convert_result, axis=1)
                    convert_con = convert_con[:, 0:content['coded_sp_norm'].
                                              shape[1]]
                    contigu = np.ascontiguousarray(convert_con.T,
                                                   dtype=np.float64)
                    decoded_sp = decode_spectral_envelope(contigu,
                                                          SAMPLE_RATE,
                                                          fft_size=FFTSIZE)
                    f0_converted = norm.pitch_conversion(f0, speaker, target)
                    wav = synthesize(f0_converted, decoded_sp, ap, SAMPLE_RATE)

                    name = f'{speaker}-{target}_iter{self.test_iters}_{filename}'
                    path = os.path.join(self.result_dir, name)
                    print(f'[保存]:{path}')
                    librosa.output.write_wav(path, wav, SAMPLE_RATE)
class Solver(object):
    """Solver for training and testing StarGAN."""
    def __init__(self, train_loader, test_loader, config):
        """Initialize configurations."""

        # Data loader.
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.sampling_rate = config.sampling_rate

        # Model configurations.
        self.num_speakers = config.num_speakers
        self.lambda_cls = config.lambda_cls
        self.lambda_rec = config.lambda_rec
        self.lambda_gp = config.lambda_gp
        self.lambda_id = config.lambda_id

        # Training configurations.
        self.batch_size = config.batch_size
        self.num_iters = config.num_iters
        self.num_iters_decay = config.num_iters_decay
        self.g_lr = config.g_lr
        self.d_lr = config.d_lr
        self.c_lr = config.c_lr
        self.n_critic = config.n_critic
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.resume_iters = config.resume_iters

        # Test configurations.
        self.test_iters = config.test_iters

        # Miscellaneous.
        self.use_tensorboard = config.use_tensorboard
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')

        # Directories.
        self.log_dir = config.log_dir
        self.sample_dir = config.sample_dir
        self.model_save_dir = config.model_save_dir

        # Step size.
        self.log_step = config.log_step
        self.sample_step = config.sample_step
        self.model_save_step = config.model_save_step
        self.lr_update_step = config.lr_update_step

        # Build the model and tensorboard.
        self.build_model()
        if self.use_tensorboard:
            self.build_tensorboard()

    def build_model(self):
        """Create a generator and a discriminator."""
        self.generator = Generator()
        self.discriminator = Discriminator(num_speakers=self.num_speakers)
        self.classifier = DomainClassifier()

        self.g_optimizer = torch.optim.Adam(self.generator.parameters(),
                                            self.g_lr,
                                            [self.beta1, self.beta2])
        self.d_optimizer = torch.optim.Adam(self.discriminator.parameters(),
                                            self.d_lr,
                                            [self.beta1, self.beta2])
        self.c_optimizer = torch.optim.Adam(self.discriminator.parameters(),
                                            self.c_lr,
                                            [self.beta1, self.beta2])

        self.print_network(self.generator, 'Generator')
        self.print_network(self.discriminator, 'Discriminator')
        self.print_network(self.classifier, 'Domain Classifier')

        self.generator.to(self.device)
        self.discriminator.to(self.device)
        self.classifier.to(self.device)

    def print_network(self, model, name):
        """Print out the network information."""
        num_params = 0
        for p in model.parameters():
            num_params += p.numel()
        print(model)
        print(name)
        print("The number of parameters: {}".format(num_params))

    def restore_model(self, resume_iters):
        """Restore the trained generator and discriminator."""
        print(
            'Loading the trained models from step {}...'.format(resume_iters))
        g_path = os.path.join(self.model_save_dir,
                              '{}-G.ckpt'.format(resume_iters))
        d_path = os.path.join(self.model_save_dir,
                              '{}-D.ckpt'.format(resume_iters))
        c_path = os.path.join(self.model_save_dir,
                              '{}-C.ckpt'.format(resume_iters))

        self.generator.load_state_dict(
            torch.load(g_path, map_location=lambda storage, loc: storage))
        self.discriminator.load_state_dict(
            torch.load(d_path, map_location=lambda storage, loc: storage))
        self.classifier.load_state_dict(
            torch.load(c_path, map_location=lambda storage, loc: storage))

    def build_tensorboard(self):
        """Build a tensorboard logger."""
        from logger import Logger
        self.logger = Logger(self.log_dir)

    def update_lr(self, g_lr, d_lr, c_lr):
        """Decay learning rates of the generator and discriminator."""
        for param_group in self.g_optimizer.param_groups:
            param_group['lr'] = g_lr
        for param_group in self.d_optimizer.param_groups:
            param_group['lr'] = d_lr
        for param_group in self.c_optimizer.param_groups:
            param_group['lr'] = c_lr

    def reset_grad(self):
        """Reset the gradientgradient buffers."""
        self.g_optimizer.zero_grad()
        self.d_optimizer.zero_grad()
        self.c_optimizer.zero_grad()

    def denorm(self, x):
        """Convert the range from [-1, 1] to [0, 1]."""
        out = (x + 1) / 2
        return out.clamp_(0, 1)

    def gradient_penalty(self, y, x):
        """Compute gradient penalty: (L2_norm(dy/dx) - 1)**2."""
        weight = torch.ones(y.size()).to(self.device)
        dydx = torch.autograd.grad(outputs=y,
                                   inputs=x,
                                   grad_outputs=weight,
                                   retain_graph=True,
                                   create_graph=True,
                                   only_inputs=True)[0]

        dydx = dydx.view(dydx.size(0), -1)
        dydx_l2norm = torch.sqrt(torch.sum(dydx**2, dim=1))
        return torch.mean((dydx_l2norm - 1)**2)

    def label2onehot(self, labels, dim):
        """Convert label indices to one-hot vectors."""
        batch_size = labels.size(0)
        out = torch.zeros(batch_size, dim)
        out[np.arange(batch_size), labels.long()] = 1
        return out

    def sample_spk_c(self, size):
        spk_c = np.random.randint(0, self.num_speakers, size=size)
        spk_c_cat = to_categorical(spk_c, self.num_speakers)
        return torch.LongTensor(spk_c), torch.FloatTensor(spk_c_cat)

    def classification_loss(self, logit, target):
        """Compute softmax cross entropy loss."""
        return F.cross_entropy(logit, target)

    def load_wav(self, wavfile, sr=16000):
        wav, _ = librosa.load(wavfile, sr=sr, mono=True)
        return wav_padding(wav, sr=16000, frame_period=5, multiple=4)

    def train(self):
        """Train StarGAN."""
        # Set data loader.
        train_loader = self.train_loader
        data_iter = iter(train_loader)

        # Read a batch of testdata
        test_wavfiles = self.test_loader.get_batch_test_data(batch_size=4)
        test_wavs = [self.load_wav(wavfile) for wavfile in test_wavfiles]

        # Determine whether do copysynthesize when first do training-time conversion test.
        cpsyn_flag = [True, False][0]
        # f0, timeaxis, sp, ap = world_decompose(wav = wav, fs = sampling_rate, frame_period = frame_period)

        # Learning rate cache for decaying.
        g_lr = self.g_lr
        d_lr = self.d_lr
        c_lr = self.c_lr

        # Start training from scratch or resume training.
        start_iters = 0
        if self.resume_iters:
            print("resuming step %d ..." % self.resume_iters)
            start_iters = self.resume_iters
            self.restore_model(self.resume_iters)

        # Start training.
        print('Start training...')
        start_time = time.time()
        for i in range(start_iters, self.num_iters):
            # =================================================================================== #
            #                             1. Preprocess input data                                #
            # =================================================================================== #

            # Fetch labels.
            try:
                mc_real, spk_label_org, spk_c_org = next(data_iter)
            except:
                data_iter = iter(train_loader)
                mc_real, spk_label_org, spk_c_org = next(data_iter)

            mc_real.unsqueeze_(1)  # (B, D, T) -> (B, 1, D, T) for conv2d

            # Generate target domain labels randomly.
            # spk_label_trg: int,   spk_c_trg:one-hot representation
            spk_label_trg, spk_c_trg = self.sample_spk_c(mc_real.size(0))

            mc_real = mc_real.to(self.device)  # Input mc.
            spk_label_org = spk_label_org.to(
                self.device)  # Original spk labels.
            spk_c_org = spk_c_org.to(self.device)  # Original spk one-hot.
            spk_label_trg = spk_label_trg.to(self.device)  # Target spk labels.
            spk_c_trg = spk_c_trg.to(self.device)  # Target spk one-hot.

            # =================================================================================== #
            #                             2. Train the Domain Classifier                           #
            # =================================================================================== #

            # Compute real classification loss.
            cls_real = self.classifier(mc_real)
            cls_loss = self.classification_loss(cls_real, spk_label_org)

            # Backwards and optimize
            self.reset_grad()
            cls_loss.backward()
            self.c_optimizer.step()

            # Logging.
            loss = {}
            loss['C/c_loss'] = cls_loss.item()

            # =================================================================================== #
            #                             3. Train the Discriminator                              #
            # =================================================================================== #

            # Compute loss with real mc feats.
            d_out_src = self.discriminator(mc_real, spk_c_org)
            mc_fake = self.generator(mc_real, spk_c_trg)
            d_out_fake = self.discriminator(mc_fake.detach(), spk_c_trg)
            d_loss = F.binary_cross_entropy_with_logits(d_out_fake, torch.zeros_like(d_out_fake, dtype=torch.float)) + \
                F.binary_cross_entropy_with_logits(d_out_src, torch.ones_like(d_out_src, dtype=torch.float))

            # Compute classification loss.
            d_out_cls = self.classifier(mc_fake)
            d_loss_cls = self.classification_loss(d_out_cls, spk_label_trg)

            # Compute loss for gradient penalty.
            alpha = torch.rand(mc_real.size(0), 1, 1, 1).to(self.device)
            x_hat = (alpha * mc_real.data +
                     (1 - alpha) * mc_fake.data).requires_grad_(True)
            d_out_src = self.discriminator(x_hat, spk_c_trg)
            d_loss_gp = self.gradient_penalty(d_out_src, x_hat)

            # Backward and optimize.
            d_loss = d_loss + self.lambda_cls * d_loss_cls + self.lambda_gp * d_loss_gp
            self.reset_grad()
            d_loss.backward()
            self.d_optimizer.step()

            # Logging.
            loss['D/loss_gp'] = d_loss_gp.item()
            loss['D/loss'] = d_loss.item()

            # =================================================================================== #
            #                               4. Train the generator                                #
            # =================================================================================== #

            if (i + 1) % self.n_critic == 0:
                # Original-to-target domain.
                mc_fake = self.generator(mc_real, spk_c_trg)
                g_out_src = self.discriminator(mc_fake, spk_c_trg)
                g_loss_fake = -torch.mean(g_out_src)

                # Classification loss.
                g_out_cls = self.classifier(mc_fake)
                g_loss_cls = self.classification_loss(g_out_cls, spk_label_trg)

                # Target-to-original domain. Cycle-consistent.
                mc_reconst = self.generator(mc_fake, spk_c_org)
                g_loss_rec = torch.mean(torch.abs(mc_real - mc_reconst))

                # Original-to-original, Id mapping loss. Mapping
                mc_fake_id = self.generator(mc_real, spk_c_org)
                g_loss_id = torch.mean(torch.abs(mc_real - mc_fake_id))

                # Backward and optimize.
                g_loss = g_loss_fake \
                    + self.lambda_rec * g_loss_rec \
                    + self.lambda_cls * g_loss_cls \
                    + self.lambda_id * g_loss_id

                self.reset_grad()
                g_loss.backward()
                self.g_optimizer.step()

                # Logging.
                loss['G/loss_fake'] = g_loss_fake.item()
                loss['G/loss_rec'] = g_loss_rec.item()
                loss['G/loss_cls'] = g_loss_cls.item()
                loss['G/loss'] = g_loss.item()

            # =================================================================================== #
            #                                 5. Miscellaneous                                    #
            # =================================================================================== #

            # Print out training information.
            if (i + 1) % self.log_step == 0:
                et = time.time() - start_time
                et = str(datetime.timedelta(seconds=et))[:-7]
                log = "Elapsed [{}], Iteration [{}/{}]".format(
                    et, i + 1, self.num_iters)
                for tag, value in loss.items():
                    log += ", {}: {:.4f}".format(tag, value)
                print(log)

                if self.use_tensorboard:
                    for tag, value in loss.items():
                        self.logger.scalar_summary(tag, value, i + 1)

            if (i + 1) % self.sample_step == 0:
                sampling_rate = 16000
                num_mcep = 36
                frame_period = 5
                with torch.no_grad():
                    for idx, wav in tqdm(enumerate(test_wavs)):
                        wav_name = basename(test_wavfiles[idx])
                        # print(wav_name)
                        f0, timeaxis, sp, ap = world_decompose(
                            wav=wav,
                            fs=sampling_rate,
                            frame_period=frame_period)
                        f0_converted = pitch_conversion(
                            f0=f0,
                            mean_log_src=self.test_loader.logf0s_mean_src,
                            std_log_src=self.test_loader.logf0s_std_src,
                            mean_log_target=self.test_loader.logf0s_mean_trg,
                            std_log_target=self.test_loader.logf0s_std_trg)
                        coded_sp = world_encode_spectral_envelop(
                            sp=sp, fs=sampling_rate, dim=num_mcep)

                        coded_sp_norm = (coded_sp -
                                         self.test_loader.mcep_mean_src
                                         ) / self.test_loader.mcep_std_src
                        coded_sp_norm_tensor = torch.FloatTensor(
                            coded_sp_norm.T).unsqueeze_(0).unsqueeze_(1).to(
                                self.device)
                        conds = torch.FloatTensor(
                            self.test_loader.spk_c_trg).to(self.device)
                        # print(conds.size())
                        coded_sp_converted_norm = self.generator(
                            coded_sp_norm_tensor, conds).data.cpu().numpy()
                        coded_sp_converted = np.squeeze(
                            coded_sp_converted_norm
                        ).T * self.test_loader.mcep_std_trg + self.test_loader.mcep_mean_trg
                        coded_sp_converted = np.ascontiguousarray(
                            coded_sp_converted)
                        # decoded_sp_converted = world_decode_spectral_envelop(coded_sp = coded_sp_converted, fs = sampling_rate)
                        wav_transformed = world_speech_synthesis(
                            f0=f0_converted,
                            coded_sp=coded_sp_converted,
                            ap=ap,
                            fs=sampling_rate,
                            frame_period=frame_period)

                        librosa.output.write_wav(
                            join(
                                self.sample_dir,
                                str(i + 1) + '-' + wav_name.split('.')[0] +
                                '-vcto-{}'.format(self.test_loader.trg_spk) +
                                '.wav'), wav_transformed, sampling_rate)
                        if cpsyn_flag:
                            wav_cpsyn = world_speech_synthesis(
                                f0=f0,
                                coded_sp=coded_sp,
                                ap=ap,
                                fs=sampling_rate,
                                frame_period=frame_period)
                            librosa.output.write_wav(
                                join(self.sample_dir, 'cpsyn-' + wav_name),
                                wav_cpsyn, sampling_rate)
                    cpsyn_flag = False

            # Save model checkpoints.
            if (i + 1) % self.model_save_step == 0:
                g_path = os.path.join(self.model_save_dir,
                                      '{}-G.ckpt'.format(i + 1))
                d_path = os.path.join(self.model_save_dir,
                                      '{}-D.ckpt'.format(i + 1))
                c_path = os.path.join(self.model_save_dir,
                                      '{}-C.ckpt'.format(i + 1))

                torch.save(self.generator.state_dict(), g_path)
                torch.save(self.discriminator.state_dict(), d_path)
                torch.save(self.classifier.state_dict(), c_path)
                print('Saved model checkpoints into {}...'.format(
                    self.model_save_dir))

            # Decay learning rates.
            if (i + 1) % self.lr_update_step == 0 and (i + 1) > (
                    self.num_iters - self.num_iters_decay):
                g_lr -= (self.g_lr / float(self.num_iters_decay))
                d_lr -= (self.d_lr / float(self.num_iters_decay))
                c_lr -= (self.c_lr / float(self.num_iters_decay))
                self.update_lr(g_lr, d_lr, c_lr)
                print('Decayed learning rates, g_lr: {}, d_lr: {}, c_lr: {}.'.
                      format(g_lr, d_lr, c_lr))
Пример #3
0
class Solver(object):
    """docstring for Solver."""
    def __init__(self, data_loader, config):

        self.config = config
        self.data_loader = data_loader
        # Model configurations.

        self.lambda_cycle = config.lambda_cycle
        self.lambda_cls = config.lambda_cls
        self.lambda_identity = config.lambda_identity

        # Training configurations.
        self.data_dir = config.data_dir
        self.test_dir = config.test_dir
        self.batch_size = config.batch_size
        self.num_iters = config.num_iters
        self.num_iters_decay = config.num_iters_decay
        self.g_lr = config.g_lr
        self.d_lr = config.d_lr
        self.c_lr = config.c_lr
        self.n_critic = config.n_critic
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.resume_iters = config.resume_iters

        # Test configurations.
        self.test_iters = config.test_iters
        self.trg_speaker = ast.literal_eval(config.trg_speaker)
        self.src_speaker = config.src_speaker

        # Miscellaneous.
        self.use_tensorboard = config.use_tensorboard
        self.device = torch.device(
            'cuda:0' if torch.cuda.is_available() else 'cpu')
        self.spk_enc = LabelBinarizer().fit(speakers)
        # Directories.
        self.log_dir = config.log_dir
        self.sample_dir = config.sample_dir
        self.model_save_dir = config.model_save_dir
        self.result_dir = config.result_dir

        # Step size.
        self.log_step = config.log_step
        self.sample_step = config.sample_step
        self.model_save_step = config.model_save_step
        self.lr_update_step = config.lr_update_step

        # Build the model and tensorboard.
        self.build_model()
        if self.use_tensorboard:
            self.build_tensorboard()

    def build_model(self):
        self.G = Generator()
        self.D = Discriminator()
        self.C = DomainClassifier()

        self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr,
                                            [self.beta1, self.beta2])
        self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr,
                                            [self.beta1, self.beta2])
        self.c_optimizer = torch.optim.Adam(self.C.parameters(), self.c_lr,
                                            [self.beta1, self.beta2])

        self.print_network(self.G, 'G')
        self.print_network(self.D, 'D')
        self.print_network(self.C, 'C')

        self.G.to(self.device)
        self.D.to(self.device)
        self.C.to(self.device)

    def print_network(self, model, name):
        """Print out the network information."""
        num_params = 0
        for p in model.parameters():
            num_params += p.numel()
        print(model)
        print(name)
        print("The number of parameters: {}".format(num_params))

    def build_tensorboard(self):
        """Build a tensorboard logger."""
        from logger import Logger
        self.logger = Logger(self.log_dir)

    def update_lr(self, g_lr, d_lr, c_lr):
        """Decay learning rates of the generator and discriminator and classifier."""
        for param_group in self.g_optimizer.param_groups:
            param_group['lr'] = g_lr
        for param_group in self.d_optimizer.param_groups:
            param_group['lr'] = d_lr
        for param_group in self.c_optimizer.param_groups:
            param_group['lr'] = c_lr

    def train(self):
        # Learning rate cache for decaying.
        g_lr = self.g_lr
        d_lr = self.d_lr
        c_lr = self.c_lr

        start_iters = 0
        if self.resume_iters:
            pass

        norm = Normalizer()
        data_iter = iter(self.data_loader)

        print('Start training......')
        start_time = datetime.now()

        for i in range(start_iters, self.num_iters):
            # =================================================================================== #
            #                             1. Preprocess input data                                #
            # =================================================================================== #
            # Fetch real images and labels.
            try:
                x_real, speaker_idx_org, label_org = next(data_iter)
            except:
                data_iter = iter(self.data_loader)
                x_real, speaker_idx_org, label_org = next(data_iter)

            # Generate target domain labels randomly.
            rand_idx = torch.randperm(label_org.size(0))
            label_trg = label_org[rand_idx]
            speaker_idx_trg = speaker_idx_org[rand_idx]

            x_real = x_real.to(self.device)  # Input images.
            label_org = label_org.to(
                self.device)  # Original domain one-hot labels.
            label_trg = label_trg.to(
                self.device)  # Target domain one-hot labels.
            speaker_idx_org = speaker_idx_org.to(
                self.device)  # Original domain labels
            speaker_idx_trg = speaker_idx_trg.to(
                self.device)  #Target domain labels

            # =================================================================================== #
            #                             2. Train the discriminator                              #
            # =================================================================================== #
            # Compute loss with real audio frame.
            CELoss = nn.CrossEntropyLoss()
            cls_real = self.C(x_real)
            cls_loss_real = CELoss(input=cls_real, target=speaker_idx_org)

            self.reset_grad()
            cls_loss_real.backward()
            self.c_optimizer.step()
            # Logging.
            loss = {}
            loss['C/C_loss'] = cls_loss_real.item()

            out_r = self.D(x_real, label_org)
            # Compute loss with fake audio frame.
            x_fake = self.G(x_real, label_trg)
            out_f = self.D(x_fake.detach(), label_trg)
            d_loss_t = F.binary_cross_entropy_with_logits(input=out_f,target=torch.zeros_like(out_f, dtype=torch.float)) + \
                F.binary_cross_entropy_with_logits(input=out_r, target=torch.ones_like(out_r, dtype=torch.float))

            out_cls = self.C(x_fake)
            d_loss_cls = CELoss(input=out_cls, target=speaker_idx_trg)

            # Compute loss for gradient penalty.
            alpha = torch.rand(x_real.size(0), 1, 1, 1).to(self.device)
            x_hat = (alpha * x_real.data +
                     (1 - alpha) * x_fake.data).requires_grad_(True)
            out_src = self.D(x_hat, label_trg)
            d_loss_gp = self.gradient_penalty(out_src, x_hat)

            d_loss = d_loss_t + self.lambda_cls * d_loss_cls + 5 * d_loss_gp

            self.reset_grad()
            d_loss.backward()
            self.d_optimizer.step()

            # loss['D/d_loss_t'] = d_loss_t.item()
            # loss['D/loss_cls'] = d_loss_cls.item()
            # loss['D/D_gp'] = d_loss_gp.item()
            loss['D/D_loss'] = d_loss.item()

            # =================================================================================== #
            #                               3. Train the generator                                #
            # =================================================================================== #
            if (i + 1) % self.n_critic == 0:
                # Original-to-target domain.
                x_fake = self.G(x_real, label_trg)
                g_out_src = self.D(x_fake, label_trg)
                g_loss_fake = F.binary_cross_entropy_with_logits(
                    input=g_out_src,
                    target=torch.ones_like(g_out_src, dtype=torch.float))

                out_cls = self.C(x_fake)
                g_loss_cls = CELoss(input=out_cls, target=speaker_idx_trg)

                # Target-to-original domain.
                x_reconst = self.G(x_fake, label_org)
                g_loss_rec = F.l1_loss(x_reconst, x_real)

                # Original-to-Original domain(identity).
                x_fake_iden = self.G(x_real, label_org)
                id_loss = F.l1_loss(x_fake_iden, x_real)

                # Backward and optimize.
                g_loss = g_loss_fake + self.lambda_cycle * g_loss_rec +\
                 self.lambda_cls * g_loss_cls + self.lambda_identity * id_loss

                self.reset_grad()
                g_loss.backward()
                self.g_optimizer.step()

                # Logging.
                loss['G/loss_fake'] = g_loss_fake.item()
                loss['G/loss_rec'] = g_loss_rec.item()
                loss['G/loss_cls'] = g_loss_cls.item()
                loss['G/loss_id'] = id_loss.item()
                loss['G/g_loss'] = g_loss.item()
            # =================================================================================== #
            #                                 4. Miscellaneous                                    #
            # =================================================================================== #
            # Print out training information.
            if (i + 1) % self.log_step == 0:
                et = datetime.now() - start_time
                et = str(et)[:-7]
                log = "Elapsed [{}], Iteration [{}/{}]".format(
                    et, i + 1, self.num_iters)
                for tag, value in loss.items():
                    log += ", {}: {:.4f}".format(tag, value)
                print(log)

                if self.use_tensorboard:
                    for tag, value in loss.items():
                        self.logger.scalar_summary(tag, value, i + 1)

            # Translate fixed images for debugging.
            if (i + 1) % self.sample_step == 0:
                with torch.no_grad():
                    d, speaker = TestSet(self.test_dir).test_data()
                    target = random.choice(
                        [x for x in speakers if x != speaker])
                    label_t = self.spk_enc.transform([target])[0]
                    label_t = np.asarray([label_t])

                    for filename, content in d.items():
                        f0 = content['f0']
                        ap = content['ap']
                        sp_norm_pad = self.pad_coded_sp(
                            content['coded_sp_norm'])

                        convert_result = []
                        for start_idx in range(
                                0, sp_norm_pad.shape[1] - FRAMES + 1, FRAMES):
                            one_seg = sp_norm_pad[:,
                                                  start_idx:start_idx + FRAMES]

                            one_seg = torch.FloatTensor(one_seg).to(
                                self.device)
                            one_seg = one_seg.view(1, 1, one_seg.size(0),
                                                   one_seg.size(1))
                            l = torch.FloatTensor(label_t)
                            one_seg = one_seg.to(self.device)
                            l = l.to(self.device)
                            one_set_return = self.G(one_seg,
                                                    l).data.cpu().numpy()
                            one_set_return = np.squeeze(one_set_return)
                            one_set_return = norm.backward_process(
                                one_set_return, target)
                            convert_result.append(one_set_return)

                        convert_con = np.concatenate(convert_result, axis=1)
                        convert_con = convert_con[:,
                                                  0:content['coded_sp_norm'].
                                                  shape[1]]
                        contigu = np.ascontiguousarray(convert_con.T,
                                                       dtype=np.float64)
                        decoded_sp = decode_spectral_envelope(contigu,
                                                              SAMPLE_RATE,
                                                              fft_size=FFTSIZE)
                        f0_converted = norm.pitch_conversion(
                            f0, speaker, target)
                        wav = synthesize(f0_converted, decoded_sp, ap,
                                         SAMPLE_RATE)

                        name = f'{speaker}-{target}_iter{i+1}_{filename}'
                        path = os.path.join(self.sample_dir, name)
                        print(f'[save]:{path}')
                        librosa.output.write_wav(path, wav, SAMPLE_RATE)

            # Save model checkpoints.
            if (i + 1) % self.model_save_step == 0:
                G_path = os.path.join(self.model_save_dir,
                                      '{}-G.ckpt'.format(i + 1))
                D_path = os.path.join(self.model_save_dir,
                                      '{}-D.ckpt'.format(i + 1))
                C_path = os.path.join(self.model_save_dir,
                                      '{}-C.ckpt'.format(i + 1))
                torch.save(self.G.state_dict(), G_path)
                torch.save(self.D.state_dict(), D_path)
                torch.save(self.C.state_dict(), C_path)
                print('Saved model checkpoints into {}...'.format(
                    self.model_save_dir))

            # Decay learning rates.
            if (i + 1) % self.lr_update_step == 0 and (i + 1) > (
                    self.num_iters - self.num_iters_decay):
                g_lr -= (self.g_lr / float(self.num_iters_decay))
                d_lr -= (self.d_lr / float(self.num_iters_decay))
                c_lr -= (self.c_lr / float(self.num_iters_decay))
                self.update_lr(g_lr, d_lr, c_lr)
                print('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(
                    g_lr, d_lr))

    def gradient_penalty(self, y, x):
        """Compute gradient penalty: (L2_norm(dy/dx) - 1)**2."""
        weight = torch.ones(y.size()).to(self.device)
        dydx = torch.autograd.grad(outputs=y,
                                   inputs=x,
                                   grad_outputs=weight,
                                   retain_graph=True,
                                   create_graph=True,
                                   only_inputs=True)[0]

        dydx = dydx.view(dydx.size(0), -1)
        dydx_l2norm = torch.sqrt(torch.sum(dydx**2, dim=1))
        return torch.mean((dydx_l2norm - 1)**2)

    def reset_grad(self):
        """Reset the gradient buffers."""
        self.g_optimizer.zero_grad()
        self.d_optimizer.zero_grad()
        self.c_optimizer.zero_grad()

    def restore_model(self, resume_iters):
        """Restore the trained generator and discriminator."""
        print(
            'Loading the trained models from step {}...'.format(resume_iters))
        G_path = os.path.join(self.model_save_dir,
                              '{}-G.ckpt'.format(resume_iters))
        D_path = os.path.join(self.model_save_dir,
                              '{}-D.ckpt'.format(resume_iters))
        C_path = os.path.join(self.model_save_dir,
                              '{}-C.ckpt'.format(resume_iters))
        self.G.load_state_dict(
            torch.load(G_path, map_location=lambda storage, loc: storage))
        self.D.load_state_dict(
            torch.load(D_path, map_location=lambda storage, loc: storage))
        self.C.load_state_dict(
            torch.load(C_path, map_location=lambda storage, loc: storage))

    @staticmethod
    def pad_coded_sp(coded_sp_norm):
        f_len = coded_sp_norm.shape[1]
        if f_len >= FRAMES:
            pad_length = FRAMES - (f_len - (f_len // FRAMES) * FRAMES)
        elif f_len < FRAMES:
            pad_length = FRAMES - f_len

        sp_norm_pad = np.hstack(
            (coded_sp_norm, np.zeros((coded_sp_norm.shape[0], pad_length))))
        return sp_norm_pad

    def test(self):
        """Translate speech using StarGAN ."""
        # Load the trained generator.
        self.restore_model(self.test_iters)
        norm = Normalizer()

        # Set data loader.
        d, speaker = TestSet(self.test_dir).test_data(self.src_speaker)
        targets = self.trg_speaker

        for target in targets:
            print(target)
            assert target in speakers
            label_t = self.spk_enc.transform([target])[0]
            label_t = np.asarray([label_t])

            with torch.no_grad():

                for filename, content in d.items():
                    f0 = content['f0']
                    ap = content['ap']
                    sp_norm_pad = self.pad_coded_sp(content['coded_sp_norm'])

                    convert_result = []
                    for start_idx in range(0,
                                           sp_norm_pad.shape[1] - FRAMES + 1,
                                           FRAMES):
                        one_seg = sp_norm_pad[:, start_idx:start_idx + FRAMES]

                        one_seg = torch.FloatTensor(one_seg).to(self.device)
                        one_seg = one_seg.view(1, 1, one_seg.size(0),
                                               one_seg.size(1))
                        l = torch.FloatTensor(label_t)
                        one_seg = one_seg.to(self.device)
                        l = l.to(self.device)
                        one_set_return = self.G(one_seg, l).data.cpu().numpy()
                        one_set_return = np.squeeze(one_set_return)
                        one_set_return = norm.backward_process(
                            one_set_return, target)
                        convert_result.append(one_set_return)

                    convert_con = np.concatenate(convert_result, axis=1)
                    convert_con = convert_con[:, 0:content['coded_sp_norm'].
                                              shape[1]]
                    contigu = np.ascontiguousarray(convert_con.T,
                                                   dtype=np.float64)
                    decoded_sp = decode_spectral_envelope(contigu,
                                                          SAMPLE_RATE,
                                                          fft_size=FFTSIZE)
                    f0_converted = norm.pitch_conversion(f0, speaker, target)
                    wav = synthesize(f0_converted, decoded_sp, ap, SAMPLE_RATE)

                    name = f'{speaker}-{target}_iter{self.test_iters}_{filename}'
                    path = os.path.join(self.result_dir, name)
                    print(f'[save]:{path}')
                    librosa.output.write_wav(path, wav, SAMPLE_RATE)
Пример #4
0
class Solver(object):
    """docstring for Solver."""
    def __init__(self, data_loader, config):

        self.config = config
        self.data_loader = data_loader
        # Model configurations.

        self.lambda_cycle = config.lambda_cycle
        self.lambda_cls = config.lambda_cls
        self.lambda_identity = config.lambda_identity
        self.sigma_d = config.sigma_d

        # Training configurations.
        self.data_dir = config.data_dir
        self.test_dir = config.test_dir
        self.batch_size = config.batch_size
        self.num_iters = config.num_iters
        self.num_iters_decay = config.num_iters_decay
        self.g_lr = config.g_lr
        self.d_lr = config.d_lr
        self.c_lr = config.c_lr
        self.n_critic = config.n_critic
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.resume_iters = config.resume_iters

        # Test configurations.
        self.test_iters = config.test_iters
        self.trg_style = ast.literal_eval(config.trg_style)
        self.src_style = config.src_style

        # Miscellaneous.
        self.use_tensorboard = config.use_tensorboard
        self.device = torch.device(
            'cuda:0' if torch.cuda.is_available() else 'cpu')
        self.stls_enc = LabelBinarizer().fit(styles)
        # Directories.
        self.log_dir = config.log_dir
        self.sample_dir = config.sample_dir
        self.model_save_dir = config.model_save_dir
        self.result_dir = config.result_dir

        # Step size.
        self.log_step = config.log_step
        self.sample_step = config.sample_step
        self.model_save_step = config.model_save_step
        self.lr_update_step = config.lr_update_step

        # Build the model and tensorboard.
        self.build_model()
        if self.use_tensorboard:
            self.build_tensorboard()

    def build_model(self):
        self.G = Generator()
        self.D = Discriminator()
        self.C = DomainClassifier()

        self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr,
                                            [self.beta1, self.beta2])
        self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr,
                                            [self.beta1, self.beta2])
        self.c_optimizer = torch.optim.Adam(self.C.parameters(), self.c_lr,
                                            [self.beta1, self.beta2])

        self.print_network(self.G, 'G')
        self.print_network(self.D, 'D')
        self.print_network(self.C, 'C')

        self.G.to(self.device)
        self.D.to(self.device)
        self.C.to(self.device)

    def print_network(self, model, name):
        """Print out the network information."""
        num_params = 0
        for p in model.parameters():
            num_params += p.numel()
        print(model)
        print(name)
        print("The number of parameters: {}".format(num_params))

    def build_tensorboard(self):
        """Build a tensorboard logger."""
        from logger import Logger
        self.logger = Logger(self.log_dir)

    def update_lr(self, g_lr, d_lr, c_lr):
        """Decay learning rates of the generator and discriminator and classifier."""
        for param_group in self.g_optimizer.param_groups:
            param_group['lr'] = g_lr
        for param_group in self.d_optimizer.param_groups:
            param_group['lr'] = d_lr
        for param_group in self.c_optimizer.param_groups:
            param_group['lr'] = c_lr

    def train(self):
        # Learning rate cache for decaying.
        g_lr = self.g_lr
        d_lr = self.d_lr
        c_lr = self.c_lr

        start_iters = 0
        if self.resume_iters:
            pass

        #norm = Normalizer()
        data_iter = iter(self.data_loader)

        print('Start training......')
        start_time = datetime.now()

        for i in range(start_iters, self.num_iters):
            # =================================================================================== #
            #                             1. Preprocess input data                                #
            # =================================================================================== #
            # Fetch real images and labels.
            try:
                x_real, style_idx_org, label_org = next(data_iter)
            except:
                data_iter = iter(self.data_loader)
                x_real, style_idx_org, label_org = next(data_iter)

            #generate gaussian noise for robustness improvement

            gaussian_noise = self.sigma_d * torch.randn(x_real.size())

            # Generate target domain labels randomly.
            rand_idx = torch.randperm(label_org.size(0))
            label_trg = label_org[rand_idx]
            style_idx_trg = style_idx_org[rand_idx]

            x_real = x_real.to(self.device)  # Input images.
            label_org = label_org.to(
                self.device)  # Original domain one-hot labels.
            label_trg = label_trg.to(
                self.device)  # Target domain one-hot labels.
            style_idx_org = style_idx_org.to(
                self.device)  # Original domain labels
            style_idx_trg = style_idx_trg.to(
                self.device)  #Target domain labels
            gaussian_noise = gaussian_noise.to(
                self.device)  #gaussian noise for discriminators
            # =================================================================================== #
            #                             2. Train the discriminator                              #
            # =================================================================================== #
            # Compute loss with real audio frame.
            CELoss = nn.CrossEntropyLoss()
            cls_real = self.C(x_real)
            cls_loss_real = CELoss(input=cls_real, target=style_idx_org)

            self.reset_grad()
            cls_loss_real.backward()
            self.c_optimizer.step()
            # Logging.
            loss = {}
            loss['C/C_loss'] = cls_loss_real.item()

            out_r = self.D(x_real + gaussian_noise, label_org)
            # Compute loss with fake audio frame.
            x_fake = self.G(x_real, label_trg)
            out_f = self.D(x_fake + gaussian_noise, label_trg)
            d_loss_t = F.mse_loss(input=out_f,target=torch.zeros_like(out_f, dtype=torch.float)) + \
                F.mse_loss(input=out_r, target=torch.ones_like(out_r, dtype=torch.float))

            out_cls = self.C(x_fake)
            d_loss_cls = CELoss(input=out_cls, target=style_idx_trg)

            #Compute loss for gradient penalty.
            #alpha = torch.rand(x_real.size(0), 1, 1, 1).to(self.device)
            #x_hat = (alpha * x_real.data + (1 - alpha) * x_fake.data).requires_grad_(True)
            #out_src = self.D(x_hat, label_trg)
            #d_loss_gp = self.gradient_penalty(out_src, x_hat)

            d_loss = d_loss_t + self.lambda_cls * d_loss_cls
            #\+ 5*d_loss_gp

            self.reset_grad()
            d_loss.backward()
            self.d_optimizer.step()

            # loss['D/d_loss_t'] = d_loss_t.item()
            # loss['D/loss_cls'] = d_loss_cls.item()
            # loss['D/D_gp'] = d_loss_gp.item()
            loss['D/D_loss'] = d_loss.item()

            # =================================================================================== #
            #                               3. Train the generator                                #
            # =================================================================================== #
            if (i + 1) % self.n_critic == 0:
                # Original-to-target domain.
                x_fake = self.G(x_real, label_trg)
                g_out_src = self.D(x_fake + gaussian_noise, label_trg)
                g_loss_fake = F.mse_loss(input=g_out_src,
                                         target=torch.ones_like(
                                             g_out_src, dtype=torch.float))

                out_cls = self.C(x_real)
                g_loss_cls = CELoss(input=out_cls, target=style_idx_org)

                # Target-to-original domain.
                x_reconst = self.G(x_fake, label_org)
                g_loss_rec = F.l1_loss(x_reconst, x_real)

                # Original-to-Original domain(identity).
                x_fake_iden = self.G(x_real, label_org)
                id_loss = F.l1_loss(x_fake_iden, x_real)

                # Backward and optimize.
                g_loss = g_loss_fake + self.lambda_cycle * g_loss_rec +\
                 self.lambda_cls * g_loss_cls + self.lambda_identity * id_loss

                self.reset_grad()
                g_loss.backward()
                self.g_optimizer.step()

                # Logging.
                loss['G/loss_fake'] = g_loss_fake.item()
                loss['G/loss_rec'] = g_loss_rec.item()
                loss['G/loss_cls'] = g_loss_cls.item()
                loss['G/loss_id'] = id_loss.item()
                loss['G/g_loss'] = g_loss.item()
            # =================================================================================== #
            #                                 4. Miscellaneous                                    #
            # =================================================================================== #
            # Print out training information.
            if (i + 1) % self.log_step == 0:
                et = datetime.now() - start_time
                et = str(et)[:-7]
                log = "Elapsed [{}], Iteration [{}/{}]".format(
                    et, i + 1, self.num_iters)
                for tag, value in loss.items():
                    log += ", {}: {:.4f}".format(tag, value)
                print(log)

                if self.use_tensorboard:
                    for tag, value in loss.items():
                        self.logger.scalar_summary(tag, value, i + 1)

            # Translate fixed images for debugging.
            if (i + 1) % self.sample_step == 0:
                with torch.no_grad():
                    d, style = TestSet(self.test_dir).test_data()

                    label_o = self.stls_enc.transform([style])[0]
                    label_o = np.asarray([label_o])

                    target = random.choice([x for x in styles if x != style])
                    label_t = self.stls_enc.transform([target])[0]
                    label_t = np.asarray([label_t])

                    for filename, content in d.items():

                        filename = filename.split('.')[0]

                        one_seg = torch.FloatTensor(content).to(self.device)
                        one_seg = one_seg.view(1, one_seg.size(0),
                                               one_seg.size(1),
                                               one_seg.size(2))
                        l_t = torch.FloatTensor(label_t)
                        one_seg = one_seg.to(self.device)
                        l_t = l_t.to(self.device)

                        one_set_transfer = self.G(one_seg, l_t)

                        l_o = torch.FloatTensor(label_o)
                        l_o = l_o.to(self.device)

                        one_set_cycle = self.G(
                            one_set_transfer.to(self.device),
                            l_o).data.cpu().numpy()
                        one_set_transfer = one_set_transfer.data.cpu().numpy()

                        one_set_transfer_binary = to_binary(
                            one_set_transfer, 0.5)
                        one_set_cycle_binary = to_binary(one_set_cycle, 0.5)

                        one_set_transfer_binary = one_set_transfer_binary.reshape(
                            -1, one_set_transfer_binary.shape[2],
                            one_set_transfer_binary.shape[3],
                            one_set_transfer_binary.shape[1])
                        one_set_cycle_binary = one_set_cycle_binary.reshape(
                            -1, one_set_cycle_binary.shape[2],
                            one_set_cycle_binary.shape[3],
                            one_set_cycle_binary.shape[1])

                        print(one_set_transfer_binary.shape,
                              one_set_cycle_binary.shape)

                        name_origin = f'{style}-{target}_iter{i+1}_{filename}_origin'
                        name_transfer = f'{style}-{target}_iter{i+1}_{filename}_transfer'
                        name_cycle = f'{style}-{target}_iter{i+1}_{filename}_cycle'

                        path_samples_per_iter = os.path.join(
                            self.sample_dir, f'iter{i+1}')

                        if not os.path.exists(path_samples_per_iter):
                            os.makedirs(path_samples_per_iter)

                        path_origin = os.path.join(path_samples_per_iter,
                                                   name_origin)
                        path_transfer = os.path.join(path_samples_per_iter,
                                                     name_transfer)
                        path_cycle = os.path.join(path_samples_per_iter,
                                                  name_cycle)

                        print(
                            f'[save]:{path_origin},{path_transfer},{path_cycle}'
                        )

                        save_midis(
                            content.reshape(1, content.shape[1],
                                            content.shape[2],
                                            content.shape[0]),
                            '{}.mid'.format(path_origin))
                        save_midis(one_set_transfer_binary,
                                   '{}.mid'.format(path_transfer))
                        save_midis(one_set_cycle_binary,
                                   '{}.mid'.format(path_cycle))

            # Save model checkpoints.
            if (i + 1) % self.model_save_step == 0:
                G_path = os.path.join(self.model_save_dir,
                                      '{}-G.ckpt'.format(i + 1))
                D_path = os.path.join(self.model_save_dir,
                                      '{}-D.ckpt'.format(i + 1))
                C_path = os.path.join(self.model_save_dir,
                                      '{}-C.ckpt'.format(i + 1))
                torch.save(self.G.state_dict(), G_path)
                torch.save(self.D.state_dict(), D_path)
                torch.save(self.C.state_dict(), C_path)
                print('Saved model checkpoints into {}...'.format(
                    self.model_save_dir))

            # Decay learning rates.
            if (i + 1) % self.lr_update_step == 0 and (i + 1) > (
                    self.num_iters - self.num_iters_decay):
                g_lr -= (self.g_lr / float(self.num_iters_decay))
                d_lr -= (self.d_lr / float(self.num_iters_decay))
                c_lr -= (self.c_lr / float(self.num_iters_decay))
                self.update_lr(g_lr, d_lr, c_lr)
                print('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(
                    g_lr, d_lr))

    def reset_grad(self):
        """Reset the gradient buffers."""
        self.g_optimizer.zero_grad()
        self.d_optimizer.zero_grad()
        self.c_optimizer.zero_grad()

    def restore_model(self, resume_iters):
        """Restore the trained generator and discriminator."""
        print(
            'Loading the trained models from step {}...'.format(resume_iters))
        G_path = os.path.join(self.model_save_dir,
                              '{}-G.ckpt'.format(resume_iters))
        D_path = os.path.join(self.model_save_dir,
                              '{}-D.ckpt'.format(resume_iters))
        C_path = os.path.join(self.model_save_dir,
                              '{}-C.ckpt'.format(resume_iters))
        self.G.load_state_dict(
            torch.load(G_path, map_location=lambda storage, loc: storage))
        self.D.load_state_dict(
            torch.load(D_path, map_location=lambda storage, loc: storage))
        self.C.load_state_dict(
            torch.load(C_path, map_location=lambda storage, loc: storage))

    def test(self):
        """Translate speech using StarGAN ."""
        # Load the trained generator.
        self.restore_model(self.test_iters)

        # Set data loader.
        d, style = TestSet(self.test_dir).test_data(self.src_style)
        targets = self.trg_style

        for target in targets:
            print(target)
            assert target in styles
            label_t = self.stls_enc.transform([target])[0]
            label_t = np.asarray([label_t])

            with torch.no_grad():

                for filename, content in d.items():

                    filename = filename.split('.')[0]

                    one_seg = torch.FloatTensor(content).to(self.device)
                    one_seg = one_seg.view(1, one_seg.size(0), one_seg.size(1),
                                           one_seg.size(2))
                    l_t = torch.FloatTensor(label_t)
                    one_seg = one_seg.to(self.device)
                    l_t = l_t.to(self.device)

                    one_set_transfer = self.G(one_seg, l_t).cpu().numpy()

                    one_set_transfer_binary = to_binary(one_set_transfer, 0.5)

                    one_set_transfer_binary = one_set_transfer_binary.reshape(
                        -1, one_set_transfer_binary.shape[2],
                        one_set_transfer_binary.shape[3],
                        one_set_transfer_binary.shape[1])

                    name_origin = f'{style}-{target}_iter{i+1}_{filename}_origin'
                    name_transfer = f'{style}-{target}_iter{i+1}_{filename}_transfer'

                    path = os.path.join(self.result_dir, f'iter{i+1}')

                    path_origin = os.path.join(path, name_origin)
                    path_transfer = os.path.join(path, name_transfer)

                    print(f'[save]:{path_origin},{path_transfer}')

                    save_midis(
                        content.reshape(1, content.shape[1], content.shape[2],
                                        content.shape[0]),
                        '{}.mid'.format(path_origin))
                    save_midis(one_set_transfer_binary,
                               '{}.mid'.format(path_transfer))