Ejemplo n.º 1
0
class Solver(object):
    """训练操作的说明"""

    #  参数为数据加载器与配置
    def __init__(self, data_loader, config):
        # 进行赋值
        self.config = config
        self.data_loader = data_loader
        # 模型配置
        # 赋值三个损失函数,循环损失,域分类损失,身份映射损失的衡量参数
        self.lambda_cycle = config.lambda_cycle
        self.lambda_cls = config.lambda_cls
        self.lambda_identity = config.lambda_identity

        # 训练配置
        # 数据文件路由
        self.data_dir = config.data_dir
        # 测试文件路由
        self.test_dir = config.test_dir
        # 批处理大小
        self.batch_size = config.batch_size
        # 训练判别器D的总迭代次数
        self.num_iters = config.num_iters
        # 衰减学习率的迭代次数
        self.num_iters_decay = config.num_iters_decay
        # 生成器G的学习频率
        self.g_lr = config.g_lr
        # 判别器D的学习频率
        self.d_lr = config.d_lr
        # 域分类器C的学习频率
        self.c_lr = config.c_lr
        # 每次G更新时的D更新次数
        self.n_critic = config.n_critic
        # Adam优化器的beta1,2参数
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        # 从此步骤恢复培训
        self.resume_iters = config.resume_iters

        # 测试配置
        # 从这个步骤开始训练模型
        self.test_iters = config.test_iters
        # ast.literal_eval为解析函数,并安全地进行类型转换
        # 目标发音者
        self.trg_speaker = ast.literal_eval(config.trg_speaker)
        # 源发音者
        self.src_speaker = config.src_speaker

        # 其他配置
        # 是否使用tensorboard记录
        self.use_tensorboard = config.use_tensorboard
        # torch.device代表将torch.Tensor分配到的设备的对象。torch.device包含一个设备类型(‘cpu’或‘cuda’)和可选的设备序号。
        # 是使用cuda还是cpu计算
        self.device = torch.device(
            'cuda:0' if torch.cuda.is_available() else 'cpu')
        # 将speakers标签列表二值化
        self.spk_enc = LabelBinarizer().fit(speakers)
        # 字典
        # 记录目录
        self.log_dir = config.log_dir
        # 样本目录
        self.sample_dir = config.sample_dir
        # 模型目录
        self.model_save_dir = config.model_save_dir
        # 输出目录
        self.result_dir = config.result_dir

        # 步长
        # 记录步长
        self.log_step = config.log_step
        # 采样步长
        self.sample_step = config.sample_step
        # 模型保存间隔步长
        self.model_save_step = config.model_save_step
        # 学习率更新步长
        self.lr_update_step = config.lr_update_step

        # 建立模型与tensorboard.
        self.build_model()
        if self.use_tensorboard:
            # 使用tensorboard记录器
            self.build_tensorboard()

    # 赋值三个模型器
    def build_model(self):
        # 将模型赋值给类属性
        self.G = Generator()
        self.D = Discriminator()
        self.C = DomainClassifier()
        # torch.optim.Adam用于实现Adam算法

        # Adam(Adaptive Moment Estimation)本质上是带有动量项的RMSprop,它利用梯度的一阶矩估计和二阶矩估计动态调整每个参数的学习率。
        # 它的优点主要在于经过偏置校正后,每一次迭代学习率都有个确定范围,使得参数比较平稳。

        # params (iterable) – 待优化参数的iterable或者是定义了参数组的dict
        # lr (float, 可选) – 学习率(默认:1e-3)
        # 同样也称为学习率或步长因子,它控制了权重的更新比率(如 0.001)。较大的值(如 0.3)在学习率更新前会有更快的初始学习,而较小的值(如 1.0E-5)会令训练收敛到更好的性能。
        # betas (Tuple[float, float], 可选) – 用于计算梯度以及梯度平方的运行平均值的系数(默认:0.9,0.999)betas = (beta1,beta2)
        # beta1:一阶矩估计的指数衰减率(如 0.9)。
        # beta2:二阶矩估计的指数衰减率(如 0.999)。该超参数在稀疏梯度(如在 NLP 或计算机视觉任务中)中应该设置为接近 1 的数。
        # eps (float, 可选) – 为了增加数值计算的稳定性而加到分母里的项(默认:1e-8)epsilon:该参数是非常小的数,其为了防止在实现中除以零(如 10E-8)。
        # weight_decay (float, 可选) – 权重衰减(L2级惩罚)(默认: 0)

        self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr,
                                            [self.beta1, self.beta2])
        self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr,
                                            [self.beta1, self.beta2])
        self.c_optimizer = torch.optim.Adam(self.C.parameters(), self.c_lr,
                                            [self.beta1, self.beta2])

        self.print_network(self.G, 'G')
        self.print_network(self.D, 'D')
        self.print_network(self.C, 'C')
        # Module.to方法用来移动和转换参数与缓冲区,类似于torch.Tensor.to,但是仅支持float类型
        # 这里将模型转移到GPU运算
        self.G.to(self.device)
        self.D.to(self.device)
        self.C.to(self.device)

    def print_network(self, model, name):
        """打印出网络的相关信息"""
        num_params = 0
        # Module.parameters()获取网络的参数
        # 计算模型网络对应的参数频次
        for p in model.parameters():
            # numel返回数组中元素的个数
            num_params += p.numel()
        print(model)
        print(name)
        print("参数的个数为:{}".format(num_params))

    # 使用tensorboard
    def build_tensorboard(self):
        """建立一个tensorboard记录器"""
        from logger import Logger
        # 建立一个记录器,传入记录地址
        self.logger = Logger(self.log_dir)

    def update_lr(self, g_lr, d_lr, c_lr):
        """生成器、判别器和域分类器的衰减学习率"""
        for param_group in self.g_optimizer.param_groups:
            param_group['lr'] = g_lr
        for param_group in self.d_optimizer.param_groups:
            param_group['lr'] = d_lr
        for param_group in self.c_optimizer.param_groups:
            param_group['lr'] = c_lr

    # 训练方法
    def train(self):
        # 衰减的学习率缓存
        g_lr = self.g_lr
        d_lr = self.d_lr
        c_lr = self.c_lr
        # 开始训练步骤数为0
        start_iters = 0
        # 如果存在就跳过
        if self.resume_iters:
            pass
        # 调用定义的个性化标准化方法
        norm = Normalizer()
        # iter用来生成迭代器,这里用来迭代加载数据集
        data_iter = iter(self.data_loader)
        print('开始训练......')
        # 记录当前时间,now函数取当前时间
        start_time = datetime.now()
        # 利用总迭代次数来进行遍历
        for i in range(start_iters, self.num_iters):
            # =================================================================================== #
            #                                 1.预处理输入数据                                    #
            # =================================================================================== #
            # 获取真实的图像和对应标签标签
            try:
                # next方法为迭代下一个迭代器
                # 利用自定义的加载器获取真实x值,发音者标签在组中索引与源标签
                x_real, speaker_idx_org, label_org = next(data_iter)
            except:
                # 如果迭代器有问题就再转换为迭代器一次然后迭代
                data_iter = iter(self.data_loader)
                x_real, speaker_idx_org, label_org = next(data_iter)

            # 随机生成目标域标签
            # torch.randperm返回一个从0到参数-1范围的随机数组
            # 因为标签二值化了,所以这里的标签是10组成的,所以一共有label_org.size(0)个标签
            # 获得的是随机索引
            rand_idx = torch.randperm(label_org.size(0))
            # 根据随机数作为源标签的索引作为目标标签数
            label_trg = label_org[rand_idx]
            # 同理得到随机目标发音者
            speaker_idx_trg = speaker_idx_org[rand_idx]
            # to表示使用cpu或者gpu运行
            x_real = x_real.to(self.device)  # 输入数据
            label_org = label_org.to(self.device)  # 源域one-hot格式标签
            label_trg = label_trg.to(self.device)  # 目标域ont-hot格式标签
            speaker_idx_org = speaker_idx_org.to(self.device)  # 源域标签
            speaker_idx_trg = speaker_idx_trg.to(self.device)  # 目标域标签

            # =================================================================================== #
            #                                      2.训练判别器                                   #
            # =================================================================================== #
            # 用真实音频数据计算损失
            # nn.CrossEntropyLoss()为交叉熵损失函数,但是不是普通的形式,而是主要是将softmax-log-NLLLoss合并到一块得到的结果。
            CELoss = nn.CrossEntropyLoss()
            # 调用分类器计算真实数据
            cls_real = self.C(x_real)
            # 计算对应的域分类损失,即用交叉熵实现
            cls_loss_real = CELoss(input=cls_real, target=speaker_idx_org)
            # 重置缓冲区,具体实现在下面
            self.reset_grad()
            # tensor.backward为自动求导函数
            cls_loss_real.backward()
            # optimizer.step这个方法会更新模型所有的参数以提升学习率,一般在backward函数后根据其计算的梯度来更新参数
            self.c_optimizer.step()
            # 记录中
            loss = {}
            # 从真实域分类损失张量中获取元素值
            # item()得到一个元素张量里面的元素值
            loss['C/C_loss'] = cls_loss_real.item()

            # 基于源数据的D判断结果
            out_r = self.D(x_real, label_org)
            # 用假音频帧计算损失
            # 根据真实样本与目标标签生成生成样本
            x_fake = self.G(x_real, label_trg)
            # detach截断反向传播的梯度流,从而让梯度不影响判别器D
            # 基于生成样本的D判断结果
            out_f = self.D(x_fake.detach(), label_trg)
            # torch.nn.Function.binary_cross_entropy_with_logits度量目标逻辑和输出逻辑之间的二进制交叉熵的函数
            # 接受任意形状的输入,target要求与输入形状一致。切记:target的值必须在[0,N-1]之间,其中N为类别数,否则会出现莫名其妙的错误,比如loss为负数。
            # 计算其实就是交叉熵,不过输入不要求在0,1之间,该函数会自动添加sigmoid运算
            # 返回一个填充了标量值1的张量,其大小与输入相同。torch.ones_like(input)
            # 相当于torch.ones(input.size(), dtype=input.dtype, layout=input.layout, device=input.device)

            # binary_cross_entropy_with_logits和binary_cross_entropy的区别
            # 有一个(类)损失函数名字中带了with_logits. 而这里的logits指的是,该损失函数已经内部自带了计算logit的操作,
            # 无需在传入给这个loss函数之前手动使用sigmoid/softmax将之前网络的输入映射到[0,1]之间
            d_loss_t = F.binary_cross_entropy_with_logits(input=out_f,target=torch.zeros_like(out_f, dtype=torch.float)) + \
                F.binary_cross_entropy_with_logits(input=out_r, target=torch.ones_like(out_r, dtype=torch.float))
            # 生成样本的分类结果
            out_cls = self.C(x_fake)
            # 交叉熵计算生成样本的域分类损失
            d_loss_cls = CELoss(input=out_cls, target=speaker_idx_trg)

            # 计算梯度惩罚的损失
            alpha = torch.rand(x_real.size(0), 1, 1, 1).to(self.device)
            # 计算x_hat
            # requires_grad_设置积分方法,将requires_grad是否积分的属性设置为真
            # 取一个随机数混合真实样本和生成样本得到一个x尖
            x_hat = (alpha * x_real.data +
                     (1 - alpha) * x_fake.data).requires_grad_(True)
            # 计算混合样本和目标标签的判别结果
            out_src = self.D(x_hat, label_trg)
            # 调用自定义方法得到处理导数后的数据
            d_loss_gp = self.gradient_penalty(out_src, x_hat)
            # 计算判别器的总体损失
            d_loss = d_loss_t + self.lambda_cls * d_loss_cls + 5 * d_loss_gp
            # 调用自定义方法重置梯度变化缓冲区
            self.reset_grad()
            # 对D的损失求导
            d_loss.backward()
            # 更新模型判别器D参数
            self.d_optimizer.step()

            # loss['D/d_loss_t'] = d_loss_t.item()
            # loss['D/loss_cls'] = d_loss_cls.item()
            # loss['D/D_gp'] = d_loss_gp.item()
            # 获取判别器损失
            loss['D/D_loss'] = d_loss.item()

            # =================================================================================== #
            #                                       3.训练生成器                                  #
            # =================================================================================== #
            # 进行模运算,判读更新时间
            if (i + 1) % self.n_critic == 0:
                # 源至目标域
                # 利用真实样本和目标标签生成生成样本
                x_fake = self.G(x_real, label_trg)
                #  判别生成样本与目标标签
                g_out_src = self.D(x_fake, label_trg)
                # 将生成与目标标签的损失与相同大小纯1张量计算交叉熵得到生成G损失
                g_loss_fake = F.binary_cross_entropy_with_logits(
                    input=g_out_src,
                    target=torch.ones_like(g_out_src, dtype=torch.float))
                # 得到真实样本通过域分类器得到的类别
                out_cls = self.C(x_real)
                # 计算C计算类别与输入的类别的交叉熵损失即G的分类损失
                g_loss_cls = CELoss(input=out_cls, target=speaker_idx_org)

                # 目标至源域
                # 通过G将生成样本转换为源标签
                x_reconst = self.G(x_fake, label_org)
                # 得到循环一致性损失,即通过G转回来的损失,按道理这两个是同样的
                # l1_loss为L1损失函数,即平均绝对误差
                g_loss_rec = F.l1_loss(x_reconst, x_real)

                # 源到源域(身份一致性损失).
                # 通过真实样本与源标签生成,按道理也是生成x_real
                x_fake_iden = self.G(x_real, label_org)
                # 利用L1损失函数计算
                id_loss = F.l1_loss(x_fake_iden, x_real)

                # 后退和优化
                # 得到生成器的总体损失函数
                g_loss = g_loss_fake + self.lambda_cycle * g_loss_rec +\
                 self.lambda_cls * g_loss_cls + self.lambda_identity * id_loss
                # 重置梯度变化缓冲区
                self.reset_grad()
                # 对G损失求导
                g_loss.backward()
                # 更新生成器参数
                self.g_optimizer.step()

                # 记录对应的损失
                loss['G/loss_fake'] = g_loss_fake.item()
                loss['G/loss_rec'] = g_loss_rec.item()
                loss['G/loss_cls'] = g_loss_cls.item()
                loss['G/loss_id'] = id_loss.item()
                loss['G/g_loss'] = g_loss.item()
            # =================================================================================== #
            #                                           4.其他                                    #
            # =================================================================================== #
            # 打印训练相关信息
            if (i + 1) % self.log_step == 0:
                # 得到训练时间
                et = datetime.now() - start_time
                # 截取后面的时间段
                et = str(et)[:-7]
                # 耗时与迭代次数
                log = "耗时:[{}], 迭代次数:[{}/{}]".format(et, i + 1, self.num_iters)
                # 打印对应损失值
                for tag, value in loss.items():
                    log += ", {}: {:.4f}".format(tag, value)
                print(log)
                # 如果调用tensorboard来记录训练过程
                if self.use_tensorboard:
                    for tag, value in loss.items():
                        # 添加到log中
                        self.logger.scalar_summary(tag, value, i + 1)

            # 翻译固定数据进行调试
            if (i + 1) % self.sample_step == 0:
                # torch.no_grad是一个上下文管理器,被该语句包括起来的部分将不会track 梯度
                # 所有依赖他的tensor会全部变成True,反向传播时就不会自动求导了,反向传播就不会保存梯度,因此大大节约了显存或者说内存。
                with torch.no_grad():
                    # 调用自定义方法,定义一个路由,并随机选取一个发音者作为测试数据
                    d, speaker = TestSet(self.test_dir).test_data()
                    # random.choice返回参数的随机项
                    # 随机在speakers中选择一个不是目标的发音者
                    target = random.choice(
                        [x for x in speakers if x != speaker])
                    # 将二值化的标签组取出第一个作为目标
                    # LabelBinary.transfrom方法将复杂类标签转换为二进制标签
                    label_t = self.spk_enc.transform([target])[0]
                    # np.asarray将python原生列表或元组形式的现有数据来创建numpy数组
                    label_t = np.asarray([label_t])
                    # 取出字典中的文件名与内容
                    for filename, content in d.items():
                        f0 = content['f0']
                        ap = content['ap']
                        # 调用自定义方法处理对应的数据
                        sp_norm_pad = self.pad_coded_sp(
                            content['coded_sp_norm'])

                        convert_result = []
                        for start_idx in range(
                                0, sp_norm_pad.shape[1] - FRAMES + 1, FRAMES):
                            one_seg = sp_norm_pad[:,
                                                  start_idx:start_idx + FRAMES]

                            one_seg = torch.FloatTensor(one_seg).to(
                                self.device)
                            one_seg = one_seg.view(1, 1, one_seg.size(0),
                                                   one_seg.size(1))
                            l = torch.FloatTensor(label_t)
                            one_seg = one_seg.to(self.device)
                            l = l.to(self.device)
                            one_set_return = self.G(one_seg,
                                                    l).data.cpu().numpy()
                            one_set_return = np.squeeze(one_set_return)
                            one_set_return = norm.backward_process(
                                one_set_return, target)
                            convert_result.append(one_set_return)

                        convert_con = np.concatenate(convert_result, axis=1)
                        convert_con = convert_con[:,
                                                  0:content['coded_sp_norm'].
                                                  shape[1]]
                        contigu = np.ascontiguousarray(convert_con.T,
                                                       dtype=np.float64)
                        decoded_sp = decode_spectral_envelope(contigu,
                                                              SAMPLE_RATE,
                                                              fft_size=FFTSIZE)
                        f0_converted = norm.pitch_conversion(
                            f0, speaker, target)
                        wav = synthesize(f0_converted, decoded_sp, ap,
                                         SAMPLE_RATE)

                        name = f'{speaker}-{target}_iter{i+1}_{filename}'
                        path = os.path.join(self.sample_dir, name)
                        print(f'[save]:{path}')
                        librosa.output.write_wav(path, wav, SAMPLE_RATE)

            # 保存模型检查点
            if (i + 1) % self.model_save_step == 0:
                G_path = os.path.join(self.model_save_dir,
                                      '{}-G.ckpt'.format(i + 1))
                D_path = os.path.join(self.model_save_dir,
                                      '{}-D.ckpt'.format(i + 1))
                C_path = os.path.join(self.model_save_dir,
                                      '{}-C.ckpt'.format(i + 1))
                torch.save(self.G.state_dict(), G_path)
                torch.save(self.D.state_dict(), D_path)
                torch.save(self.C.state_dict(), C_path)
                print('Saved model checkpoints into {}...'.format(
                    self.model_save_dir))

            # 衰减学习率
            if (i + 1) % self.lr_update_step == 0 and (i + 1) > (
                    self.num_iters - self.num_iters_decay):
                g_lr -= (self.g_lr / float(self.num_iters_decay))
                d_lr -= (self.d_lr / float(self.num_iters_decay))
                c_lr -= (self.c_lr / float(self.num_iters_decay))
                self.update_lr(g_lr, d_lr, c_lr)
                print('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(
                    g_lr, d_lr))

    def gradient_penalty(self, y, x):
        """计算梯度惩罚: (L2_norm(dy/dx) - 1)**2."""
        # 根据标签的纬度创建出权重矩阵
        weight = torch.ones(y.size()).to(self.device)
        # torch.autograd.grad计算并返回outputs对inputs的梯度dy/dx
        dydx = torch.autograd.grad(
            outputs=y,
            inputs=x,
            # 雅可比向量积中的“向量”,用来将梯度向量转换为梯度标量,并可以衡量y梯度的各个数据的权重
            grad_outputs=weight,
            retain_graph=True,
            # 保持计算的导数图
            create_graph=True,
            only_inputs=True)[0]
        # 将导数变为二维的数据并保持行数
        dydx = dydx.view(dydx.size(0), -1)
        # 计算导数的L2范数
        dydx_l2norm = torch.sqrt(torch.sum(dydx**2, dim=1))
        return torch.mean((dydx_l2norm - 1)**2)

    def reset_grad(self):
        """重置梯度变化缓冲区"""
        self.g_optimizer.zero_grad()
        self.d_optimizer.zero_grad()
        self.c_optimizer.zero_grad()

    def restore_model(self, resume_iters):
        """重置训练好的发生器和鉴别器"""
        print('从{}步骤开始加载训练过的模型...'.format(resume_iters))
        G_path = os.path.join(self.model_save_dir,
                              '{}-G.ckpt'.format(resume_iters))
        D_path = os.path.join(self.model_save_dir,
                              '{}-D.ckpt'.format(resume_iters))
        C_path = os.path.join(self.model_save_dir,
                              '{}-C.ckpt'.format(resume_iters))
        self.G.load_state_dict(
            torch.load(G_path, map_location=lambda storage, loc: storage))
        self.D.load_state_dict(
            torch.load(D_path, map_location=lambda storage, loc: storage))
        self.C.load_state_dict(
            torch.load(C_path, map_location=lambda storage, loc: storage))

    @staticmethod
    def pad_coded_sp(coded_sp_norm):
        f_len = coded_sp_norm.shape[1]
        if f_len >= FRAMES:
            pad_length = FRAMES - (f_len - (f_len // FRAMES) * FRAMES)
        elif f_len < FRAMES:
            pad_length = FRAMES - f_len

        sp_norm_pad = np.hstack(
            (coded_sp_norm, np.zeros((coded_sp_norm.shape[0], pad_length))))
        return sp_norm_pad

    def test(self):
        """用StarGAN处理音频数据"""
        # 加载训练生成器
        self.restore_model(self.test_iters)
        norm = Normalizer()

        # 设置数据加载器
        d, speaker = TestSet(self.test_dir).test_data(self.src_speaker)
        targets = self.trg_speaker

        for target in targets:
            print(target)
            assert target in speakers
            label_t = self.spk_enc.transform([target])[0]
            label_t = np.asarray([label_t])

            with torch.no_grad():

                for filename, content in d.items():
                    f0 = content['f0']
                    ap = content['ap']
                    sp_norm_pad = self.pad_coded_sp(content['coded_sp_norm'])

                    convert_result = []
                    for start_idx in range(0,
                                           sp_norm_pad.shape[1] - FRAMES + 1,
                                           FRAMES):
                        one_seg = sp_norm_pad[:, start_idx:start_idx + FRAMES]

                        one_seg = torch.FloatTensor(one_seg).to(self.device)
                        one_seg = one_seg.view(1, 1, one_seg.size(0),
                                               one_seg.size(1))
                        l = torch.FloatTensor(label_t)
                        one_seg = one_seg.to(self.device)
                        l = l.to(self.device)
                        one_set_return = self.G(one_seg, l).data.cpu().numpy()
                        one_set_return = np.squeeze(one_set_return)
                        one_set_return = norm.backward_process(
                            one_set_return, target)
                        convert_result.append(one_set_return)

                    convert_con = np.concatenate(convert_result, axis=1)
                    convert_con = convert_con[:, 0:content['coded_sp_norm'].
                                              shape[1]]
                    contigu = np.ascontiguousarray(convert_con.T,
                                                   dtype=np.float64)
                    decoded_sp = decode_spectral_envelope(contigu,
                                                          SAMPLE_RATE,
                                                          fft_size=FFTSIZE)
                    f0_converted = norm.pitch_conversion(f0, speaker, target)
                    wav = synthesize(f0_converted, decoded_sp, ap, SAMPLE_RATE)

                    name = f'{speaker}-{target}_iter{self.test_iters}_{filename}'
                    path = os.path.join(self.result_dir, name)
                    print(f'[保存]:{path}')
                    librosa.output.write_wav(path, wav, SAMPLE_RATE)
class Solver(object):
    """Solver for training and testing StarGAN."""
    def __init__(self, train_loader, test_loader, config):
        """Initialize configurations."""

        # Data loader.
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.sampling_rate = config.sampling_rate

        # Model configurations.
        self.num_speakers = config.num_speakers
        self.lambda_cls = config.lambda_cls
        self.lambda_rec = config.lambda_rec
        self.lambda_gp = config.lambda_gp
        self.lambda_id = config.lambda_id

        # Training configurations.
        self.batch_size = config.batch_size
        self.num_iters = config.num_iters
        self.num_iters_decay = config.num_iters_decay
        self.g_lr = config.g_lr
        self.d_lr = config.d_lr
        self.c_lr = config.c_lr
        self.n_critic = config.n_critic
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.resume_iters = config.resume_iters

        # Test configurations.
        self.test_iters = config.test_iters

        # Miscellaneous.
        self.use_tensorboard = config.use_tensorboard
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')

        # Directories.
        self.log_dir = config.log_dir
        self.sample_dir = config.sample_dir
        self.model_save_dir = config.model_save_dir

        # Step size.
        self.log_step = config.log_step
        self.sample_step = config.sample_step
        self.model_save_step = config.model_save_step
        self.lr_update_step = config.lr_update_step

        # Build the model and tensorboard.
        self.build_model()
        if self.use_tensorboard:
            self.build_tensorboard()

    def build_model(self):
        """Create a generator and a discriminator."""
        self.generator = Generator()
        self.discriminator = Discriminator(num_speakers=self.num_speakers)
        self.classifier = DomainClassifier()

        self.g_optimizer = torch.optim.Adam(self.generator.parameters(),
                                            self.g_lr,
                                            [self.beta1, self.beta2])
        self.d_optimizer = torch.optim.Adam(self.discriminator.parameters(),
                                            self.d_lr,
                                            [self.beta1, self.beta2])
        self.c_optimizer = torch.optim.Adam(self.discriminator.parameters(),
                                            self.c_lr,
                                            [self.beta1, self.beta2])

        self.print_network(self.generator, 'Generator')
        self.print_network(self.discriminator, 'Discriminator')
        self.print_network(self.classifier, 'Domain Classifier')

        self.generator.to(self.device)
        self.discriminator.to(self.device)
        self.classifier.to(self.device)

    def print_network(self, model, name):
        """Print out the network information."""
        num_params = 0
        for p in model.parameters():
            num_params += p.numel()
        print(model)
        print(name)
        print("The number of parameters: {}".format(num_params))

    def restore_model(self, resume_iters):
        """Restore the trained generator and discriminator."""
        print(
            'Loading the trained models from step {}...'.format(resume_iters))
        g_path = os.path.join(self.model_save_dir,
                              '{}-G.ckpt'.format(resume_iters))
        d_path = os.path.join(self.model_save_dir,
                              '{}-D.ckpt'.format(resume_iters))
        c_path = os.path.join(self.model_save_dir,
                              '{}-C.ckpt'.format(resume_iters))

        self.generator.load_state_dict(
            torch.load(g_path, map_location=lambda storage, loc: storage))
        self.discriminator.load_state_dict(
            torch.load(d_path, map_location=lambda storage, loc: storage))
        self.classifier.load_state_dict(
            torch.load(c_path, map_location=lambda storage, loc: storage))

    def build_tensorboard(self):
        """Build a tensorboard logger."""
        from logger import Logger
        self.logger = Logger(self.log_dir)

    def update_lr(self, g_lr, d_lr, c_lr):
        """Decay learning rates of the generator and discriminator."""
        for param_group in self.g_optimizer.param_groups:
            param_group['lr'] = g_lr
        for param_group in self.d_optimizer.param_groups:
            param_group['lr'] = d_lr
        for param_group in self.c_optimizer.param_groups:
            param_group['lr'] = c_lr

    def reset_grad(self):
        """Reset the gradientgradient buffers."""
        self.g_optimizer.zero_grad()
        self.d_optimizer.zero_grad()
        self.c_optimizer.zero_grad()

    def denorm(self, x):
        """Convert the range from [-1, 1] to [0, 1]."""
        out = (x + 1) / 2
        return out.clamp_(0, 1)

    def gradient_penalty(self, y, x):
        """Compute gradient penalty: (L2_norm(dy/dx) - 1)**2."""
        weight = torch.ones(y.size()).to(self.device)
        dydx = torch.autograd.grad(outputs=y,
                                   inputs=x,
                                   grad_outputs=weight,
                                   retain_graph=True,
                                   create_graph=True,
                                   only_inputs=True)[0]

        dydx = dydx.view(dydx.size(0), -1)
        dydx_l2norm = torch.sqrt(torch.sum(dydx**2, dim=1))
        return torch.mean((dydx_l2norm - 1)**2)

    def label2onehot(self, labels, dim):
        """Convert label indices to one-hot vectors."""
        batch_size = labels.size(0)
        out = torch.zeros(batch_size, dim)
        out[np.arange(batch_size), labels.long()] = 1
        return out

    def sample_spk_c(self, size):
        spk_c = np.random.randint(0, self.num_speakers, size=size)
        spk_c_cat = to_categorical(spk_c, self.num_speakers)
        return torch.LongTensor(spk_c), torch.FloatTensor(spk_c_cat)

    def classification_loss(self, logit, target):
        """Compute softmax cross entropy loss."""
        return F.cross_entropy(logit, target)

    def load_wav(self, wavfile, sr=16000):
        wav, _ = librosa.load(wavfile, sr=sr, mono=True)
        return wav_padding(wav, sr=16000, frame_period=5, multiple=4)

    def train(self):
        """Train StarGAN."""
        # Set data loader.
        train_loader = self.train_loader
        data_iter = iter(train_loader)

        # Read a batch of testdata
        test_wavfiles = self.test_loader.get_batch_test_data(batch_size=4)
        test_wavs = [self.load_wav(wavfile) for wavfile in test_wavfiles]

        # Determine whether do copysynthesize when first do training-time conversion test.
        cpsyn_flag = [True, False][0]
        # f0, timeaxis, sp, ap = world_decompose(wav = wav, fs = sampling_rate, frame_period = frame_period)

        # Learning rate cache for decaying.
        g_lr = self.g_lr
        d_lr = self.d_lr
        c_lr = self.c_lr

        # Start training from scratch or resume training.
        start_iters = 0
        if self.resume_iters:
            print("resuming step %d ..." % self.resume_iters)
            start_iters = self.resume_iters
            self.restore_model(self.resume_iters)

        # Start training.
        print('Start training...')
        start_time = time.time()
        for i in range(start_iters, self.num_iters):
            # =================================================================================== #
            #                             1. Preprocess input data                                #
            # =================================================================================== #

            # Fetch labels.
            try:
                mc_real, spk_label_org, spk_c_org = next(data_iter)
            except:
                data_iter = iter(train_loader)
                mc_real, spk_label_org, spk_c_org = next(data_iter)

            mc_real.unsqueeze_(1)  # (B, D, T) -> (B, 1, D, T) for conv2d

            # Generate target domain labels randomly.
            # spk_label_trg: int,   spk_c_trg:one-hot representation
            spk_label_trg, spk_c_trg = self.sample_spk_c(mc_real.size(0))

            mc_real = mc_real.to(self.device)  # Input mc.
            spk_label_org = spk_label_org.to(
                self.device)  # Original spk labels.
            spk_c_org = spk_c_org.to(self.device)  # Original spk one-hot.
            spk_label_trg = spk_label_trg.to(self.device)  # Target spk labels.
            spk_c_trg = spk_c_trg.to(self.device)  # Target spk one-hot.

            # =================================================================================== #
            #                             2. Train the Domain Classifier                           #
            # =================================================================================== #

            # Compute real classification loss.
            cls_real = self.classifier(mc_real)
            cls_loss = self.classification_loss(cls_real, spk_label_org)

            # Backwards and optimize
            self.reset_grad()
            cls_loss.backward()
            self.c_optimizer.step()

            # Logging.
            loss = {}
            loss['C/c_loss'] = cls_loss.item()

            # =================================================================================== #
            #                             3. Train the Discriminator                              #
            # =================================================================================== #

            # Compute loss with real mc feats.
            d_out_src = self.discriminator(mc_real, spk_c_org)
            mc_fake = self.generator(mc_real, spk_c_trg)
            d_out_fake = self.discriminator(mc_fake.detach(), spk_c_trg)
            d_loss = F.binary_cross_entropy_with_logits(d_out_fake, torch.zeros_like(d_out_fake, dtype=torch.float)) + \
                F.binary_cross_entropy_with_logits(d_out_src, torch.ones_like(d_out_src, dtype=torch.float))

            # Compute classification loss.
            d_out_cls = self.classifier(mc_fake)
            d_loss_cls = self.classification_loss(d_out_cls, spk_label_trg)

            # Compute loss for gradient penalty.
            alpha = torch.rand(mc_real.size(0), 1, 1, 1).to(self.device)
            x_hat = (alpha * mc_real.data +
                     (1 - alpha) * mc_fake.data).requires_grad_(True)
            d_out_src = self.discriminator(x_hat, spk_c_trg)
            d_loss_gp = self.gradient_penalty(d_out_src, x_hat)

            # Backward and optimize.
            d_loss = d_loss + self.lambda_cls * d_loss_cls + self.lambda_gp * d_loss_gp
            self.reset_grad()
            d_loss.backward()
            self.d_optimizer.step()

            # Logging.
            loss['D/loss_gp'] = d_loss_gp.item()
            loss['D/loss'] = d_loss.item()

            # =================================================================================== #
            #                               4. Train the generator                                #
            # =================================================================================== #

            if (i + 1) % self.n_critic == 0:
                # Original-to-target domain.
                mc_fake = self.generator(mc_real, spk_c_trg)
                g_out_src = self.discriminator(mc_fake, spk_c_trg)
                g_loss_fake = -torch.mean(g_out_src)

                # Classification loss.
                g_out_cls = self.classifier(mc_fake)
                g_loss_cls = self.classification_loss(g_out_cls, spk_label_trg)

                # Target-to-original domain. Cycle-consistent.
                mc_reconst = self.generator(mc_fake, spk_c_org)
                g_loss_rec = torch.mean(torch.abs(mc_real - mc_reconst))

                # Original-to-original, Id mapping loss. Mapping
                mc_fake_id = self.generator(mc_real, spk_c_org)
                g_loss_id = torch.mean(torch.abs(mc_real - mc_fake_id))

                # Backward and optimize.
                g_loss = g_loss_fake \
                    + self.lambda_rec * g_loss_rec \
                    + self.lambda_cls * g_loss_cls \
                    + self.lambda_id * g_loss_id

                self.reset_grad()
                g_loss.backward()
                self.g_optimizer.step()

                # Logging.
                loss['G/loss_fake'] = g_loss_fake.item()
                loss['G/loss_rec'] = g_loss_rec.item()
                loss['G/loss_cls'] = g_loss_cls.item()
                loss['G/loss'] = g_loss.item()

            # =================================================================================== #
            #                                 5. Miscellaneous                                    #
            # =================================================================================== #

            # Print out training information.
            if (i + 1) % self.log_step == 0:
                et = time.time() - start_time
                et = str(datetime.timedelta(seconds=et))[:-7]
                log = "Elapsed [{}], Iteration [{}/{}]".format(
                    et, i + 1, self.num_iters)
                for tag, value in loss.items():
                    log += ", {}: {:.4f}".format(tag, value)
                print(log)

                if self.use_tensorboard:
                    for tag, value in loss.items():
                        self.logger.scalar_summary(tag, value, i + 1)

            if (i + 1) % self.sample_step == 0:
                sampling_rate = 16000
                num_mcep = 36
                frame_period = 5
                with torch.no_grad():
                    for idx, wav in tqdm(enumerate(test_wavs)):
                        wav_name = basename(test_wavfiles[idx])
                        # print(wav_name)
                        f0, timeaxis, sp, ap = world_decompose(
                            wav=wav,
                            fs=sampling_rate,
                            frame_period=frame_period)
                        f0_converted = pitch_conversion(
                            f0=f0,
                            mean_log_src=self.test_loader.logf0s_mean_src,
                            std_log_src=self.test_loader.logf0s_std_src,
                            mean_log_target=self.test_loader.logf0s_mean_trg,
                            std_log_target=self.test_loader.logf0s_std_trg)
                        coded_sp = world_encode_spectral_envelop(
                            sp=sp, fs=sampling_rate, dim=num_mcep)

                        coded_sp_norm = (coded_sp -
                                         self.test_loader.mcep_mean_src
                                         ) / self.test_loader.mcep_std_src
                        coded_sp_norm_tensor = torch.FloatTensor(
                            coded_sp_norm.T).unsqueeze_(0).unsqueeze_(1).to(
                                self.device)
                        conds = torch.FloatTensor(
                            self.test_loader.spk_c_trg).to(self.device)
                        # print(conds.size())
                        coded_sp_converted_norm = self.generator(
                            coded_sp_norm_tensor, conds).data.cpu().numpy()
                        coded_sp_converted = np.squeeze(
                            coded_sp_converted_norm
                        ).T * self.test_loader.mcep_std_trg + self.test_loader.mcep_mean_trg
                        coded_sp_converted = np.ascontiguousarray(
                            coded_sp_converted)
                        # decoded_sp_converted = world_decode_spectral_envelop(coded_sp = coded_sp_converted, fs = sampling_rate)
                        wav_transformed = world_speech_synthesis(
                            f0=f0_converted,
                            coded_sp=coded_sp_converted,
                            ap=ap,
                            fs=sampling_rate,
                            frame_period=frame_period)

                        librosa.output.write_wav(
                            join(
                                self.sample_dir,
                                str(i + 1) + '-' + wav_name.split('.')[0] +
                                '-vcto-{}'.format(self.test_loader.trg_spk) +
                                '.wav'), wav_transformed, sampling_rate)
                        if cpsyn_flag:
                            wav_cpsyn = world_speech_synthesis(
                                f0=f0,
                                coded_sp=coded_sp,
                                ap=ap,
                                fs=sampling_rate,
                                frame_period=frame_period)
                            librosa.output.write_wav(
                                join(self.sample_dir, 'cpsyn-' + wav_name),
                                wav_cpsyn, sampling_rate)
                    cpsyn_flag = False

            # Save model checkpoints.
            if (i + 1) % self.model_save_step == 0:
                g_path = os.path.join(self.model_save_dir,
                                      '{}-G.ckpt'.format(i + 1))
                d_path = os.path.join(self.model_save_dir,
                                      '{}-D.ckpt'.format(i + 1))
                c_path = os.path.join(self.model_save_dir,
                                      '{}-C.ckpt'.format(i + 1))

                torch.save(self.generator.state_dict(), g_path)
                torch.save(self.discriminator.state_dict(), d_path)
                torch.save(self.classifier.state_dict(), c_path)
                print('Saved model checkpoints into {}...'.format(
                    self.model_save_dir))

            # Decay learning rates.
            if (i + 1) % self.lr_update_step == 0 and (i + 1) > (
                    self.num_iters - self.num_iters_decay):
                g_lr -= (self.g_lr / float(self.num_iters_decay))
                d_lr -= (self.d_lr / float(self.num_iters_decay))
                c_lr -= (self.c_lr / float(self.num_iters_decay))
                self.update_lr(g_lr, d_lr, c_lr)
                print('Decayed learning rates, g_lr: {}, d_lr: {}, c_lr: {}.'.
                      format(g_lr, d_lr, c_lr))
Ejemplo n.º 3
0
class Solver(object):
    """docstring for Solver."""
    def __init__(self, data_loader, config):

        self.config = config
        self.data_loader = data_loader
        # Model configurations.

        self.lambda_cycle = config.lambda_cycle
        self.lambda_cls = config.lambda_cls
        self.lambda_identity = config.lambda_identity

        # Training configurations.
        self.data_dir = config.data_dir
        self.test_dir = config.test_dir
        self.batch_size = config.batch_size
        self.num_iters = config.num_iters
        self.num_iters_decay = config.num_iters_decay
        self.g_lr = config.g_lr
        self.d_lr = config.d_lr
        self.c_lr = config.c_lr
        self.n_critic = config.n_critic
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.resume_iters = config.resume_iters

        # Test configurations.
        self.test_iters = config.test_iters
        self.trg_speaker = ast.literal_eval(config.trg_speaker)
        self.src_speaker = config.src_speaker

        # Miscellaneous.
        self.use_tensorboard = config.use_tensorboard
        self.device = torch.device(
            'cuda:0' if torch.cuda.is_available() else 'cpu')
        self.spk_enc = LabelBinarizer().fit(speakers)
        # Directories.
        self.log_dir = config.log_dir
        self.sample_dir = config.sample_dir
        self.model_save_dir = config.model_save_dir
        self.result_dir = config.result_dir

        # Step size.
        self.log_step = config.log_step
        self.sample_step = config.sample_step
        self.model_save_step = config.model_save_step
        self.lr_update_step = config.lr_update_step

        # Build the model and tensorboard.
        self.build_model()
        if self.use_tensorboard:
            self.build_tensorboard()

    def build_model(self):
        self.G = Generator()
        self.D = Discriminator()
        self.C = DomainClassifier()

        self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr,
                                            [self.beta1, self.beta2])
        self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr,
                                            [self.beta1, self.beta2])
        self.c_optimizer = torch.optim.Adam(self.C.parameters(), self.c_lr,
                                            [self.beta1, self.beta2])

        self.print_network(self.G, 'G')
        self.print_network(self.D, 'D')
        self.print_network(self.C, 'C')

        self.G.to(self.device)
        self.D.to(self.device)
        self.C.to(self.device)

    def print_network(self, model, name):
        """Print out the network information."""
        num_params = 0
        for p in model.parameters():
            num_params += p.numel()
        print(model)
        print(name)
        print("The number of parameters: {}".format(num_params))

    def build_tensorboard(self):
        """Build a tensorboard logger."""
        from logger import Logger
        self.logger = Logger(self.log_dir)

    def update_lr(self, g_lr, d_lr, c_lr):
        """Decay learning rates of the generator and discriminator and classifier."""
        for param_group in self.g_optimizer.param_groups:
            param_group['lr'] = g_lr
        for param_group in self.d_optimizer.param_groups:
            param_group['lr'] = d_lr
        for param_group in self.c_optimizer.param_groups:
            param_group['lr'] = c_lr

    def train(self):
        # Learning rate cache for decaying.
        g_lr = self.g_lr
        d_lr = self.d_lr
        c_lr = self.c_lr

        start_iters = 0
        if self.resume_iters:
            pass

        norm = Normalizer()
        data_iter = iter(self.data_loader)

        print('Start training......')
        start_time = datetime.now()

        for i in range(start_iters, self.num_iters):
            # =================================================================================== #
            #                             1. Preprocess input data                                #
            # =================================================================================== #
            # Fetch real images and labels.
            try:
                x_real, speaker_idx_org, label_org = next(data_iter)
            except:
                data_iter = iter(self.data_loader)
                x_real, speaker_idx_org, label_org = next(data_iter)

            # Generate target domain labels randomly.
            rand_idx = torch.randperm(label_org.size(0))
            label_trg = label_org[rand_idx]
            speaker_idx_trg = speaker_idx_org[rand_idx]

            x_real = x_real.to(self.device)  # Input images.
            label_org = label_org.to(
                self.device)  # Original domain one-hot labels.
            label_trg = label_trg.to(
                self.device)  # Target domain one-hot labels.
            speaker_idx_org = speaker_idx_org.to(
                self.device)  # Original domain labels
            speaker_idx_trg = speaker_idx_trg.to(
                self.device)  #Target domain labels

            # =================================================================================== #
            #                             2. Train the discriminator                              #
            # =================================================================================== #
            # Compute loss with real audio frame.
            CELoss = nn.CrossEntropyLoss()
            cls_real = self.C(x_real)
            cls_loss_real = CELoss(input=cls_real, target=speaker_idx_org)

            self.reset_grad()
            cls_loss_real.backward()
            self.c_optimizer.step()
            # Logging.
            loss = {}
            loss['C/C_loss'] = cls_loss_real.item()

            out_r = self.D(x_real, label_org)
            # Compute loss with fake audio frame.
            x_fake = self.G(x_real, label_trg)
            out_f = self.D(x_fake.detach(), label_trg)
            d_loss_t = F.binary_cross_entropy_with_logits(input=out_f,target=torch.zeros_like(out_f, dtype=torch.float)) + \
                F.binary_cross_entropy_with_logits(input=out_r, target=torch.ones_like(out_r, dtype=torch.float))

            out_cls = self.C(x_fake)
            d_loss_cls = CELoss(input=out_cls, target=speaker_idx_trg)

            # Compute loss for gradient penalty.
            alpha = torch.rand(x_real.size(0), 1, 1, 1).to(self.device)
            x_hat = (alpha * x_real.data +
                     (1 - alpha) * x_fake.data).requires_grad_(True)
            out_src = self.D(x_hat, label_trg)
            d_loss_gp = self.gradient_penalty(out_src, x_hat)

            d_loss = d_loss_t + self.lambda_cls * d_loss_cls + 5 * d_loss_gp

            self.reset_grad()
            d_loss.backward()
            self.d_optimizer.step()

            # loss['D/d_loss_t'] = d_loss_t.item()
            # loss['D/loss_cls'] = d_loss_cls.item()
            # loss['D/D_gp'] = d_loss_gp.item()
            loss['D/D_loss'] = d_loss.item()

            # =================================================================================== #
            #                               3. Train the generator                                #
            # =================================================================================== #
            if (i + 1) % self.n_critic == 0:
                # Original-to-target domain.
                x_fake = self.G(x_real, label_trg)
                g_out_src = self.D(x_fake, label_trg)
                g_loss_fake = F.binary_cross_entropy_with_logits(
                    input=g_out_src,
                    target=torch.ones_like(g_out_src, dtype=torch.float))

                out_cls = self.C(x_fake)
                g_loss_cls = CELoss(input=out_cls, target=speaker_idx_trg)

                # Target-to-original domain.
                x_reconst = self.G(x_fake, label_org)
                g_loss_rec = F.l1_loss(x_reconst, x_real)

                # Original-to-Original domain(identity).
                x_fake_iden = self.G(x_real, label_org)
                id_loss = F.l1_loss(x_fake_iden, x_real)

                # Backward and optimize.
                g_loss = g_loss_fake + self.lambda_cycle * g_loss_rec +\
                 self.lambda_cls * g_loss_cls + self.lambda_identity * id_loss

                self.reset_grad()
                g_loss.backward()
                self.g_optimizer.step()

                # Logging.
                loss['G/loss_fake'] = g_loss_fake.item()
                loss['G/loss_rec'] = g_loss_rec.item()
                loss['G/loss_cls'] = g_loss_cls.item()
                loss['G/loss_id'] = id_loss.item()
                loss['G/g_loss'] = g_loss.item()
            # =================================================================================== #
            #                                 4. Miscellaneous                                    #
            # =================================================================================== #
            # Print out training information.
            if (i + 1) % self.log_step == 0:
                et = datetime.now() - start_time
                et = str(et)[:-7]
                log = "Elapsed [{}], Iteration [{}/{}]".format(
                    et, i + 1, self.num_iters)
                for tag, value in loss.items():
                    log += ", {}: {:.4f}".format(tag, value)
                print(log)

                if self.use_tensorboard:
                    for tag, value in loss.items():
                        self.logger.scalar_summary(tag, value, i + 1)

            # Translate fixed images for debugging.
            if (i + 1) % self.sample_step == 0:
                with torch.no_grad():
                    d, speaker = TestSet(self.test_dir).test_data()
                    target = random.choice(
                        [x for x in speakers if x != speaker])
                    label_t = self.spk_enc.transform([target])[0]
                    label_t = np.asarray([label_t])

                    for filename, content in d.items():
                        f0 = content['f0']
                        ap = content['ap']
                        sp_norm_pad = self.pad_coded_sp(
                            content['coded_sp_norm'])

                        convert_result = []
                        for start_idx in range(
                                0, sp_norm_pad.shape[1] - FRAMES + 1, FRAMES):
                            one_seg = sp_norm_pad[:,
                                                  start_idx:start_idx + FRAMES]

                            one_seg = torch.FloatTensor(one_seg).to(
                                self.device)
                            one_seg = one_seg.view(1, 1, one_seg.size(0),
                                                   one_seg.size(1))
                            l = torch.FloatTensor(label_t)
                            one_seg = one_seg.to(self.device)
                            l = l.to(self.device)
                            one_set_return = self.G(one_seg,
                                                    l).data.cpu().numpy()
                            one_set_return = np.squeeze(one_set_return)
                            one_set_return = norm.backward_process(
                                one_set_return, target)
                            convert_result.append(one_set_return)

                        convert_con = np.concatenate(convert_result, axis=1)
                        convert_con = convert_con[:,
                                                  0:content['coded_sp_norm'].
                                                  shape[1]]
                        contigu = np.ascontiguousarray(convert_con.T,
                                                       dtype=np.float64)
                        decoded_sp = decode_spectral_envelope(contigu,
                                                              SAMPLE_RATE,
                                                              fft_size=FFTSIZE)
                        f0_converted = norm.pitch_conversion(
                            f0, speaker, target)
                        wav = synthesize(f0_converted, decoded_sp, ap,
                                         SAMPLE_RATE)

                        name = f'{speaker}-{target}_iter{i+1}_{filename}'
                        path = os.path.join(self.sample_dir, name)
                        print(f'[save]:{path}')
                        librosa.output.write_wav(path, wav, SAMPLE_RATE)

            # Save model checkpoints.
            if (i + 1) % self.model_save_step == 0:
                G_path = os.path.join(self.model_save_dir,
                                      '{}-G.ckpt'.format(i + 1))
                D_path = os.path.join(self.model_save_dir,
                                      '{}-D.ckpt'.format(i + 1))
                C_path = os.path.join(self.model_save_dir,
                                      '{}-C.ckpt'.format(i + 1))
                torch.save(self.G.state_dict(), G_path)
                torch.save(self.D.state_dict(), D_path)
                torch.save(self.C.state_dict(), C_path)
                print('Saved model checkpoints into {}...'.format(
                    self.model_save_dir))

            # Decay learning rates.
            if (i + 1) % self.lr_update_step == 0 and (i + 1) > (
                    self.num_iters - self.num_iters_decay):
                g_lr -= (self.g_lr / float(self.num_iters_decay))
                d_lr -= (self.d_lr / float(self.num_iters_decay))
                c_lr -= (self.c_lr / float(self.num_iters_decay))
                self.update_lr(g_lr, d_lr, c_lr)
                print('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(
                    g_lr, d_lr))

    def gradient_penalty(self, y, x):
        """Compute gradient penalty: (L2_norm(dy/dx) - 1)**2."""
        weight = torch.ones(y.size()).to(self.device)
        dydx = torch.autograd.grad(outputs=y,
                                   inputs=x,
                                   grad_outputs=weight,
                                   retain_graph=True,
                                   create_graph=True,
                                   only_inputs=True)[0]

        dydx = dydx.view(dydx.size(0), -1)
        dydx_l2norm = torch.sqrt(torch.sum(dydx**2, dim=1))
        return torch.mean((dydx_l2norm - 1)**2)

    def reset_grad(self):
        """Reset the gradient buffers."""
        self.g_optimizer.zero_grad()
        self.d_optimizer.zero_grad()
        self.c_optimizer.zero_grad()

    def restore_model(self, resume_iters):
        """Restore the trained generator and discriminator."""
        print(
            'Loading the trained models from step {}...'.format(resume_iters))
        G_path = os.path.join(self.model_save_dir,
                              '{}-G.ckpt'.format(resume_iters))
        D_path = os.path.join(self.model_save_dir,
                              '{}-D.ckpt'.format(resume_iters))
        C_path = os.path.join(self.model_save_dir,
                              '{}-C.ckpt'.format(resume_iters))
        self.G.load_state_dict(
            torch.load(G_path, map_location=lambda storage, loc: storage))
        self.D.load_state_dict(
            torch.load(D_path, map_location=lambda storage, loc: storage))
        self.C.load_state_dict(
            torch.load(C_path, map_location=lambda storage, loc: storage))

    @staticmethod
    def pad_coded_sp(coded_sp_norm):
        f_len = coded_sp_norm.shape[1]
        if f_len >= FRAMES:
            pad_length = FRAMES - (f_len - (f_len // FRAMES) * FRAMES)
        elif f_len < FRAMES:
            pad_length = FRAMES - f_len

        sp_norm_pad = np.hstack(
            (coded_sp_norm, np.zeros((coded_sp_norm.shape[0], pad_length))))
        return sp_norm_pad

    def test(self):
        """Translate speech using StarGAN ."""
        # Load the trained generator.
        self.restore_model(self.test_iters)
        norm = Normalizer()

        # Set data loader.
        d, speaker = TestSet(self.test_dir).test_data(self.src_speaker)
        targets = self.trg_speaker

        for target in targets:
            print(target)
            assert target in speakers
            label_t = self.spk_enc.transform([target])[0]
            label_t = np.asarray([label_t])

            with torch.no_grad():

                for filename, content in d.items():
                    f0 = content['f0']
                    ap = content['ap']
                    sp_norm_pad = self.pad_coded_sp(content['coded_sp_norm'])

                    convert_result = []
                    for start_idx in range(0,
                                           sp_norm_pad.shape[1] - FRAMES + 1,
                                           FRAMES):
                        one_seg = sp_norm_pad[:, start_idx:start_idx + FRAMES]

                        one_seg = torch.FloatTensor(one_seg).to(self.device)
                        one_seg = one_seg.view(1, 1, one_seg.size(0),
                                               one_seg.size(1))
                        l = torch.FloatTensor(label_t)
                        one_seg = one_seg.to(self.device)
                        l = l.to(self.device)
                        one_set_return = self.G(one_seg, l).data.cpu().numpy()
                        one_set_return = np.squeeze(one_set_return)
                        one_set_return = norm.backward_process(
                            one_set_return, target)
                        convert_result.append(one_set_return)

                    convert_con = np.concatenate(convert_result, axis=1)
                    convert_con = convert_con[:, 0:content['coded_sp_norm'].
                                              shape[1]]
                    contigu = np.ascontiguousarray(convert_con.T,
                                                   dtype=np.float64)
                    decoded_sp = decode_spectral_envelope(contigu,
                                                          SAMPLE_RATE,
                                                          fft_size=FFTSIZE)
                    f0_converted = norm.pitch_conversion(f0, speaker, target)
                    wav = synthesize(f0_converted, decoded_sp, ap, SAMPLE_RATE)

                    name = f'{speaker}-{target}_iter{self.test_iters}_{filename}'
                    path = os.path.join(self.result_dir, name)
                    print(f'[save]:{path}')
                    librosa.output.write_wav(path, wav, SAMPLE_RATE)
Ejemplo n.º 4
0
train_dataloader = DataLoader(train_data,
                              batch_size=1024,
                              shuffle=True,
                              drop_last=True)
test_dataloader = DataLoader(test_data,
                             batch_size=1024,
                             shuffle=True,
                             drop_last=True)

# -------------------------------------- Training Stage ------------------------------------------- #

precision = 1e-8

feature_extractor = FeatureExtractor().cuda()
label_predictor = LabelPredictor().cuda()
domain_classifier = DomainClassifier().cuda()

class_criterion = nn.CrossEntropyLoss()
domain_criterion = nn.CrossEntropyLoss()

optimizer_F = optim.Adam(feature_extractor.parameters())
optimizer_C = optim.Adam(label_predictor.parameters())
optimizer_D = optim.Adam(label_predictor.parameters())

scheduler_F = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer_F,
                                                   mode='min',
                                                   factor=0.1,
                                                   patience=8,
                                                   verbose=True,
                                                   eps=precision)
scheduler_C = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer_C,
#
# total_users_test_dataset = [list_total_user_test_data, list_total_user_test_labels]
# print('Test dataset formatted finished.')
#
# np.save("formatted_datasets/saved_total_users_test_dataset_xu.npy", total_users_test_dataset)
# print('Train dataset saved.')

test_dataset = np.load(
    "formatted_datasets/saved_total_users_test_dataset_xu_8_subjects.npy",
    encoding="bytes",
    allow_pickle=True)
print(np.shape(test_dataset))

# -------------------------------------------- Load Model ----------------------------------------------------- #
feature_extractor = FeatureExtractor().cuda()
domain_classifier = DomainClassifier().cuda()
label_predictor = LabelPredictor().cuda()

feature_extractor.load_state_dict(
    torch.load(r'saved_model\feature_extractor_CE_8_subjects.pkl'))
domain_classifier.load_state_dict(
    torch.load(r'saved_model\domain_classifier_CE_8_subjects.pkl'))
label_predictor.load_state_dict(
    torch.load(r'saved_model\label_predictor_CE_8_subjects.pkl'))

print(feature_extractor)
print(domain_classifier)
print(label_predictor)

print('Model loaded.')
# time.sleep(1000)
Ejemplo n.º 6
0
def main():
    args = parse_arguments()
    # argument setting
    print("=== Argument Setting ===")
    print("src: " + args.src)
    print("tgt: " + args.tgt)
    print("alpha: " + str(args.alpha))
    print("seed: " + str(args.seed))
    print("train_seed: " + str(args.train_seed))
    print("model_type: " + str(args.model))
    print("max_seq_length: " + str(args.max_seq_length))
    print("batch_size: " + str(args.batch_size))
    print("num_epochs: " + str(args.num_epochs))
    set_seed(args.train_seed)

    if args.model == 'roberta':
        tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
    else:
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    # preprocess data
    print("=== Processing datasets ===")
    if args.src == 'blog':
        src_x, src_y = CSV2Array(os.path.join('data', args.src, 'blog.csv'))

    elif args.src == 'airline':
        src_x, src_y = CSV2Array(os.path.join('data', args.src, 'airline.csv'))

    else:
        src_x, src_y = XML2Array(
            os.path.join('data', args.src, 'negative.review'),
            os.path.join('data', args.src, 'positive.review'))

    src_x, src_test_x, src_y, src_test_y = train_test_split(
        src_x, src_y, test_size=0.2, stratify=src_y, random_state=args.seed)

    if args.tgt == 'blog':
        tgt_x, tgt_y = CSV2Array(os.path.join('data', args.tgt, 'blog.csv'))

    elif args.tgt == 'airline':
        tgt_x, tgt_y = CSV2Array(os.path.join('data', args.tgt, 'airline.csv'))
    else:
        tgt_x, tgt_y = XML2Array(
            os.path.join('data', args.tgt, 'negative.review'),
            os.path.join('data', args.tgt, 'positive.review'))

    tgt_train_x, _, tgt_train_y, _ = train_test_split(tgt_x,
                                                      tgt_y,
                                                      test_size=0.2,
                                                      stratify=tgt_y,
                                                      random_state=args.seed)

    if args.model == 'roberta':
        src_features = roberta_convert_examples_to_features(
            src_x, src_y, args.max_seq_length, tokenizer)
        src_test_features = roberta_convert_examples_to_features(
            src_test_x, src_test_y, args.max_seq_length, tokenizer)
        tgt_features = roberta_convert_examples_to_features(
            tgt_train_x, tgt_train_y, args.max_seq_length, tokenizer)
        tgt_all_features = roberta_convert_examples_to_features(
            tgt_x, tgt_y, args.max_seq_length, tokenizer)
    else:
        src_features = convert_examples_to_features(src_x, src_y,
                                                    args.max_seq_length,
                                                    tokenizer)
        src_test_features = convert_examples_to_features(
            src_test_x, src_test_y, args.max_seq_length, tokenizer)
        tgt_features = convert_examples_to_features(tgt_train_x, tgt_train_y,
                                                    args.max_seq_length,
                                                    tokenizer)
        tgt_all_features = convert_examples_to_features(
            tgt_x, tgt_y, args.max_seq_length, tokenizer)

    # load dataset

    src_data_loader = get_data_loader(src_features, args.batch_size)
    src_data_loader_eval = get_data_loader(src_test_features, args.batch_size)
    tgt_data_loader = get_data_loader(tgt_features, args.batch_size)
    tgt_data_loader_all = get_data_loader(tgt_all_features, args.batch_size)

    # load models
    if args.model == 'bert':
        encoder = BertEncoder()
        cls_classifier = BertClassifier()
        dom_classifier = DomainClassifier()
    elif args.model == 'distilbert':
        encoder = DistilBertEncoder()
        cls_classifier = BertClassifier()
        dom_classifier = DomainClassifier()
    else:
        encoder = RobertaEncoder()
        cls_classifier = RobertaClassifier()
        dom_classifier = RobertaDomainClassifier()

    if args.load:
        encoder = init_model(encoder, restore=param.encoder_path)
        cls_classifier = init_model(cls_classifier,
                                    restore=param.cls_classifier_path)
        dom_classifier = init_model(dom_classifier,
                                    restore=param.dom_classifier_path)
    else:
        encoder = init_model(encoder)
        cls_classifier = init_model(cls_classifier)
        dom_classifier = init_model(dom_classifier)

    print("=== Start Training ===")
    if args.train:
        encoder, cls_classifier, dom_classifier = train(
            args, encoder, cls_classifier, dom_classifier, src_data_loader,
            src_data_loader_eval, tgt_data_loader, tgt_data_loader_all)

    print("=== Evaluating classifier for encoded target domain ===")
    print(">>> after training <<<")
    evaluate(encoder, cls_classifier, tgt_data_loader_all)
Ejemplo n.º 7
0
def main():
    args = parser.parse_args()
    filenames = glob.glob(os.path.join(args.dataset, '*.tfrecord'))
    speaker_count = len(filenames)
    x_dataset = dataset_builder.build(filenames, prefetch=128, batch=8)
    y_dataset = dataset_builder.build(filenames, prefetch=128, batch=8)
    D, G, C = Discriminator(), Generator(), DomainClassifier()
    G_lr = tf.Variable(0.0, dtype=tf.float64)
    G_optimizer = tf.keras.optimizers.Adam(G_lr, beta_1=0.5, beta_2=0.999)
    D_lr = tf.Variable(0.0, dtype=tf.float64)
    D_optimizer = tf.keras.optimizers.Adam(D_lr, beta_1=0.5, beta_2=0.999)
    C_lr = tf.Variable(0.0, dtype=tf.float64)
    C_optimizer = tf.keras.optimizers.Adam(C_lr, beta_1=0.5, beta_2=0.999)
    # G_optimizer = tf.keras.optimizers.SGD(0.0001)
    # D_optimizer = tf.keras.optimizers.SGD(0.0001)
    # C_optimizer = tf.keras.optimizers.SGD(0.0001)

    ce_loss = tf.keras.losses.CategoricalCrossentropy(from_logits=False,
                                                      label_smoothing=0.1)
    # huber_loss = tf.keras.losses.Huber()
    l1_loss = tf.keras.losses.MeanAbsoluteError()

    G_metric = tf.keras.metrics.Mean(name='G_loss')
    D_metric = tf.keras.metrics.Mean(name='D_loss')
    C_metric = tf.keras.metrics.Mean(name='C_loss')

    summary_writer = tf.summary.create_file_writer(args.logdir)

    def train_G(real_x, real_x_attr, real_y, real_y_attr, step):
        with tf.GradientTape() as tape:
            fake_y = G(real_x, real_y_attr, True)
            reconst_x = G(fake_y, real_x_attr, True)
            fake_y_d = D(fake_y, real_y_attr, False)
            fake_y_c = C(fake_y, False)
            fake_x = G(real_x, real_x_attr, True)
            # gan_loss = tf.reduce_mean(-1 * ops.safe_log(fake_y_d))
            gan_loss = tf.reduce_mean(-1 * fake_y_d)
            cycle_loss = l1_loss(real_x, reconst_x)
            cls_loss = ce_loss(real_y_attr, fake_y_c)
            identity_loss = l1_loss(real_x, fake_x)
            loss = gan_loss + 3 * cycle_loss + 2 * cls_loss + 2 * identity_loss
            G_gradients = tape.gradient(loss, G.trainable_variables)
            G_optimizer.apply_gradients(zip(G_gradients,
                                            G.trainable_variables))
        G_lr.assign(
            ops.cosine_lr(step, learning_rate, total_steps,
                          warmup_learning_rate, warmup_steps))
        tf.summary.scalar('loss_G/gan_loss',
                          gan_loss,
                          step=G_optimizer.iterations)
        tf.summary.scalar('loss_G/cycle_loss',
                          cycle_loss,
                          step=G_optimizer.iterations)
        tf.summary.scalar('loss_G/cls_loss',
                          cls_loss,
                          step=G_optimizer.iterations)
        tf.summary.scalar('loss_G/identity_loss',
                          identity_loss,
                          step=G_optimizer.iterations)
        G_metric(loss)
        return loss

    def train_D(real_x, real_x_attr, real_y, real_y_attr, step):
        with tf.GradientTape() as tape:
            fake_y = G(real_x, real_y_attr, False)
            fake_y_d = D(fake_y, real_y_attr, True)
            real_y_d = D(real_y, real_y_attr, True)
            # gan_loss = tf.reduce_mean(
            #     -1 * ops.safe_log(real_y_d) + (-1) * ops.safe_log(1 - fake_y_d))
            gan_loss = 0.5 * tf.reduce_mean(fake_y_d - real_y_d)

            # alpha = tf.random.uniform([tf.shape(fake_y)[0],1,1,1], 0, 1)
            # inter = alpha * real_x + (1-alpha) * fake_y
            with tf.GradientTape() as t:
                t.watch(fake_y)
                pred = D(fake_y, real_y_attr, True)
            grad = t.gradient(pred, [fake_y])[0]
            slopes = tf.sqrt(tf.reduce_sum(tf.square(grad), axis=[1, 2, 3]))
            gradient_penalty = tf.reduce_mean((slopes - 1.)**2)

            loss = gan_loss + 10 * gradient_penalty
            D_gradients = tape.gradient(loss, D.trainable_variables)
            D_optimizer.apply_gradients(zip(D_gradients,
                                            D.trainable_variables))
        D_lr.assign(
            ops.cosine_lr(step, learning_rate, total_steps,
                          warmup_learning_rate, warmup_steps))
        tf.summary.scalar('loss_D/gan_loss',
                          gan_loss,
                          step=D_optimizer.iterations)
        tf.summary.scalar('loss_D/negative_critic_loss',
                          -1 * gan_loss,
                          step=D_optimizer.iterations)
        tf.summary.scalar('loss_D/gp',
                          gradient_penalty,
                          step=D_optimizer.iterations)
        D_metric(loss)
        return loss

    def train_C(real_x, real_x_attr, real_y, real_y_attr, step):
        with tf.GradientTape() as tape:
            real_y_c = C(real_y, True)
            loss = ce_loss(real_y_attr, real_y_c)
            C_gradients = tape.gradient(loss, C.trainable_variables)
            C_optimizer.apply_gradients(zip(C_gradients,
                                            C.trainable_variables))
        C_lr.assign(
            ops.cosine_lr(step, learning_rate, total_steps,
                          warmup_learning_rate, warmup_steps))
        C_metric(loss)
        return loss

    ckpt = tf.train.Checkpoint(D=D,
                               G=G,
                               C=C,
                               G_optimizer=G_optimizer,
                               D_optimizer=D_optimizer,
                               C_optimizer=C_optimizer)
    latest_ckpt = tf.train.latest_checkpoint(args.logdir)
    if latest_ckpt is not None:
        dummy_wav = np.zeros((1, 36, 512, 1), dtype=np.float32)
        dummy_speaker_onehot = np.zeros((1, speaker_count), dtype=np.float32)
        train_D(dummy_wav, dummy_speaker_onehot, dummy_wav,
                dummy_speaker_onehot)
        train_C(dummy_wav, dummy_speaker_onehot, dummy_wav,
                dummy_speaker_onehot)
        train_G(dummy_wav, dummy_speaker_onehot, dummy_wav,
                dummy_speaker_onehot)
        ckpt.restore(latest_ckpt).assert_consumed()
    ckpt_mgr = tf.train.CheckpointManager(ckpt, args.logdir, max_to_keep=5)

    with summary_writer.as_default():
        G_metric.reset_states()
        D_metric.reset_states()
        C_metric.reset_states()

        step = 0
        mcep_normalizer = McepNormalizer('./train/norm_dict.pkl')
        for features in tqdm(zip(x_dataset, y_dataset), total=total_steps):
            x_feature, y_feature = features

            x_id = np.asarray(x_feature['speaker_id'], dtype=np.int32)
            x_id_onehot = tf.one_hot(x_id - 1, speaker_count)
            x = mcep_normalizer.batch_mcep_norm(x_feature['mcep'], x_id)
            x = tf.expand_dims(x, axis=-1)

            y_id = np.asarray(y_feature['speaker_id'], dtype=np.int32)
            y_id_onehot = tf.one_hot(y_id - 1, speaker_count)
            y = mcep_normalizer.batch_mcep_norm(y_feature['mcep'], y_id)
            y = tf.expand_dims(y, axis=-1)

            if (step + 1) % 5 == 0:
                loss_G = train_G(x, x_id_onehot, y, y_id_onehot, step)
                tf.summary.scalar('loss_G',
                                  loss_G,
                                  step=G_optimizer.iterations)
            else:
                loss_D = train_D(x, x_id_onehot, y, y_id_onehot, step)
                loss_C = train_C(x, x_id_onehot, y, y_id_onehot, step)
                tf.summary.scalar('loss_D',
                                  loss_D,
                                  step=D_optimizer.iterations)
                tf.summary.scalar('loss_C',
                                  loss_C,
                                  step=C_optimizer.iterations)
            if (step + 1) % 100 == 0:
                ckpt_mgr.save()
                template = 'Steps {}, G Loss: {}, D Loss: {}, C Loss: {}'
                log_msg = template.format(step, G_metric.result(),
                                          D_metric.result(), C_metric.result())
                print(log_msg)
            if step >= total_steps:
                break
            step += 1
        print('Finish Trainin, saving ckpt...')
        template = 'Steps {}, G Loss: {}, D Loss: {}, C Loss: {}'
        log_msg = template.format(step, G_metric.result(), D_metric.result(),
                                  C_metric.result())
        ckpt_mgr.save()
        print('Done.')
Ejemplo n.º 8
0
                            './real_or_drawing',
                            train=True)
