class CycleGANModel(BaseModel): def name(self): return 'CycleGANModel' def initialize(self, opt): BaseModel.initialize(self, opt) # load/define networks # The naming conversion is different from those used in the paper # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) self.netG_A = WaveGANGenerator( model_size=opt.model_size, ngpus=opt.ngpus, latent_dim=opt.latent_dim, alpha=opt.alpha, post_proc_filt_len=opt.post_proc_filt_len) self.netG_B = WaveGANGenerator( model_size=opt.model_size, ngpus=opt.ngpus, latent_dim=opt.latent_dim, alpha=opt.alpha, post_proc_filt_len=opt.post_proc_filt_len) if self.isTrain: use_sigmoid = opt.gan_loss != 'lsgan' self.netD_A = WaveGANDiscriminator(model_size=opt.model_size, ngpus=opt.ngpus, shift_factor=opt.shift_factor, alpha=opt.alpha, batch_shuffle=opt.batch_shuffle) self.netD_B = WaveGANDiscriminator(model_size=opt.model_size, ngpus=opt.ngpus, shift_factor=opt.shift_factor, alpha=opt.alpha, batch_shuffle=opt.batch_shuffle) if self.isTrain: self.fake_A_pool = AudioPool(opt.pool_size) self.fake_B_pool = AudioPool(opt.pool_size) # define loss functions self.criterionGAN = GANLoss(loss_type=opt.gan_loss, tensor=self.Tensor) self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() # initialize optimizers self.optimizer_G = torch.optim.Adam(itertools.chain( self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, opt.beta2)) self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2)) self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2)) self.optimizers = [] self.schedulers = [] self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D_A) self.optimizers.append(self.optimizer_D_B) for optimizer in self.optimizers: self.schedulers.append(get_scheduler(optimizer, opt)) LOGGER.info('---------- Networks initialized -------------') print_network(self.netG_A) print_network(self.netG_B) if self.isTrain: print_network(self.netD_A) print_network(self.netD_B) LOGGER.info('-----------------------------------------------') def set_input(self, input): AtoB = self.opt.which_direction == 'AtoB' input_A = input['A' if AtoB else 'B'] input_B = input['B' if AtoB else 'A'] if len(self.gpu_ids) > 0: input_A = input_A.cuda(self.gpu_ids[0], async=True) input_B = input_B.cuda(self.gpu_ids[0], async=True) self.input_A = input_A self.input_B = input_B self.audio_paths = input['A_paths' if AtoB else 'B_paths'] def forward(self): self.real_A = Variable(self.input_A) self.real_B = Variable(self.input_B) def test(self): real_A = Variable(self.input_A, volatile=True) fake_B = self.netG_A(real_A) self.rec_A = self.netG_B(fake_B).data self.fake_B = fake_B.data real_B = Variable(self.input_B, volatile=True) fake_A = self.netG_B(real_B) self.rec_B = self.netG_A(fake_A).data self.fake_A = fake_A.data # get audio paths def get_audio_paths(self): return self.audio_paths def backward_D_basic(self, netD, real, fake): # Real pred_real = netD(real) loss_D_real = self.criterionGAN(pred_real, True) # Fake pred_fake = netD(fake.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss loss_D = (loss_D_real + loss_D_fake) * 0.5 # backward loss_D.backward() return loss_D def backward_D_wp(self, netD, real, fake): # Gradient penalty loss for WGAN-WP loss_D_wp = calc_gradient_penalty(netD, real, fake, self.opt.batchSize, self.opt.lambda_wp, use_cuda=len(self.gpu_ids) > 0) loss_D_wp.backward() return loss_D_wp def backward_D_A(self): fake_B = self.fake_B_pool.query(self.fake_B) loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) self.loss_D_A = loss_D_A.data[0] if self.opt.gan_loss == 'wgan-wp': loss_D_wp_A = self.backward_D_wp(self.netD_A, self.real_B, fake_B) self.loss_D_wp_A = loss_D_wp_A.data[0] else: self.loss_D_wp_A = 0 def backward_D_B(self): fake_A = self.fake_A_pool.query(self.fake_A) loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) self.loss_D_B = loss_D_B.data[0] if self.opt.gan_loss == 'wgan-wp': loss_D_wp_B = self.backward_D_wp(self.netD_B, self.real_A, fake_A) self.loss_D_wp_B = loss_D_wp_B.data[0] else: self.loss_D_wp_B = 0 def backward_G(self): lambda_idt = self.opt.lambda_identity lambda_A = self.opt.lambda_A lambda_B = self.opt.lambda_B # Identity loss if lambda_idt > 0: # G_A should be identity if real_B is fed. idt_A = self.netG_A(self.real_B) loss_idt_A = self.criterionIdt(idt_A, self.real_B) * lambda_B * lambda_idt # G_B should be identity if real_A is fed. idt_B = self.netG_B(self.real_A) loss_idt_B = self.criterionIdt(idt_B, self.real_A) * lambda_A * lambda_idt self.idt_A = idt_A.data self.idt_B = idt_B.data self.loss_idt_A = loss_idt_A.data[0] self.loss_idt_B = loss_idt_B.data[0] else: loss_idt_A = 0 loss_idt_B = 0 self.loss_idt_A = 0 self.loss_idt_B = 0 # GAN loss D_A(G_A(A)) fake_B = self.netG_A(self.real_A) pred_fake = self.netD_A(fake_B) loss_G_A = self.criterionGAN(pred_fake, True) # GAN loss D_B(G_B(B)) fake_A = self.netG_B(self.real_B) pred_fake = self.netD_B(fake_A) loss_G_B = self.criterionGAN(pred_fake, True) # Forward cycle loss rec_A = self.netG_B(fake_B) loss_cycle_A = self.criterionCycle(rec_A, self.real_A) * lambda_A # Backward cycle loss rec_B = self.netG_A(fake_A) loss_cycle_B = self.criterionCycle(rec_B, self.real_B) * lambda_B # combined loss loss_G = loss_G_A + loss_G_B + loss_cycle_A + loss_cycle_B + loss_idt_A + loss_idt_B loss_G.backward() self.fake_B = fake_B.data self.fake_A = fake_A.data self.rec_A = rec_A.data self.rec_B = rec_B.data self.loss_G_A = loss_G_A.data[0] self.loss_G_B = loss_G_B.data[0] self.loss_cycle_A = loss_cycle_A.data[0] self.loss_cycle_B = loss_cycle_B.data[0] def optimize_parameters(self): # forward self.forward() # G_A and G_B self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() # D_A self.optimizer_D_A.zero_grad() self.backward_D_A() self.optimizer_D_A.step() # D_B self.optimizer_D_B.zero_grad() self.backward_D_B() self.optimizer_D_B.step() def get_current_errors(self): ret_errors = OrderedDict([('D_A', self.loss_D_A), ('G_A', self.loss_G_A), ('Cyc_A', self.loss_cycle_A), ('D_B', self.loss_D_B), ('G_B', self.loss_G_B), ('Cyc_B', self.loss_cycle_B)]) if self.opt.lambda_identity > 0.0: ret_errors['idt_A'] = self.loss_idt_A ret_errors['idt_B'] = self.loss_idt_B if self.opt.gan_loss == 'wgan-wp': ret_errors['WP_A'] = self.loss_D_wp_A ret_errors['WP_B'] = self.loss_D_wp_B return ret_errors def get_current_audibles(self): real_A = tensor2audio(self.input_A) fake_B = tensor2audio(self.fake_B) rec_A = tensor2audio(self.rec_A) real_B = tensor2audio(self.input_B) fake_A = tensor2audio(self.fake_A) rec_B = tensor2audio(self.rec_B) ret_visuals = OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B)]) if self.opt.isTrain and self.opt.lambda_identity > 0.0: ret_visuals['idt_A'] = tensor2audio(self.idt_A) ret_visuals['idt_B'] = tensor2audio(self.idt_B) return ret_visuals def save(self, label): self.save_network(self.netG_A, 'G_A', label, self.gpu_ids) self.save_network(self.netD_A, 'D_A', label, self.gpu_ids) self.save_network(self.netG_B, 'G_B', label, self.gpu_ids) self.save_network(self.netD_B, 'D_B', label, self.gpu_ids)
# Dir audio_dir = args['audio_dir'] output_dir = args['output_dir'] # =============Network=============== netG = WaveGANGenerator(model_size=model_size, ngpus=ngpus, latent_dim=latent_dim, upsample=True) netD = WaveGANDiscriminator(model_size=model_size, ngpus=ngpus) if cuda: netG = torch.nn.DataParallel(netG).cuda() netD = torch.nn.DataParallel(netD).cuda() # "Two time-scale update rule"(TTUR) to update netD 4x faster than netG. optimizerG = optim.Adam(netG.parameters(), lr=args['learning_rate'], betas=(args['beta1'], args['beta2'])) optimizerD = optim.Adam(netD.parameters(), lr=args['learning_rate'], betas=(args['beta1'], args['beta2'])) # Sample noise used for generated output. sample_noise = torch.randn(args['sample_size'], latent_dim) if cuda: sample_noise = sample_noise.cuda() sample_noise_Var = autograd.Variable(sample_noise, requires_grad=False) # Save config. LOGGER.info('Saving configurations...') config_path = os.path.join(model_dir, 'config.json') with open(config_path, 'w') as f: json.dump(args, f)