class Solver(object): """训练操作的说明""" # 参数为数据加载器与配置 def __init__(self, data_loader, config): # 进行赋值 self.config = config self.data_loader = data_loader # 模型配置 # 赋值三个损失函数,循环损失,域分类损失,身份映射损失的衡量参数 self.lambda_cycle = config.lambda_cycle self.lambda_cls = config.lambda_cls self.lambda_identity = config.lambda_identity # 训练配置 # 数据文件路由 self.data_dir = config.data_dir # 测试文件路由 self.test_dir = config.test_dir # 批处理大小 self.batch_size = config.batch_size # 训练判别器D的总迭代次数 self.num_iters = config.num_iters # 衰减学习率的迭代次数 self.num_iters_decay = config.num_iters_decay # 生成器G的学习频率 self.g_lr = config.g_lr # 判别器D的学习频率 self.d_lr = config.d_lr # 域分类器C的学习频率 self.c_lr = config.c_lr # 每次G更新时的D更新次数 self.n_critic = config.n_critic # Adam优化器的beta1,2参数 self.beta1 = config.beta1 self.beta2 = config.beta2 # 从此步骤恢复培训 self.resume_iters = config.resume_iters # 测试配置 # 从这个步骤开始训练模型 self.test_iters = config.test_iters # ast.literal_eval为解析函数,并安全地进行类型转换 # 目标发音者 self.trg_speaker = ast.literal_eval(config.trg_speaker) # 源发音者 self.src_speaker = config.src_speaker # 其他配置 # 是否使用tensorboard记录 self.use_tensorboard = config.use_tensorboard # torch.device代表将torch.Tensor分配到的设备的对象。torch.device包含一个设备类型(‘cpu’或‘cuda’)和可选的设备序号。 # 是使用cuda还是cpu计算 self.device = torch.device( 'cuda:0' if torch.cuda.is_available() else 'cpu') # 将speakers标签列表二值化 self.spk_enc = LabelBinarizer().fit(speakers) # 字典 # 记录目录 self.log_dir = config.log_dir # 样本目录 self.sample_dir = config.sample_dir # 模型目录 self.model_save_dir = config.model_save_dir # 输出目录 self.result_dir = config.result_dir # 步长 # 记录步长 self.log_step = config.log_step # 采样步长 self.sample_step = config.sample_step # 模型保存间隔步长 self.model_save_step = config.model_save_step # 学习率更新步长 self.lr_update_step = config.lr_update_step # 建立模型与tensorboard. self.build_model() if self.use_tensorboard: # 使用tensorboard记录器 self.build_tensorboard() # 赋值三个模型器 def build_model(self): # 将模型赋值给类属性 self.G = Generator() self.D = Discriminator() self.C = DomainClassifier() # torch.optim.Adam用于实现Adam算法 # Adam(Adaptive Moment Estimation)本质上是带有动量项的RMSprop,它利用梯度的一阶矩估计和二阶矩估计动态调整每个参数的学习率。 # 它的优点主要在于经过偏置校正后,每一次迭代学习率都有个确定范围,使得参数比较平稳。 # params (iterable) – 待优化参数的iterable或者是定义了参数组的dict # lr (float, 可选) – 学习率(默认:1e-3) # 同样也称为学习率或步长因子,它控制了权重的更新比率(如 0.001)。较大的值(如 0.3)在学习率更新前会有更快的初始学习,而较小的值(如 1.0E-5)会令训练收敛到更好的性能。 # betas (Tuple[float, float], 可选) – 用于计算梯度以及梯度平方的运行平均值的系数(默认:0.9,0.999)betas = (beta1,beta2) # beta1:一阶矩估计的指数衰减率(如 0.9)。 # beta2:二阶矩估计的指数衰减率(如 0.999)。该超参数在稀疏梯度(如在 NLP 或计算机视觉任务中)中应该设置为接近 1 的数。 # eps (float, 可选) – 为了增加数值计算的稳定性而加到分母里的项(默认:1e-8)epsilon:该参数是非常小的数,其为了防止在实现中除以零(如 10E-8)。 # weight_decay (float, 可选) – 权重衰减(L2级惩罚)(默认: 0) self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2]) self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr, [self.beta1, self.beta2]) self.c_optimizer = torch.optim.Adam(self.C.parameters(), self.c_lr, [self.beta1, self.beta2]) self.print_network(self.G, 'G') self.print_network(self.D, 'D') self.print_network(self.C, 'C') # Module.to方法用来移动和转换参数与缓冲区,类似于torch.Tensor.to,但是仅支持float类型 # 这里将模型转移到GPU运算 self.G.to(self.device) self.D.to(self.device) self.C.to(self.device) def print_network(self, model, name): """打印出网络的相关信息""" num_params = 0 # Module.parameters()获取网络的参数 # 计算模型网络对应的参数频次 for p in model.parameters(): # numel返回数组中元素的个数 num_params += p.numel() print(model) print(name) print("参数的个数为:{}".format(num_params)) # 使用tensorboard def build_tensorboard(self): """建立一个tensorboard记录器""" from logger import Logger # 建立一个记录器,传入记录地址 self.logger = Logger(self.log_dir) def update_lr(self, g_lr, d_lr, c_lr): """生成器、判别器和域分类器的衰减学习率""" for param_group in self.g_optimizer.param_groups: param_group['lr'] = g_lr for param_group in self.d_optimizer.param_groups: param_group['lr'] = d_lr for param_group in self.c_optimizer.param_groups: param_group['lr'] = c_lr # 训练方法 def train(self): # 衰减的学习率缓存 g_lr = self.g_lr d_lr = self.d_lr c_lr = self.c_lr # 开始训练步骤数为0 start_iters = 0 # 如果存在就跳过 if self.resume_iters: pass # 调用定义的个性化标准化方法 norm = Normalizer() # iter用来生成迭代器,这里用来迭代加载数据集 data_iter = iter(self.data_loader) print('开始训练......') # 记录当前时间,now函数取当前时间 start_time = datetime.now() # 利用总迭代次数来进行遍历 for i in range(start_iters, self.num_iters): # =================================================================================== # # 1.预处理输入数据 # # =================================================================================== # # 获取真实的图像和对应标签标签 try: # next方法为迭代下一个迭代器 # 利用自定义的加载器获取真实x值,发音者标签在组中索引与源标签 x_real, speaker_idx_org, label_org = next(data_iter) except: # 如果迭代器有问题就再转换为迭代器一次然后迭代 data_iter = iter(self.data_loader) x_real, speaker_idx_org, label_org = next(data_iter) # 随机生成目标域标签 # torch.randperm返回一个从0到参数-1范围的随机数组 # 因为标签二值化了,所以这里的标签是10组成的,所以一共有label_org.size(0)个标签 # 获得的是随机索引 rand_idx = torch.randperm(label_org.size(0)) # 根据随机数作为源标签的索引作为目标标签数 label_trg = label_org[rand_idx] # 同理得到随机目标发音者 speaker_idx_trg = speaker_idx_org[rand_idx] # to表示使用cpu或者gpu运行 x_real = x_real.to(self.device) # 输入数据 label_org = label_org.to(self.device) # 源域one-hot格式标签 label_trg = label_trg.to(self.device) # 目标域ont-hot格式标签 speaker_idx_org = speaker_idx_org.to(self.device) # 源域标签 speaker_idx_trg = speaker_idx_trg.to(self.device) # 目标域标签 # =================================================================================== # # 2.训练判别器 # # =================================================================================== # # 用真实音频数据计算损失 # nn.CrossEntropyLoss()为交叉熵损失函数,但是不是普通的形式,而是主要是将softmax-log-NLLLoss合并到一块得到的结果。 CELoss = nn.CrossEntropyLoss() # 调用分类器计算真实数据 cls_real = self.C(x_real) # 计算对应的域分类损失,即用交叉熵实现 cls_loss_real = CELoss(input=cls_real, target=speaker_idx_org) # 重置缓冲区,具体实现在下面 self.reset_grad() # tensor.backward为自动求导函数 cls_loss_real.backward() # optimizer.step这个方法会更新模型所有的参数以提升学习率,一般在backward函数后根据其计算的梯度来更新参数 self.c_optimizer.step() # 记录中 loss = {} # 从真实域分类损失张量中获取元素值 # item()得到一个元素张量里面的元素值 loss['C/C_loss'] = cls_loss_real.item() # 基于源数据的D判断结果 out_r = self.D(x_real, label_org) # 用假音频帧计算损失 # 根据真实样本与目标标签生成生成样本 x_fake = self.G(x_real, label_trg) # detach截断反向传播的梯度流,从而让梯度不影响判别器D # 基于生成样本的D判断结果 out_f = self.D(x_fake.detach(), label_trg) # torch.nn.Function.binary_cross_entropy_with_logits度量目标逻辑和输出逻辑之间的二进制交叉熵的函数 # 接受任意形状的输入,target要求与输入形状一致。切记:target的值必须在[0,N-1]之间,其中N为类别数,否则会出现莫名其妙的错误,比如loss为负数。 # 计算其实就是交叉熵,不过输入不要求在0,1之间,该函数会自动添加sigmoid运算 # 返回一个填充了标量值1的张量,其大小与输入相同。torch.ones_like(input) # 相当于torch.ones(input.size(), dtype=input.dtype, layout=input.layout, device=input.device) # binary_cross_entropy_with_logits和binary_cross_entropy的区别 # 有一个(类)损失函数名字中带了with_logits. 而这里的logits指的是,该损失函数已经内部自带了计算logit的操作, # 无需在传入给这个loss函数之前手动使用sigmoid/softmax将之前网络的输入映射到[0,1]之间 d_loss_t = F.binary_cross_entropy_with_logits(input=out_f,target=torch.zeros_like(out_f, dtype=torch.float)) + \ F.binary_cross_entropy_with_logits(input=out_r, target=torch.ones_like(out_r, dtype=torch.float)) # 生成样本的分类结果 out_cls = self.C(x_fake) # 交叉熵计算生成样本的域分类损失 d_loss_cls = CELoss(input=out_cls, target=speaker_idx_trg) # 计算梯度惩罚的损失 alpha = torch.rand(x_real.size(0), 1, 1, 1).to(self.device) # 计算x_hat # requires_grad_设置积分方法,将requires_grad是否积分的属性设置为真 # 取一个随机数混合真实样本和生成样本得到一个x尖 x_hat = (alpha * x_real.data + (1 - alpha) * x_fake.data).requires_grad_(True) # 计算混合样本和目标标签的判别结果 out_src = self.D(x_hat, label_trg) # 调用自定义方法得到处理导数后的数据 d_loss_gp = self.gradient_penalty(out_src, x_hat) # 计算判别器的总体损失 d_loss = d_loss_t + self.lambda_cls * d_loss_cls + 5 * d_loss_gp # 调用自定义方法重置梯度变化缓冲区 self.reset_grad() # 对D的损失求导 d_loss.backward() # 更新模型判别器D参数 self.d_optimizer.step() # loss['D/d_loss_t'] = d_loss_t.item() # loss['D/loss_cls'] = d_loss_cls.item() # loss['D/D_gp'] = d_loss_gp.item() # 获取判别器损失 loss['D/D_loss'] = d_loss.item() # =================================================================================== # # 3.训练生成器 # # =================================================================================== # # 进行模运算,判读更新时间 if (i + 1) % self.n_critic == 0: # 源至目标域 # 利用真实样本和目标标签生成生成样本 x_fake = self.G(x_real, label_trg) # 判别生成样本与目标标签 g_out_src = self.D(x_fake, label_trg) # 将生成与目标标签的损失与相同大小纯1张量计算交叉熵得到生成G损失 g_loss_fake = F.binary_cross_entropy_with_logits( input=g_out_src, target=torch.ones_like(g_out_src, dtype=torch.float)) # 得到真实样本通过域分类器得到的类别 out_cls = self.C(x_real) # 计算C计算类别与输入的类别的交叉熵损失即G的分类损失 g_loss_cls = CELoss(input=out_cls, target=speaker_idx_org) # 目标至源域 # 通过G将生成样本转换为源标签 x_reconst = self.G(x_fake, label_org) # 得到循环一致性损失,即通过G转回来的损失,按道理这两个是同样的 # l1_loss为L1损失函数,即平均绝对误差 g_loss_rec = F.l1_loss(x_reconst, x_real) # 源到源域(身份一致性损失). # 通过真实样本与源标签生成,按道理也是生成x_real x_fake_iden = self.G(x_real, label_org) # 利用L1损失函数计算 id_loss = F.l1_loss(x_fake_iden, x_real) # 后退和优化 # 得到生成器的总体损失函数 g_loss = g_loss_fake + self.lambda_cycle * g_loss_rec +\ self.lambda_cls * g_loss_cls + self.lambda_identity * id_loss # 重置梯度变化缓冲区 self.reset_grad() # 对G损失求导 g_loss.backward() # 更新生成器参数 self.g_optimizer.step() # 记录对应的损失 loss['G/loss_fake'] = g_loss_fake.item() loss['G/loss_rec'] = g_loss_rec.item() loss['G/loss_cls'] = g_loss_cls.item() loss['G/loss_id'] = id_loss.item() loss['G/g_loss'] = g_loss.item() # =================================================================================== # # 4.其他 # # =================================================================================== # # 打印训练相关信息 if (i + 1) % self.log_step == 0: # 得到训练时间 et = datetime.now() - start_time # 截取后面的时间段 et = str(et)[:-7] # 耗时与迭代次数 log = "耗时:[{}], 迭代次数:[{}/{}]".format(et, i + 1, self.num_iters) # 打印对应损失值 for tag, value in loss.items(): log += ", {}: {:.4f}".format(tag, value) print(log) # 如果调用tensorboard来记录训练过程 if self.use_tensorboard: for tag, value in loss.items(): # 添加到log中 self.logger.scalar_summary(tag, value, i + 1) # 翻译固定数据进行调试 if (i + 1) % self.sample_step == 0: # torch.no_grad是一个上下文管理器,被该语句包括起来的部分将不会track 梯度 # 所有依赖他的tensor会全部变成True,反向传播时就不会自动求导了,反向传播就不会保存梯度,因此大大节约了显存或者说内存。 with torch.no_grad(): # 调用自定义方法,定义一个路由,并随机选取一个发音者作为测试数据 d, speaker = TestSet(self.test_dir).test_data() # random.choice返回参数的随机项 # 随机在speakers中选择一个不是目标的发音者 target = random.choice( [x for x in speakers if x != speaker]) # 将二值化的标签组取出第一个作为目标 # LabelBinary.transfrom方法将复杂类标签转换为二进制标签 label_t = self.spk_enc.transform([target])[0] # np.asarray将python原生列表或元组形式的现有数据来创建numpy数组 label_t = np.asarray([label_t]) # 取出字典中的文件名与内容 for filename, content in d.items(): f0 = content['f0'] ap = content['ap'] # 调用自定义方法处理对应的数据 sp_norm_pad = self.pad_coded_sp( content['coded_sp_norm']) convert_result = [] for start_idx in range( 0, sp_norm_pad.shape[1] - FRAMES + 1, FRAMES): one_seg = sp_norm_pad[:, start_idx:start_idx + FRAMES] one_seg = torch.FloatTensor(one_seg).to( self.device) one_seg = one_seg.view(1, 1, one_seg.size(0), one_seg.size(1)) l = torch.FloatTensor(label_t) one_seg = one_seg.to(self.device) l = l.to(self.device) one_set_return = self.G(one_seg, l).data.cpu().numpy() one_set_return = np.squeeze(one_set_return) one_set_return = norm.backward_process( one_set_return, target) convert_result.append(one_set_return) convert_con = np.concatenate(convert_result, axis=1) convert_con = convert_con[:, 0:content['coded_sp_norm']. shape[1]] contigu = np.ascontiguousarray(convert_con.T, dtype=np.float64) decoded_sp = decode_spectral_envelope(contigu, SAMPLE_RATE, fft_size=FFTSIZE) f0_converted = norm.pitch_conversion( f0, speaker, target) wav = synthesize(f0_converted, decoded_sp, ap, SAMPLE_RATE) name = f'{speaker}-{target}_iter{i+1}_{filename}' path = os.path.join(self.sample_dir, name) print(f'[save]:{path}') librosa.output.write_wav(path, wav, SAMPLE_RATE) # 保存模型检查点 if (i + 1) % self.model_save_step == 0: G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(i + 1)) D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(i + 1)) C_path = os.path.join(self.model_save_dir, '{}-C.ckpt'.format(i + 1)) torch.save(self.G.state_dict(), G_path) torch.save(self.D.state_dict(), D_path) torch.save(self.C.state_dict(), C_path) print('Saved model checkpoints into {}...'.format( self.model_save_dir)) # 衰减学习率 if (i + 1) % self.lr_update_step == 0 and (i + 1) > ( self.num_iters - self.num_iters_decay): g_lr -= (self.g_lr / float(self.num_iters_decay)) d_lr -= (self.d_lr / float(self.num_iters_decay)) c_lr -= (self.c_lr / float(self.num_iters_decay)) self.update_lr(g_lr, d_lr, c_lr) print('Decayed learning rates, g_lr: {}, d_lr: {}.'.format( g_lr, d_lr)) def gradient_penalty(self, y, x): """计算梯度惩罚: (L2_norm(dy/dx) - 1)**2.""" # 根据标签的纬度创建出权重矩阵 weight = torch.ones(y.size()).to(self.device) # torch.autograd.grad计算并返回outputs对inputs的梯度dy/dx dydx = torch.autograd.grad( outputs=y, inputs=x, # 雅可比向量积中的“向量”,用来将梯度向量转换为梯度标量,并可以衡量y梯度的各个数据的权重 grad_outputs=weight, retain_graph=True, # 保持计算的导数图 create_graph=True, only_inputs=True)[0] # 将导数变为二维的数据并保持行数 dydx = dydx.view(dydx.size(0), -1) # 计算导数的L2范数 dydx_l2norm = torch.sqrt(torch.sum(dydx**2, dim=1)) return torch.mean((dydx_l2norm - 1)**2) def reset_grad(self): """重置梯度变化缓冲区""" self.g_optimizer.zero_grad() self.d_optimizer.zero_grad() self.c_optimizer.zero_grad() def restore_model(self, resume_iters): """重置训练好的发生器和鉴别器""" print('从{}步骤开始加载训练过的模型...'.format(resume_iters)) G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(resume_iters)) D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(resume_iters)) C_path = os.path.join(self.model_save_dir, '{}-C.ckpt'.format(resume_iters)) self.G.load_state_dict( torch.load(G_path, map_location=lambda storage, loc: storage)) self.D.load_state_dict( torch.load(D_path, map_location=lambda storage, loc: storage)) self.C.load_state_dict( torch.load(C_path, map_location=lambda storage, loc: storage)) @staticmethod def pad_coded_sp(coded_sp_norm): f_len = coded_sp_norm.shape[1] if f_len >= FRAMES: pad_length = FRAMES - (f_len - (f_len // FRAMES) * FRAMES) elif f_len < FRAMES: pad_length = FRAMES - f_len sp_norm_pad = np.hstack( (coded_sp_norm, np.zeros((coded_sp_norm.shape[0], pad_length)))) return sp_norm_pad def test(self): """用StarGAN处理音频数据""" # 加载训练生成器 self.restore_model(self.test_iters) norm = Normalizer() # 设置数据加载器 d, speaker = TestSet(self.test_dir).test_data(self.src_speaker) targets = self.trg_speaker for target in targets: print(target) assert target in speakers label_t = self.spk_enc.transform([target])[0] label_t = np.asarray([label_t]) with torch.no_grad(): for filename, content in d.items(): f0 = content['f0'] ap = content['ap'] sp_norm_pad = self.pad_coded_sp(content['coded_sp_norm']) convert_result = [] for start_idx in range(0, sp_norm_pad.shape[1] - FRAMES + 1, FRAMES): one_seg = sp_norm_pad[:, start_idx:start_idx + FRAMES] one_seg = torch.FloatTensor(one_seg).to(self.device) one_seg = one_seg.view(1, 1, one_seg.size(0), one_seg.size(1)) l = torch.FloatTensor(label_t) one_seg = one_seg.to(self.device) l = l.to(self.device) one_set_return = self.G(one_seg, l).data.cpu().numpy() one_set_return = np.squeeze(one_set_return) one_set_return = norm.backward_process( one_set_return, target) convert_result.append(one_set_return) convert_con = np.concatenate(convert_result, axis=1) convert_con = convert_con[:, 0:content['coded_sp_norm']. shape[1]] contigu = np.ascontiguousarray(convert_con.T, dtype=np.float64) decoded_sp = decode_spectral_envelope(contigu, SAMPLE_RATE, fft_size=FFTSIZE) f0_converted = norm.pitch_conversion(f0, speaker, target) wav = synthesize(f0_converted, decoded_sp, ap, SAMPLE_RATE) name = f'{speaker}-{target}_iter{self.test_iters}_{filename}' path = os.path.join(self.result_dir, name) print(f'[保存]:{path}') librosa.output.write_wav(path, wav, SAMPLE_RATE)
class Solver(object): """docstring for Solver.""" def __init__(self, data_loader, config): self.config = config self.data_loader = data_loader # Model configurations. self.lambda_cycle = config.lambda_cycle self.lambda_cls = config.lambda_cls self.lambda_identity = config.lambda_identity # Training configurations. self.data_dir = config.data_dir self.test_dir = config.test_dir self.batch_size = config.batch_size self.num_iters = config.num_iters self.num_iters_decay = config.num_iters_decay self.g_lr = config.g_lr self.d_lr = config.d_lr self.c_lr = config.c_lr self.n_critic = config.n_critic self.beta1 = config.beta1 self.beta2 = config.beta2 self.resume_iters = config.resume_iters # Test configurations. self.test_iters = config.test_iters self.trg_speaker = ast.literal_eval(config.trg_speaker) self.src_speaker = config.src_speaker # Miscellaneous. self.use_tensorboard = config.use_tensorboard self.device = torch.device( 'cuda:0' if torch.cuda.is_available() else 'cpu') self.spk_enc = LabelBinarizer().fit(speakers) # Directories. self.log_dir = config.log_dir self.sample_dir = config.sample_dir self.model_save_dir = config.model_save_dir self.result_dir = config.result_dir # Step size. self.log_step = config.log_step self.sample_step = config.sample_step self.model_save_step = config.model_save_step self.lr_update_step = config.lr_update_step # Build the model and tensorboard. self.build_model() if self.use_tensorboard: self.build_tensorboard() def build_model(self): self.G = Generator() self.D = Discriminator() self.C = DomainClassifier() self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2]) self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr, [self.beta1, self.beta2]) self.c_optimizer = torch.optim.Adam(self.C.parameters(), self.c_lr, [self.beta1, self.beta2]) self.print_network(self.G, 'G') self.print_network(self.D, 'D') self.print_network(self.C, 'C') self.G.to(self.device) self.D.to(self.device) self.C.to(self.device) def print_network(self, model, name): """Print out the network information.""" num_params = 0 for p in model.parameters(): num_params += p.numel() print(model) print(name) print("The number of parameters: {}".format(num_params)) def build_tensorboard(self): """Build a tensorboard logger.""" from logger import Logger self.logger = Logger(self.log_dir) def update_lr(self, g_lr, d_lr, c_lr): """Decay learning rates of the generator and discriminator and classifier.""" for param_group in self.g_optimizer.param_groups: param_group['lr'] = g_lr for param_group in self.d_optimizer.param_groups: param_group['lr'] = d_lr for param_group in self.c_optimizer.param_groups: param_group['lr'] = c_lr def train(self): # Learning rate cache for decaying. g_lr = self.g_lr d_lr = self.d_lr c_lr = self.c_lr start_iters = 0 if self.resume_iters: pass norm = Normalizer() data_iter = iter(self.data_loader) print('Start training......') start_time = datetime.now() for i in range(start_iters, self.num_iters): # =================================================================================== # # 1. Preprocess input data # # =================================================================================== # # Fetch real images and labels. try: x_real, speaker_idx_org, label_org = next(data_iter) except: data_iter = iter(self.data_loader) x_real, speaker_idx_org, label_org = next(data_iter) # Generate target domain labels randomly. rand_idx = torch.randperm(label_org.size(0)) label_trg = label_org[rand_idx] speaker_idx_trg = speaker_idx_org[rand_idx] x_real = x_real.to(self.device) # Input images. label_org = label_org.to( self.device) # Original domain one-hot labels. label_trg = label_trg.to( self.device) # Target domain one-hot labels. speaker_idx_org = speaker_idx_org.to( self.device) # Original domain labels speaker_idx_trg = speaker_idx_trg.to( self.device) #Target domain labels # =================================================================================== # # 2. Train the discriminator # # =================================================================================== # # Compute loss with real audio frame. CELoss = nn.CrossEntropyLoss() cls_real = self.C(x_real) cls_loss_real = CELoss(input=cls_real, target=speaker_idx_org) self.reset_grad() cls_loss_real.backward() self.c_optimizer.step() # Logging. loss = {} loss['C/C_loss'] = cls_loss_real.item() out_r = self.D(x_real, label_org) # Compute loss with fake audio frame. x_fake = self.G(x_real, label_trg) out_f = self.D(x_fake.detach(), label_trg) d_loss_t = F.binary_cross_entropy_with_logits(input=out_f,target=torch.zeros_like(out_f, dtype=torch.float)) + \ F.binary_cross_entropy_with_logits(input=out_r, target=torch.ones_like(out_r, dtype=torch.float)) out_cls = self.C(x_fake) d_loss_cls = CELoss(input=out_cls, target=speaker_idx_trg) # Compute loss for gradient penalty. alpha = torch.rand(x_real.size(0), 1, 1, 1).to(self.device) x_hat = (alpha * x_real.data + (1 - alpha) * x_fake.data).requires_grad_(True) out_src = self.D(x_hat, label_trg) d_loss_gp = self.gradient_penalty(out_src, x_hat) d_loss = d_loss_t + self.lambda_cls * d_loss_cls + 5 * d_loss_gp self.reset_grad() d_loss.backward() self.d_optimizer.step() # loss['D/d_loss_t'] = d_loss_t.item() # loss['D/loss_cls'] = d_loss_cls.item() # loss['D/D_gp'] = d_loss_gp.item() loss['D/D_loss'] = d_loss.item() # =================================================================================== # # 3. Train the generator # # =================================================================================== # if (i + 1) % self.n_critic == 0: # Original-to-target domain. x_fake = self.G(x_real, label_trg) g_out_src = self.D(x_fake, label_trg) g_loss_fake = F.binary_cross_entropy_with_logits( input=g_out_src, target=torch.ones_like(g_out_src, dtype=torch.float)) out_cls = self.C(x_fake) g_loss_cls = CELoss(input=out_cls, target=speaker_idx_trg) # Target-to-original domain. x_reconst = self.G(x_fake, label_org) g_loss_rec = F.l1_loss(x_reconst, x_real) # Original-to-Original domain(identity). x_fake_iden = self.G(x_real, label_org) id_loss = F.l1_loss(x_fake_iden, x_real) # Backward and optimize. g_loss = g_loss_fake + self.lambda_cycle * g_loss_rec +\ self.lambda_cls * g_loss_cls + self.lambda_identity * id_loss self.reset_grad() g_loss.backward() self.g_optimizer.step() # Logging. loss['G/loss_fake'] = g_loss_fake.item() loss['G/loss_rec'] = g_loss_rec.item() loss['G/loss_cls'] = g_loss_cls.item() loss['G/loss_id'] = id_loss.item() loss['G/g_loss'] = g_loss.item() # =================================================================================== # # 4. Miscellaneous # # =================================================================================== # # Print out training information. if (i + 1) % self.log_step == 0: et = datetime.now() - start_time et = str(et)[:-7] log = "Elapsed [{}], Iteration [{}/{}]".format( et, i + 1, self.num_iters) for tag, value in loss.items(): log += ", {}: {:.4f}".format(tag, value) print(log) if self.use_tensorboard: for tag, value in loss.items(): self.logger.scalar_summary(tag, value, i + 1) # Translate fixed images for debugging. if (i + 1) % self.sample_step == 0: with torch.no_grad(): d, speaker = TestSet(self.test_dir).test_data() target = random.choice( [x for x in speakers if x != speaker]) label_t = self.spk_enc.transform([target])[0] label_t = np.asarray([label_t]) for filename, content in d.items(): f0 = content['f0'] ap = content['ap'] sp_norm_pad = self.pad_coded_sp( content['coded_sp_norm']) convert_result = [] for start_idx in range( 0, sp_norm_pad.shape[1] - FRAMES + 1, FRAMES): one_seg = sp_norm_pad[:, start_idx:start_idx + FRAMES] one_seg = torch.FloatTensor(one_seg).to( self.device) one_seg = one_seg.view(1, 1, one_seg.size(0), one_seg.size(1)) l = torch.FloatTensor(label_t) one_seg = one_seg.to(self.device) l = l.to(self.device) one_set_return = self.G(one_seg, l).data.cpu().numpy() one_set_return = np.squeeze(one_set_return) one_set_return = norm.backward_process( one_set_return, target) convert_result.append(one_set_return) convert_con = np.concatenate(convert_result, axis=1) convert_con = convert_con[:, 0:content['coded_sp_norm']. shape[1]] contigu = np.ascontiguousarray(convert_con.T, dtype=np.float64) decoded_sp = decode_spectral_envelope(contigu, SAMPLE_RATE, fft_size=FFTSIZE) f0_converted = norm.pitch_conversion( f0, speaker, target) wav = synthesize(f0_converted, decoded_sp, ap, SAMPLE_RATE) name = f'{speaker}-{target}_iter{i+1}_{filename}' path = os.path.join(self.sample_dir, name) print(f'[save]:{path}') librosa.output.write_wav(path, wav, SAMPLE_RATE) # Save model checkpoints. if (i + 1) % self.model_save_step == 0: G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(i + 1)) D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(i + 1)) C_path = os.path.join(self.model_save_dir, '{}-C.ckpt'.format(i + 1)) torch.save(self.G.state_dict(), G_path) torch.save(self.D.state_dict(), D_path) torch.save(self.C.state_dict(), C_path) print('Saved model checkpoints into {}...'.format( self.model_save_dir)) # Decay learning rates. if (i + 1) % self.lr_update_step == 0 and (i + 1) > ( self.num_iters - self.num_iters_decay): g_lr -= (self.g_lr / float(self.num_iters_decay)) d_lr -= (self.d_lr / float(self.num_iters_decay)) c_lr -= (self.c_lr / float(self.num_iters_decay)) self.update_lr(g_lr, d_lr, c_lr) print('Decayed learning rates, g_lr: {}, d_lr: {}.'.format( g_lr, d_lr)) def gradient_penalty(self, y, x): """Compute gradient penalty: (L2_norm(dy/dx) - 1)**2.""" weight = torch.ones(y.size()).to(self.device) dydx = torch.autograd.grad(outputs=y, inputs=x, grad_outputs=weight, retain_graph=True, create_graph=True, only_inputs=True)[0] dydx = dydx.view(dydx.size(0), -1) dydx_l2norm = torch.sqrt(torch.sum(dydx**2, dim=1)) return torch.mean((dydx_l2norm - 1)**2) def reset_grad(self): """Reset the gradient buffers.""" self.g_optimizer.zero_grad() self.d_optimizer.zero_grad() self.c_optimizer.zero_grad() def restore_model(self, resume_iters): """Restore the trained generator and discriminator.""" print( 'Loading the trained models from step {}...'.format(resume_iters)) G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(resume_iters)) D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(resume_iters)) C_path = os.path.join(self.model_save_dir, '{}-C.ckpt'.format(resume_iters)) self.G.load_state_dict( torch.load(G_path, map_location=lambda storage, loc: storage)) self.D.load_state_dict( torch.load(D_path, map_location=lambda storage, loc: storage)) self.C.load_state_dict( torch.load(C_path, map_location=lambda storage, loc: storage)) @staticmethod def pad_coded_sp(coded_sp_norm): f_len = coded_sp_norm.shape[1] if f_len >= FRAMES: pad_length = FRAMES - (f_len - (f_len // FRAMES) * FRAMES) elif f_len < FRAMES: pad_length = FRAMES - f_len sp_norm_pad = np.hstack( (coded_sp_norm, np.zeros((coded_sp_norm.shape[0], pad_length)))) return sp_norm_pad def test(self): """Translate speech using StarGAN .""" # Load the trained generator. self.restore_model(self.test_iters) norm = Normalizer() # Set data loader. d, speaker = TestSet(self.test_dir).test_data(self.src_speaker) targets = self.trg_speaker for target in targets: print(target) assert target in speakers label_t = self.spk_enc.transform([target])[0] label_t = np.asarray([label_t]) with torch.no_grad(): for filename, content in d.items(): f0 = content['f0'] ap = content['ap'] sp_norm_pad = self.pad_coded_sp(content['coded_sp_norm']) convert_result = [] for start_idx in range(0, sp_norm_pad.shape[1] - FRAMES + 1, FRAMES): one_seg = sp_norm_pad[:, start_idx:start_idx + FRAMES] one_seg = torch.FloatTensor(one_seg).to(self.device) one_seg = one_seg.view(1, 1, one_seg.size(0), one_seg.size(1)) l = torch.FloatTensor(label_t) one_seg = one_seg.to(self.device) l = l.to(self.device) one_set_return = self.G(one_seg, l).data.cpu().numpy() one_set_return = np.squeeze(one_set_return) one_set_return = norm.backward_process( one_set_return, target) convert_result.append(one_set_return) convert_con = np.concatenate(convert_result, axis=1) convert_con = convert_con[:, 0:content['coded_sp_norm']. shape[1]] contigu = np.ascontiguousarray(convert_con.T, dtype=np.float64) decoded_sp = decode_spectral_envelope(contigu, SAMPLE_RATE, fft_size=FFTSIZE) f0_converted = norm.pitch_conversion(f0, speaker, target) wav = synthesize(f0_converted, decoded_sp, ap, SAMPLE_RATE) name = f'{speaker}-{target}_iter{self.test_iters}_{filename}' path = os.path.join(self.result_dir, name) print(f'[save]:{path}') librosa.output.write_wav(path, wav, SAMPLE_RATE)
print('dataset done 2') source_dataloader = DataLoader(source_dataset, batch_size=32, shuffle=True) target_dataloader = DataLoader(target_dataset, batch_size=32, shuffle=True) test_dataloader = DataLoader(target_dataset, batch_size=128, shuffle=False) feature_extractor = FeatureExtractor().cuda() label_predictor = LabelPredictor().cuda() domain_classifier = DomainClassifier().cuda() class_criterion = nn.CrossEntropyLoss() domain_criterion = nn.BCEWithLogitsLoss() optimizer_F = optim.Adam(feature_extractor.parameters()) #原為1e-3 optimizer_C = optim.Adam(label_predictor.parameters()) optimizer_D = optim.Adam(domain_classifier.parameters()) def train_epoch(source_dataloader, target_dataloader, lamb): ''' Args: source_dataloader: source data的dataloader target_dataloader: target data的dataloader lamb: 調控adversarial的loss係數。 ''' # D loss: Domain Classifier的loss # F loss: Feature Extrator & Label Predictor的loss # total_hit: 計算目前對了幾筆 total_num: 目前經過了幾筆 running_D_loss, running_F_loss = 0.0, 0.0 total_hit, total_num = 0.0, 0.0
class Solver(object): """Solver for training and testing StarGAN.""" def __init__(self, train_loader, test_loader, config): """Initialize configurations.""" # Data loader. self.train_loader = train_loader self.test_loader = test_loader self.sampling_rate = config.sampling_rate # Model configurations. self.num_speakers = config.num_speakers self.lambda_cls = config.lambda_cls self.lambda_rec = config.lambda_rec self.lambda_gp = config.lambda_gp # Training configurations. self.batch_size = config.batch_size self.num_iters = config.num_iters self.num_iters_decay = config.num_iters_decay self.g_lr = config.g_lr self.d_lr = config.d_lr self.c_lr = config.c_lr self.n_critic = config.n_critic self.beta1 = config.beta1 self.beta2 = config.beta2 self.resume_iters = config.resume_iters # Test configurations. self.test_iters = config.test_iters # Miscellaneous. self.use_tensorboard = config.use_tensorboard self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Directories. self.log_dir = config.log_dir self.sample_dir = config.sample_dir self.model_save_dir = config.model_save_dir # Step size. self.log_step = config.log_step self.sample_step = config.sample_step self.model_save_step = config.model_save_step self.lr_update_step = config.lr_update_step # Build the model and tensorboard. self.build_model() if self.use_tensorboard: self.build_tensorboard() def build_model(self): """Create a generator and a discriminator.""" self.G = Generator(num_speakers=self.num_speakers) self.D = Discriminator(num_speakers=self.num_speakers) self.C = DomainClassifier() self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2]) self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr, [self.beta1, self.beta2]) self.c_optimizer = torch.optim.Adam(self.C.parameters(), self.c_lr, [self.beta1, self.beta2]) self.print_network(self.G, 'G') self.print_network(self.D, 'D') self.print_network(self.C, 'C') self.G.to(self.device) self.D.to(self.device) self.C.to(self.device) def print_network(self, model, name): """Print out the network information.""" num_params = 0 for p in model.parameters(): num_params += p.numel() print(model) print(name) print("The number of parameters: {}".format(num_params)) def restore_model(self, resume_iters): """Restore the trained generator and discriminator.""" print('Loading the trained models from step {}...'.format(resume_iters)) G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(resume_iters)) D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(resume_iters)) C_path = os.path.join(self.model_save_dir, '{}-C.ckpt'.format(resume_iters)) self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage)) self.D.load_state_dict(torch.load(D_path, map_location=lambda storage, loc: storage)) self.C.load_state_dict(torch.load(C_path, map_location=lambda storage, loc: storage)) def build_tensorboard(self): """Build a tensorboard logger.""" from logger import Logger self.logger = Logger(self.log_dir) def update_lr(self, g_lr, d_lr, c_lr): """Decay learning rates of the generator and discriminator.""" for param_group in self.g_optimizer.param_groups: param_group['lr'] = g_lr for param_group in self.d_optimizer.param_groups: param_group['lr'] = d_lr for param_group in self.c_optimizer.param_groups: param_group['lr'] = c_lr def reset_grad(self): """Reset the gradient buffers.""" self.g_optimizer.zero_grad() self.d_optimizer.zero_grad() self.c_optimizer.zero_grad() def denorm(self, x): """Convert the range from [-1, 1] to [0, 1].""" out = (x + 1) / 2 return out.clamp_(0, 1) def gradient_penalty(self, y, x): """Compute gradient penalty: (L2_norm(dy/dx) - 1)**2.""" weight = torch.ones(y.size()).to(self.device) dydx = torch.autograd.grad(outputs=y, inputs=x, grad_outputs=weight, retain_graph=True, create_graph=True, only_inputs=True)[0] dydx = dydx.view(dydx.size(0), -1) dydx_l2norm = torch.sqrt(torch.sum(dydx**2, dim=1)) return torch.mean((dydx_l2norm-1)**2) def label2onehot(self, labels, dim): """Convert label indices to one-hot vectors.""" batch_size = labels.size(0) out = torch.zeros(batch_size, dim) out[np.arange(batch_size), labels.long()] = 1 return out def sample_spk_c(self, size): spk_c = np.random.randint(0, self.num_speakers, size=size) spk_c_cat = to_categorical(spk_c, self.num_speakers) return torch.LongTensor(spk_c), torch.FloatTensor(spk_c_cat) def classification_loss(self, logit, target): """Compute softmax cross entropy loss.""" return F.cross_entropy(logit, target) def load_wav(self, wavfile, sr=16000): wav, _ = librosa.load(wavfile, sr=sr, mono=True) return wav_padding(wav, sr=16000, frame_period=5, multiple = 4) def train(self): """Train StarGAN.""" # Set data loader. train_loader = self.train_loader data_iter = iter(train_loader) # Read a batch of testdata test_wavfiles = self.test_loader.get_batch_test_data(batch_size=4) test_wavs = [self.load_wav(wavfile) for wavfile in test_wavfiles] # Determine whether do copysynthesize when first do training-time conversion test. cpsyn_flag = [True, False][0] # f0, timeaxis, sp, ap = world_decompose(wav = wav, fs = sampling_rate, frame_period = frame_period) # Learning rate cache for decaying. g_lr = self.g_lr d_lr = self.d_lr c_lr = self.c_lr # Start training from scratch or resume training. start_iters = 0 if self.resume_iters: print("resuming step %d ..."% self.resume_iters) start_iters = self.resume_iters self.restore_model(self.resume_iters) # Start training. print('Start training...') start_time = time.time() for i in range(start_iters, self.num_iters): # =================================================================================== # # 1. Preprocess input data # # =================================================================================== # # Fetch labels. try: mc_real, spk_label_org, spk_c_org = next(data_iter) except: data_iter = iter(train_loader) mc_real, spk_label_org, spk_c_org = next(data_iter) mc_real.unsqueeze_(1) # (B, D, T) -> (B, 1, D, T) for conv2d # Generate target domain labels randomly. # spk_label_trg: int, spk_c_trg:one-hot representation spk_label_trg, spk_c_trg = self.sample_spk_c(mc_real.size(0)) mc_real = mc_real.to(self.device) # Input mc. spk_label_org = spk_label_org.to(self.device) # Original spk labels. spk_c_org = spk_c_org.to(self.device) # Original spk acc conditioning. spk_label_trg = spk_label_trg.to(self.device) # Target spk labels for classification loss for G. spk_c_trg = spk_c_trg.to(self.device) # Target spk conditioning. # =================================================================================== # # 2. Train the discriminator # # =================================================================================== # # Compute loss with real mc feats. out_src = self.D(mc_real,spk_c_org) #print("output of the discriminator for real data") #print(out_src.data.cpu().numpy()) d_loss_real = - torch.mean(out_src) out_cls_spks = self.C(mc_real) c_loss_cls_spks = self.classification_loss(out_cls_spks, spk_label_org) c_loss_cls_spks.backward() self.c_optimizer.step() # Compute loss with fake mc feats. mc_fake = self.G(mc_real, spk_c_trg) out_src = self.D(mc_fake.detach(),spk_c_trg) #print("output of the discriminator for fake data") #print(out_src.data.cpu().numpy()) d_loss_fake = torch.mean(out_src) # Compute loss for gradient penalty. alpha = torch.rand(mc_real.size(0), 1, 1, 1).to(self.device) x_hat = (alpha * mc_real.data + (1 - alpha) * mc_fake.data).requires_grad_(True) out_src = self.D(x_hat,spk_c_trg) d_loss_gp = self.gradient_penalty(out_src, x_hat) out_cls_spks = self.C(mc_fake) d_loss_cls_spks = self.classification_loss(out_cls_spks, spk_label_trg) # Backward and optimize. d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls_spks + self.lambda_gp * d_loss_gp self.reset_grad() d_loss.backward() self.d_optimizer.step() # Logging. loss = {} loss['D/loss_real'] = d_loss_real.item() loss['D/loss_fake'] = d_loss_fake.item() loss['D/loss_cls_spks'] = d_loss_cls_spks.item() loss['D/loss_gp'] = d_loss_gp.item() # =================================================================================== # # 3. Train the generator # # =================================================================================== # if (i+1) % self.n_critic == 0: # Original-to-target domain. mc_fake = self.G(mc_real, spk_c_trg) out_src = self.D(mc_fake,spk_c_trg) g_loss_fake = - torch.mean(out_src) out_cls_spks = self.C(mc_real) g_loss_cls_spks = self.classification_loss(out_cls_spks, spk_label_org) # Target-to-original domain. mc_reconst = self.G(mc_fake, spk_c_org) g_loss_rec = torch.mean(torch.abs(mc_real - mc_reconst)) # Original-to-Original domain(identity). mc_reconst_id = self.G(mc_real, spk_c_org) g_loss_id_rec = torch.mean(torch.abs(mc_real - mc_reconst_id)) # Backward and optimize. g_loss = g_loss_fake + self.lambda_rec * g_loss_rec + self.lambda_cls * g_loss_cls_spks + self.lambda_rec * g_loss_id_rec self.reset_grad() g_loss.backward() self.g_optimizer.step() # Logging. loss['G/loss_fake'] = g_loss_fake.item() loss['G/loss_rec'] = g_loss_rec.item() loss['G/loss_cls_spks'] = g_loss_cls_spks.item() # =================================================================================== # # 4. Miscellaneous # # =================================================================================== # # Print out training information. if (i+1) % self.log_step == 0: et = time.time() - start_time et = str(datetime.timedelta(seconds=et))[:-7] log = "Elapsed [{}], Iteration [{}/{}]".format(et, i+1, self.num_iters) for tag, value in loss.items(): log += ", {}: {:.4f}".format(tag, value) print(log) if self.use_tensorboard: for tag, value in loss.items(): self.logger.scalar_summary(tag, value, i+1) if (i+1) % self.sample_step == 0: sampling_rate=16000 num_mcep=36 frame_period=5 with torch.no_grad(): for idx, wav in tqdm(enumerate(test_wavs)): wav_name = basename(test_wavfiles[idx]) # print(wav_name) f0, timeaxis, sp, ap = world_decompose(wav=wav, fs=sampling_rate, frame_period=frame_period) f0_converted = pitch_conversion(f0=f0, mean_log_src=self.test_loader.logf0s_mean_src, std_log_src=self.test_loader.logf0s_std_src, mean_log_target=self.test_loader.logf0s_mean_trg, std_log_target=self.test_loader.logf0s_std_trg) coded_sp = world_encode_spectral_envelop(sp=sp, fs=sampling_rate, dim=num_mcep) coded_sp_norm = (coded_sp - self.test_loader.mcep_mean_src) / self.test_loader.mcep_std_src coded_sp_norm_tensor = torch.FloatTensor(coded_sp_norm.T).unsqueeze_(0).unsqueeze_(1).to(self.device) conds = torch.FloatTensor(self.test_loader.spk_c_trg).to(self.device) # print(conds.size()) coded_sp_converted_norm = self.G(coded_sp_norm_tensor, conds).data.cpu().numpy() coded_sp_converted = np.squeeze(coded_sp_converted_norm).T * self.test_loader.mcep_std_trg + self.test_loader.mcep_mean_trg coded_sp_converted = np.ascontiguousarray(coded_sp_converted) # decoded_sp_converted = world_decode_spectral_envelop(coded_sp = coded_sp_converted, fs = sampling_rate) wav_transformed = world_speech_synthesis(f0=f0_converted, coded_sp=coded_sp_converted, ap=ap, fs=sampling_rate, frame_period=frame_period) librosa.output.write_wav( join(self.sample_dir, str(i+1)+'-'+wav_name.split('.')[0]+'-vcto-{}'.format(self.test_loader.trg_spk)+'.wav'), wav_transformed, sampling_rate) if cpsyn_flag: wav_cpsyn = world_speech_synthesis(f0=f0, coded_sp=coded_sp, ap=ap, fs=sampling_rate, frame_period=frame_period) librosa.output.write_wav(join(self.sample_dir, 'cpsyn-'+wav_name), wav_cpsyn, sampling_rate) cpsyn_flag = False # Save model checkpoints. if (i+1) % self.model_save_step == 0: G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(i+1)) D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(i+1)) C_path = os.path.join(self.model_save_dir, '{}-C.ckpt'.format(i+1)) torch.save(self.G.state_dict(), G_path) torch.save(self.D.state_dict(), D_path) torch.save(self.C.state_dict(), C_path) print('Saved model checkpoints into {}...'.format(self.model_save_dir)) # Decay learning rates. if (i+1) % self.lr_update_step == 0 and (i+1) > (self.num_iters - self.num_iters_decay): g_lr -= (self.g_lr / float(self.num_iters_decay)) d_lr -= (self.d_lr / float(self.num_iters_decay)) c_lr -= (self.c_lr / float(self.num_iters_decay)) self.update_lr(g_lr, d_lr, c_lr) print ('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))
class Solver(object): """docstring for Solver.""" def __init__(self, data_loader, config): self.config = config self.data_loader = data_loader # Model configurations. self.lambda_cycle = config.lambda_cycle self.lambda_cls = config.lambda_cls self.lambda_identity = config.lambda_identity self.sigma_d = config.sigma_d # Training configurations. self.data_dir = config.data_dir self.test_dir = config.test_dir self.batch_size = config.batch_size self.num_iters = config.num_iters self.num_iters_decay = config.num_iters_decay self.g_lr = config.g_lr self.d_lr = config.d_lr self.c_lr = config.c_lr self.n_critic = config.n_critic self.beta1 = config.beta1 self.beta2 = config.beta2 self.resume_iters = config.resume_iters # Test configurations. self.test_iters = config.test_iters self.trg_style = ast.literal_eval(config.trg_style) self.src_style = config.src_style # Miscellaneous. self.use_tensorboard = config.use_tensorboard self.device = torch.device( 'cuda:0' if torch.cuda.is_available() else 'cpu') self.stls_enc = LabelBinarizer().fit(styles) # Directories. self.log_dir = config.log_dir self.sample_dir = config.sample_dir self.model_save_dir = config.model_save_dir self.result_dir = config.result_dir # Step size. self.log_step = config.log_step self.sample_step = config.sample_step self.model_save_step = config.model_save_step self.lr_update_step = config.lr_update_step # Build the model and tensorboard. self.build_model() if self.use_tensorboard: self.build_tensorboard() def build_model(self): self.G = Generator() self.D = Discriminator() self.C = DomainClassifier() self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2]) self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr, [self.beta1, self.beta2]) self.c_optimizer = torch.optim.Adam(self.C.parameters(), self.c_lr, [self.beta1, self.beta2]) self.print_network(self.G, 'G') self.print_network(self.D, 'D') self.print_network(self.C, 'C') self.G.to(self.device) self.D.to(self.device) self.C.to(self.device) def print_network(self, model, name): """Print out the network information.""" num_params = 0 for p in model.parameters(): num_params += p.numel() print(model) print(name) print("The number of parameters: {}".format(num_params)) def build_tensorboard(self): """Build a tensorboard logger.""" from logger import Logger self.logger = Logger(self.log_dir) def update_lr(self, g_lr, d_lr, c_lr): """Decay learning rates of the generator and discriminator and classifier.""" for param_group in self.g_optimizer.param_groups: param_group['lr'] = g_lr for param_group in self.d_optimizer.param_groups: param_group['lr'] = d_lr for param_group in self.c_optimizer.param_groups: param_group['lr'] = c_lr def train(self): # Learning rate cache for decaying. g_lr = self.g_lr d_lr = self.d_lr c_lr = self.c_lr start_iters = 0 if self.resume_iters: pass #norm = Normalizer() data_iter = iter(self.data_loader) print('Start training......') start_time = datetime.now() for i in range(start_iters, self.num_iters): # =================================================================================== # # 1. Preprocess input data # # =================================================================================== # # Fetch real images and labels. try: x_real, style_idx_org, label_org = next(data_iter) except: data_iter = iter(self.data_loader) x_real, style_idx_org, label_org = next(data_iter) #generate gaussian noise for robustness improvement gaussian_noise = self.sigma_d * torch.randn(x_real.size()) # Generate target domain labels randomly. rand_idx = torch.randperm(label_org.size(0)) label_trg = label_org[rand_idx] style_idx_trg = style_idx_org[rand_idx] x_real = x_real.to(self.device) # Input images. label_org = label_org.to( self.device) # Original domain one-hot labels. label_trg = label_trg.to( self.device) # Target domain one-hot labels. style_idx_org = style_idx_org.to( self.device) # Original domain labels style_idx_trg = style_idx_trg.to( self.device) #Target domain labels gaussian_noise = gaussian_noise.to( self.device) #gaussian noise for discriminators # =================================================================================== # # 2. Train the discriminator # # =================================================================================== # # Compute loss with real audio frame. CELoss = nn.CrossEntropyLoss() cls_real = self.C(x_real) cls_loss_real = CELoss(input=cls_real, target=style_idx_org) self.reset_grad() cls_loss_real.backward() self.c_optimizer.step() # Logging. loss = {} loss['C/C_loss'] = cls_loss_real.item() out_r = self.D(x_real + gaussian_noise, label_org) # Compute loss with fake audio frame. x_fake = self.G(x_real, label_trg) out_f = self.D(x_fake + gaussian_noise, label_trg) d_loss_t = F.mse_loss(input=out_f,target=torch.zeros_like(out_f, dtype=torch.float)) + \ F.mse_loss(input=out_r, target=torch.ones_like(out_r, dtype=torch.float)) out_cls = self.C(x_fake) d_loss_cls = CELoss(input=out_cls, target=style_idx_trg) #Compute loss for gradient penalty. #alpha = torch.rand(x_real.size(0), 1, 1, 1).to(self.device) #x_hat = (alpha * x_real.data + (1 - alpha) * x_fake.data).requires_grad_(True) #out_src = self.D(x_hat, label_trg) #d_loss_gp = self.gradient_penalty(out_src, x_hat) d_loss = d_loss_t + self.lambda_cls * d_loss_cls #\+ 5*d_loss_gp self.reset_grad() d_loss.backward() self.d_optimizer.step() # loss['D/d_loss_t'] = d_loss_t.item() # loss['D/loss_cls'] = d_loss_cls.item() # loss['D/D_gp'] = d_loss_gp.item() loss['D/D_loss'] = d_loss.item() # =================================================================================== # # 3. Train the generator # # =================================================================================== # if (i + 1) % self.n_critic == 0: # Original-to-target domain. x_fake = self.G(x_real, label_trg) g_out_src = self.D(x_fake + gaussian_noise, label_trg) g_loss_fake = F.mse_loss(input=g_out_src, target=torch.ones_like( g_out_src, dtype=torch.float)) out_cls = self.C(x_real) g_loss_cls = CELoss(input=out_cls, target=style_idx_org) # Target-to-original domain. x_reconst = self.G(x_fake, label_org) g_loss_rec = F.l1_loss(x_reconst, x_real) # Original-to-Original domain(identity). x_fake_iden = self.G(x_real, label_org) id_loss = F.l1_loss(x_fake_iden, x_real) # Backward and optimize. g_loss = g_loss_fake + self.lambda_cycle * g_loss_rec +\ self.lambda_cls * g_loss_cls + self.lambda_identity * id_loss self.reset_grad() g_loss.backward() self.g_optimizer.step() # Logging. loss['G/loss_fake'] = g_loss_fake.item() loss['G/loss_rec'] = g_loss_rec.item() loss['G/loss_cls'] = g_loss_cls.item() loss['G/loss_id'] = id_loss.item() loss['G/g_loss'] = g_loss.item() # =================================================================================== # # 4. Miscellaneous # # =================================================================================== # # Print out training information. if (i + 1) % self.log_step == 0: et = datetime.now() - start_time et = str(et)[:-7] log = "Elapsed [{}], Iteration [{}/{}]".format( et, i + 1, self.num_iters) for tag, value in loss.items(): log += ", {}: {:.4f}".format(tag, value) print(log) if self.use_tensorboard: for tag, value in loss.items(): self.logger.scalar_summary(tag, value, i + 1) # Translate fixed images for debugging. if (i + 1) % self.sample_step == 0: with torch.no_grad(): d, style = TestSet(self.test_dir).test_data() label_o = self.stls_enc.transform([style])[0] label_o = np.asarray([label_o]) target = random.choice([x for x in styles if x != style]) label_t = self.stls_enc.transform([target])[0] label_t = np.asarray([label_t]) for filename, content in d.items(): filename = filename.split('.')[0] one_seg = torch.FloatTensor(content).to(self.device) one_seg = one_seg.view(1, one_seg.size(0), one_seg.size(1), one_seg.size(2)) l_t = torch.FloatTensor(label_t) one_seg = one_seg.to(self.device) l_t = l_t.to(self.device) one_set_transfer = self.G(one_seg, l_t) l_o = torch.FloatTensor(label_o) l_o = l_o.to(self.device) one_set_cycle = self.G( one_set_transfer.to(self.device), l_o).data.cpu().numpy() one_set_transfer = one_set_transfer.data.cpu().numpy() one_set_transfer_binary = to_binary( one_set_transfer, 0.5) one_set_cycle_binary = to_binary(one_set_cycle, 0.5) one_set_transfer_binary = one_set_transfer_binary.reshape( -1, one_set_transfer_binary.shape[2], one_set_transfer_binary.shape[3], one_set_transfer_binary.shape[1]) one_set_cycle_binary = one_set_cycle_binary.reshape( -1, one_set_cycle_binary.shape[2], one_set_cycle_binary.shape[3], one_set_cycle_binary.shape[1]) print(one_set_transfer_binary.shape, one_set_cycle_binary.shape) name_origin = f'{style}-{target}_iter{i+1}_{filename}_origin' name_transfer = f'{style}-{target}_iter{i+1}_{filename}_transfer' name_cycle = f'{style}-{target}_iter{i+1}_{filename}_cycle' path_samples_per_iter = os.path.join( self.sample_dir, f'iter{i+1}') if not os.path.exists(path_samples_per_iter): os.makedirs(path_samples_per_iter) path_origin = os.path.join(path_samples_per_iter, name_origin) path_transfer = os.path.join(path_samples_per_iter, name_transfer) path_cycle = os.path.join(path_samples_per_iter, name_cycle) print( f'[save]:{path_origin},{path_transfer},{path_cycle}' ) save_midis( content.reshape(1, content.shape[1], content.shape[2], content.shape[0]), '{}.mid'.format(path_origin)) save_midis(one_set_transfer_binary, '{}.mid'.format(path_transfer)) save_midis(one_set_cycle_binary, '{}.mid'.format(path_cycle)) # Save model checkpoints. if (i + 1) % self.model_save_step == 0: G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(i + 1)) D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(i + 1)) C_path = os.path.join(self.model_save_dir, '{}-C.ckpt'.format(i + 1)) torch.save(self.G.state_dict(), G_path) torch.save(self.D.state_dict(), D_path) torch.save(self.C.state_dict(), C_path) print('Saved model checkpoints into {}...'.format( self.model_save_dir)) # Decay learning rates. if (i + 1) % self.lr_update_step == 0 and (i + 1) > ( self.num_iters - self.num_iters_decay): g_lr -= (self.g_lr / float(self.num_iters_decay)) d_lr -= (self.d_lr / float(self.num_iters_decay)) c_lr -= (self.c_lr / float(self.num_iters_decay)) self.update_lr(g_lr, d_lr, c_lr) print('Decayed learning rates, g_lr: {}, d_lr: {}.'.format( g_lr, d_lr)) def reset_grad(self): """Reset the gradient buffers.""" self.g_optimizer.zero_grad() self.d_optimizer.zero_grad() self.c_optimizer.zero_grad() def restore_model(self, resume_iters): """Restore the trained generator and discriminator.""" print( 'Loading the trained models from step {}...'.format(resume_iters)) G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(resume_iters)) D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(resume_iters)) C_path = os.path.join(self.model_save_dir, '{}-C.ckpt'.format(resume_iters)) self.G.load_state_dict( torch.load(G_path, map_location=lambda storage, loc: storage)) self.D.load_state_dict( torch.load(D_path, map_location=lambda storage, loc: storage)) self.C.load_state_dict( torch.load(C_path, map_location=lambda storage, loc: storage)) def test(self): """Translate speech using StarGAN .""" # Load the trained generator. self.restore_model(self.test_iters) # Set data loader. d, style = TestSet(self.test_dir).test_data(self.src_style) targets = self.trg_style for target in targets: print(target) assert target in styles label_t = self.stls_enc.transform([target])[0] label_t = np.asarray([label_t]) with torch.no_grad(): for filename, content in d.items(): filename = filename.split('.')[0] one_seg = torch.FloatTensor(content).to(self.device) one_seg = one_seg.view(1, one_seg.size(0), one_seg.size(1), one_seg.size(2)) l_t = torch.FloatTensor(label_t) one_seg = one_seg.to(self.device) l_t = l_t.to(self.device) one_set_transfer = self.G(one_seg, l_t).cpu().numpy() one_set_transfer_binary = to_binary(one_set_transfer, 0.5) one_set_transfer_binary = one_set_transfer_binary.reshape( -1, one_set_transfer_binary.shape[2], one_set_transfer_binary.shape[3], one_set_transfer_binary.shape[1]) name_origin = f'{style}-{target}_iter{i+1}_{filename}_origin' name_transfer = f'{style}-{target}_iter{i+1}_{filename}_transfer' path = os.path.join(self.result_dir, f'iter{i+1}') path_origin = os.path.join(path, name_origin) path_transfer = os.path.join(path, name_transfer) print(f'[save]:{path_origin},{path_transfer}') save_midis( content.reshape(1, content.shape[1], content.shape[2], content.shape[0]), '{}.mid'.format(path_origin)) save_midis(one_set_transfer_binary, '{}.mid'.format(path_transfer))