print('dataset done')
target_dataset = ImgDataset(source_transform,
                            target_transform,
                            './real_or_drawing',
                            train=False)
print('dataset done 2')

source_dataloader = DataLoader(source_dataset, batch_size=32, shuffle=True)
target_dataloader = DataLoader(target_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(target_dataset, batch_size=128, shuffle=False)

feature_extractor = FeatureExtractor().cuda()
label_predictor = LabelPredictor().cuda()
domain_classifier = DomainClassifier().cuda()

class_criterion = nn.CrossEntropyLoss()
domain_criterion = nn.BCEWithLogitsLoss()

optimizer_F = optim.Adam(feature_extractor.parameters())  #原為1e-3
optimizer_C = optim.Adam(label_predictor.parameters())
optimizer_D = optim.Adam(domain_classifier.parameters())


def train_epoch(source_dataloader, target_dataloader, lamb):
    '''
      Args:
        source_dataloader: source data的dataloader
        target_dataloader: target data的dataloader
        lamb: 調控adversarial的loss係數。
Ejemplo n.º 9
0
class Solver(object):
    """docstring for Solver."""
    def __init__(self, data_loader, config):

        self.config = config
        self.data_loader = data_loader
        # Model configurations.

        self.lambda_cycle = config.lambda_cycle
        self.lambda_cls = config.lambda_cls
        self.lambda_identity = config.lambda_identity
        self.sigma_d = config.sigma_d

        # Training configurations.
        self.data_dir = config.data_dir
        self.test_dir = config.test_dir
        self.batch_size = config.batch_size
        self.num_iters = config.num_iters
        self.num_iters_decay = config.num_iters_decay
        self.g_lr = config.g_lr
        self.d_lr = config.d_lr
        self.c_lr = config.c_lr
        self.n_critic = config.n_critic
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.resume_iters = config.resume_iters

        # Test configurations.
        self.test_iters = config.test_iters
        self.trg_style = ast.literal_eval(config.trg_style)
        self.src_style = config.src_style

        # Miscellaneous.
        self.use_tensorboard = config.use_tensorboard
        self.device = torch.device(
            'cuda:0' if torch.cuda.is_available() else 'cpu')
        self.stls_enc = LabelBinarizer().fit(styles)
        # Directories.
        self.log_dir = config.log_dir
        self.sample_dir = config.sample_dir
        self.model_save_dir = config.model_save_dir
        self.result_dir = config.result_dir

        # Step size.
        self.log_step = config.log_step
        self.sample_step = config.sample_step
        self.model_save_step = config.model_save_step
        self.lr_update_step = config.lr_update_step

        # Build the model and tensorboard.
        self.build_model()
        if self.use_tensorboard:
            self.build_tensorboard()

    def build_model(self):
        self.G = Generator()
        self.D = Discriminator()
        self.C = DomainClassifier()

        self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr,
                                            [self.beta1, self.beta2])
        self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr,
                                            [self.beta1, self.beta2])
        self.c_optimizer = torch.optim.Adam(self.C.parameters(), self.c_lr,
                                            [self.beta1, self.beta2])

        self.print_network(self.G, 'G')
        self.print_network(self.D, 'D')
        self.print_network(self.C, 'C')

        self.G.to(self.device)
        self.D.to(self.device)
        self.C.to(self.device)

    def print_network(self, model, name):
        """Print out the network information."""
        num_params = 0
        for p in model.parameters():
            num_params += p.numel()
        print(model)
        print(name)
        print("The number of parameters: {}".format(num_params))

    def build_tensorboard(self):
        """Build a tensorboard logger."""
        from logger import Logger
        self.logger = Logger(self.log_dir)

    def update_lr(self, g_lr, d_lr, c_lr):
        """Decay learning rates of the generator and discriminator and classifier."""
        for param_group in self.g_optimizer.param_groups:
            param_group['lr'] = g_lr
        for param_group in self.d_optimizer.param_groups:
            param_group['lr'] = d_lr
        for param_group in self.c_optimizer.param_groups:
            param_group['lr'] = c_lr

    def train(self):
        # Learning rate cache for decaying.
        g_lr = self.g_lr
        d_lr = self.d_lr
        c_lr = self.c_lr

        start_iters = 0
        if self.resume_iters:
            pass

        #norm = Normalizer()
        data_iter = iter(self.data_loader)

        print('Start training......')
        start_time = datetime.now()

        for i in range(start_iters, self.num_iters):
            # =================================================================================== #
            #                             1. Preprocess input data                                #
            # =================================================================================== #
            # Fetch real images and labels.
            try:
                x_real, style_idx_org, label_org = next(data_iter)
            except:
                data_iter = iter(self.data_loader)
                x_real, style_idx_org, label_org = next(data_iter)

            #generate gaussian noise for robustness improvement

            gaussian_noise = self.sigma_d * torch.randn(x_real.size())

            # Generate target domain labels randomly.
            rand_idx = torch.randperm(label_org.size(0))
            label_trg = label_org[rand_idx]
            style_idx_trg = style_idx_org[rand_idx]

            x_real = x_real.to(self.device)  # Input images.
            label_org = label_org.to(
                self.device)  # Original domain one-hot labels.
            label_trg = label_trg.to(
                self.device)  # Target domain one-hot labels.
            style_idx_org = style_idx_org.to(
                self.device)  # Original domain labels
            style_idx_trg = style_idx_trg.to(
                self.device)  #Target domain labels
            gaussian_noise = gaussian_noise.to(
                self.device)  #gaussian noise for discriminators
            # =================================================================================== #
            #                             2. Train the discriminator                              #
            # =================================================================================== #
            # Compute loss with real audio frame.
            CELoss = nn.CrossEntropyLoss()
            cls_real = self.C(x_real)
            cls_loss_real = CELoss(input=cls_real, target=style_idx_org)

            self.reset_grad()
            cls_loss_real.backward()
            self.c_optimizer.step()
            # Logging.
            loss = {}
            loss['C/C_loss'] = cls_loss_real.item()

            out_r = self.D(x_real + gaussian_noise, label_org)
            # Compute loss with fake audio frame.
            x_fake = self.G(x_real, label_trg)
            out_f = self.D(x_fake + gaussian_noise, label_trg)
            d_loss_t = F.mse_loss(input=out_f,target=torch.zeros_like(out_f, dtype=torch.float)) + \
                F.mse_loss(input=out_r, target=torch.ones_like(out_r, dtype=torch.float))

            out_cls = self.C(x_fake)
            d_loss_cls = CELoss(input=out_cls, target=style_idx_trg)

            #Compute loss for gradient penalty.
            #alpha = torch.rand(x_real.size(0), 1, 1, 1).to(self.device)
            #x_hat = (alpha * x_real.data + (1 - alpha) * x_fake.data).requires_grad_(True)
            #out_src = self.D(x_hat, label_trg)
            #d_loss_gp = self.gradient_penalty(out_src, x_hat)

            d_loss = d_loss_t + self.lambda_cls * d_loss_cls
            #\+ 5*d_loss_gp

            self.reset_grad()
            d_loss.backward()
            self.d_optimizer.step()

            # loss['D/d_loss_t'] = d_loss_t.item()
            # loss['D/loss_cls'] = d_loss_cls.item()
            # loss['D/D_gp'] = d_loss_gp.item()
            loss['D/D_loss'] = d_loss.item()

            # =================================================================================== #
            #                               3. Train the generator                                #
            # =================================================================================== #
            if (i + 1) % self.n_critic == 0:
                # Original-to-target domain.
                x_fake = self.G(x_real, label_trg)
                g_out_src = self.D(x_fake + gaussian_noise, label_trg)
                g_loss_fake = F.mse_loss(input=g_out_src,
                                         target=torch.ones_like(
                                             g_out_src, dtype=torch.float))

                out_cls = self.C(x_real)
                g_loss_cls = CELoss(input=out_cls, target=style_idx_org)

                # Target-to-original domain.
                x_reconst = self.G(x_fake, label_org)
                g_loss_rec = F.l1_loss(x_reconst, x_real)

                # Original-to-Original domain(identity).
                x_fake_iden = self.G(x_real, label_org)
                id_loss = F.l1_loss(x_fake_iden, x_real)

                # Backward and optimize.
                g_loss = g_loss_fake + self.lambda_cycle * g_loss_rec +\
                 self.lambda_cls * g_loss_cls + self.lambda_identity * id_loss

                self.reset_grad()
                g_loss.backward()
                self.g_optimizer.step()

                # Logging.
                loss['G/loss_fake'] = g_loss_fake.item()
                loss['G/loss_rec'] = g_loss_rec.item()
                loss['G/loss_cls'] = g_loss_cls.item()
                loss['G/loss_id'] = id_loss.item()
                loss['G/g_loss'] = g_loss.item()
            # =================================================================================== #
            #                                 4. Miscellaneous                                    #
            # =================================================================================== #
            # Print out training information.
            if (i + 1) % self.log_step == 0:
                et = datetime.now() - start_time
                et = str(et)[:-7]
                log = "Elapsed [{}], Iteration [{}/{}]".format(
                    et, i + 1, self.num_iters)
                for tag, value in loss.items():
                    log += ", {}: {:.4f}".format(tag, value)
                print(log)

                if self.use_tensorboard:
                    for tag, value in loss.items():
                        self.logger.scalar_summary(tag, value, i + 1)

            # Translate fixed images for debugging.
            if (i + 1) % self.sample_step == 0:
                with torch.no_grad():
                    d, style = TestSet(self.test_dir).test_data()

                    label_o = self.stls_enc.transform([style])[0]
                    label_o = np.asarray([label_o])

                    target = random.choice([x for x in styles if x != style])
                    label_t = self.stls_enc.transform([target])[0]
                    label_t = np.asarray([label_t])

                    for filename, content in d.items():

                        filename = filename.split('.')[0]

                        one_seg = torch.FloatTensor(content).to(self.device)
                        one_seg = one_seg.view(1, one_seg.size(0),
                                               one_seg.size(1),
                                               one_seg.size(2))
                        l_t = torch.FloatTensor(label_t)
                        one_seg = one_seg.to(self.device)
                        l_t = l_t.to(self.device)

                        one_set_transfer = self.G(one_seg, l_t)

                        l_o = torch.FloatTensor(label_o)
                        l_o = l_o.to(self.device)

                        one_set_cycle = self.G(
                            one_set_transfer.to(self.device),
                            l_o).data.cpu().numpy()
                        one_set_transfer = one_set_transfer.data.cpu().numpy()

                        one_set_transfer_binary = to_binary(
                            one_set_transfer, 0.5)
                        one_set_cycle_binary = to_binary(one_set_cycle, 0.5)

                        one_set_transfer_binary = one_set_transfer_binary.reshape(
                            -1, one_set_transfer_binary.shape[2],
                            one_set_transfer_binary.shape[3],
                            one_set_transfer_binary.shape[1])
                        one_set_cycle_binary = one_set_cycle_binary.reshape(
                            -1, one_set_cycle_binary.shape[2],
                            one_set_cycle_binary.shape[3],
                            one_set_cycle_binary.shape[1])

                        print(one_set_transfer_binary.shape,
                              one_set_cycle_binary.shape)

                        name_origin = f'{style}-{target}_iter{i+1}_{filename}_origin'
                        name_transfer = f'{style}-{target}_iter{i+1}_{filename}_transfer'
                        name_cycle = f'{style}-{target}_iter{i+1}_{filename}_cycle'

                        path_samples_per_iter = os.path.join(
                            self.sample_dir, f'iter{i+1}')

                        if not os.path.exists(path_samples_per_iter):
                            os.makedirs(path_samples_per_iter)

                        path_origin = os.path.join(path_samples_per_iter,
                                                   name_origin)
                        path_transfer = os.path.join(path_samples_per_iter,
                                                     name_transfer)
                        path_cycle = os.path.join(path_samples_per_iter,
                                                  name_cycle)

                        print(
                            f'[save]:{path_origin},{path_transfer},{path_cycle}'
                        )

                        save_midis(
                            content.reshape(1, content.shape[1],
                                            content.shape[2],
                                            content.shape[0]),
                            '{}.mid'.format(path_origin))
                        save_midis(one_set_transfer_binary,
                                   '{}.mid'.format(path_transfer))
                        save_midis(one_set_cycle_binary,
                                   '{}.mid'.format(path_cycle))

            # Save model checkpoints.
            if (i + 1) % self.model_save_step == 0:
                G_path = os.path.join(self.model_save_dir,
                                      '{}-G.ckpt'.format(i + 1))
                D_path = os.path.join(self.model_save_dir,
                                      '{}-D.ckpt'.format(i + 1))
                C_path = os.path.join(self.model_save_dir,
                                      '{}-C.ckpt'.format(i + 1))
                torch.save(self.G.state_dict(), G_path)
                torch.save(self.D.state_dict(), D_path)
                torch.save(self.C.state_dict(), C_path)
                print('Saved model checkpoints into {}...'.format(
                    self.model_save_dir))

            # Decay learning rates.
            if (i + 1) % self.lr_update_step == 0 and (i + 1) > (
                    self.num_iters - self.num_iters_decay):
                g_lr -= (self.g_lr / float(self.num_iters_decay))
                d_lr -= (self.d_lr / float(self.num_iters_decay))
                c_lr -= (self.c_lr / float(self.num_iters_decay))
                self.update_lr(g_lr, d_lr, c_lr)
                print('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(
                    g_lr, d_lr))

    def reset_grad(self):
        """Reset the gradient buffers."""
        self.g_optimizer.zero_grad()
        self.d_optimizer.zero_grad()
        self.c_optimizer.zero_grad()

    def restore_model(self, resume_iters):
        """Restore the trained generator and discriminator."""
        print(
            'Loading the trained models from step {}...'.format(resume_iters))
        G_path = os.path.join(self.model_save_dir,
                              '{}-G.ckpt'.format(resume_iters))
        D_path = os.path.join(self.model_save_dir,
                              '{}-D.ckpt'.format(resume_iters))
        C_path = os.path.join(self.model_save_dir,
                              '{}-C.ckpt'.format(resume_iters))
        self.G.load_state_dict(
            torch.load(G_path, map_location=lambda storage, loc: storage))
        self.D.load_state_dict(
            torch.load(D_path, map_location=lambda storage, loc: storage))
        self.C.load_state_dict(
            torch.load(C_path, map_location=lambda storage, loc: storage))

    def test(self):
        """Translate speech using StarGAN ."""
        # Load the trained generator.
        self.restore_model(self.test_iters)

        # Set data loader.
        d, style = TestSet(self.test_dir).test_data(self.src_style)
        targets = self.trg_style

        for target in targets:
            print(target)
            assert target in styles
            label_t = self.stls_enc.transform([target])[0]
            label_t = np.asarray([label_t])

            with torch.no_grad():

                for filename, content in d.items():

                    filename = filename.split('.')[0]

                    one_seg = torch.FloatTensor(content).to(self.device)
                    one_seg = one_seg.view(1, one_seg.size(0), one_seg.size(1),
                                           one_seg.size(2))
                    l_t = torch.FloatTensor(label_t)
                    one_seg = one_seg.to(self.device)
                    l_t = l_t.to(self.device)

                    one_set_transfer = self.G(one_seg, l_t).cpu().numpy()

                    one_set_transfer_binary = to_binary(one_set_transfer, 0.5)

                    one_set_transfer_binary = one_set_transfer_binary.reshape(
                        -1, one_set_transfer_binary.shape[2],
                        one_set_transfer_binary.shape[3],
                        one_set_transfer_binary.shape[1])

                    name_origin = f'{style}-{target}_iter{i+1}_{filename}_origin'
                    name_transfer = f'{style}-{target}_iter{i+1}_{filename}_transfer'

                    path = os.path.join(self.result_dir, f'iter{i+1}')

                    path_origin = os.path.join(path, name_origin)
                    path_transfer = os.path.join(path, name_transfer)

                    print(f'[save]:{path_origin},{path_transfer}')

                    save_midis(
                        content.reshape(1, content.shape[1], content.shape[2],
                                        content.shape[0]),
                        '{}.mid'.format(path_origin))
                    save_midis(one_set_transfer_binary,
                               '{}.mid'.format(path_transfer))
Ejemplo n.º 10
0
    # 最後轉成Tensor供model使用。
    transforms.ToTensor(),
])

source_dataset = ImageFolder('real_or_drawing/train_data',
                             transform=source_transform)
target_dataset = ImageFolder('real_or_drawing/test_data',
                             transform=target_transform)

source_dataloader = DataLoader(source_dataset, batch_size=32, shuffle=True)
target_dataloader = DataLoader(target_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(target_dataset, batch_size=128, shuffle=False)

feature_extractor = FeatureExtractor().cuda()
label_predictor = LabelPredictor().cuda()
domain_classifier = DomainClassifier().cuda()

class_criterion = nn.CrossEntropyLoss()
domain_criterion = nn.BCEWithLogitsLoss()

optimizer_F = optim.Adam(feature_extractor.parameters())  #原為1e-3
optimizer_C = optim.Adam(label_predictor.parameters())
optimizer_D = optim.Adam(domain_classifier.parameters())


def train_epoch(source_dataloader, target_dataloader, lamb):
    '''
      Args:
        source_dataloader: source data的dataloader
        target_dataloader: target data的dataloader
        lamb: 調控adversarial的loss係數。
Ejemplo n.º 11
0
    transforms.Grayscale(),
    # 縮放: 因為source data是32x32,我們將target data的28x28放大成32x32。
    transforms.Resize((32, 32)),
    # 水平翻轉 (Augmentation)
    transforms.RandomHorizontalFlip(),
    # 旋轉15度內 (Augmentation),旋轉後空的地方補0
    transforms.RandomRotation(15),
    # 最後轉成Tensor供model使用。
    transforms.ToTensor(),
])

set_seed(208)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
feature_extractor = CNN_VAE().to(device)
label_predictor = LabelPredictor().to(device)
domain_classifier = DomainClassifier().to(device)

class_criterion = nn.CrossEntropyLoss()
domain_criterion = nn.BCEWithLogitsLoss()

source_dataset = ImageFolder('real_or_drawing/train_data',
                             transform=source_transform)
target_dataset = ImageFolder('real_or_drawing/test_data',
                             transform=target_transform)

source_dataloader = DataLoader(source_dataset, batch_size=32, shuffle=True)
target_dataloader = DataLoader(target_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(target_dataset, batch_size=128, shuffle=False)

train(source_dataloader,
      target_dataloader,
Ejemplo n.º 12
0
    # 水平翻轉 (Augmentation)
    transforms.RandomHorizontalFlip(),
    # 旋轉15度內 (Augmentation),旋轉後空的地方補0
    transforms.RandomRotation(15),#, fill=(0,)),
    # 最後轉成Tensor供model使用。
    transforms.ToTensor(),
])

target_dataset = ImageFolder('real_or_drawing/test_data', transform=target_transform)
test_dataloader = DataLoader(target_dataset, batch_size=128, shuffle=False)

feature_extractor = FeatureExtractor().cuda()
feature_extractor.load_state_dict(torch.load('strong1_extractor_model_1000.bin'))
label_predictor = LabelPredictor().cuda()
label_predictor.load_state_dict(torch.load('strong1_predictor_model_1000.bin'))
domain_classifier = DomainClassifier().cuda()
#domain_classifier.load_state_dict(torch.load('extractor_model_300.bin'))

feature_extractor.eval()
label_predictor.eval()
label_dict = {}
for i in range(10):
    label_dict[i] = []

for i, (test_data, _) in enumerate(test_dataloader):
    test_data = test_data.cuda()
    class_logits = label_predictor(feature_extractor(test_data))
    x = torch.argmax(class_logits, dim=1).cpu().detach().numpy()
    _y = torch.max(class_logits, dim=1)
    y = _y[0].cpu().detach().numpy()
    data = test_data.cpu().detach().numpy()