def update_core(self): opt_g = self.get_optimizer('opt_g') opt_f = self.get_optimizer('opt_f') opt_x = self.get_optimizer('opt_x') opt_y = self.get_optimizer('opt_y') self._iter += 1 # learning rate decay: TODO: weight_decay_rate of AdamW if self.is_new_epoch and self.epoch >= self.args.lrdecay_start: decay_step = self.init_alpha / self.args.lrdecay_period if opt_g.alpha > decay_step: opt_g.alpha -= decay_step if opt_f.alpha > decay_step: opt_f.alpha -= decay_step if opt_x.alpha > decay_step: opt_x.alpha -= decay_step if opt_y.alpha > decay_step: opt_y.alpha -= decay_step # get mini-batch batch_x = self.get_iterator('main').next() batch_y = self.get_iterator('train_B').next() x = Variable(self.converter(batch_x, self.device)) y = Variable(self.converter(batch_y, self.device)) ### generator # X => Y => X x_y = self.gen_g(losses.add_noise(x, sigma=self.args.noise)) if self.args.conditional_discriminator: x_y_copy = Variable(self._buffer_y.query(F.concat([x, x_y]).data)) else: x_y_copy = Variable(self._buffer_y.query(x_y.data)) x_y_x = self.gen_f(x_y) loss_cycle_x = losses.loss_avg(x_y_x, x, ksize=self.args.cycle_ksize, norm='l1') loss_gen_g_adv = 0 if self.args.gen_start < self._iter: if self.args.conditional_discriminator: if self.args.wgan: loss_gen_g_adv = -F.average(self.dis_y(F.concat([x, x_y]))) else: loss_gen_g_adv = losses.loss_func_comp( self.dis_y(F.concat([x, x_y])), 1.0) else: if self.args.wgan: loss_gen_g_adv = -F.average(self.dis_y(x_y)) else: loss_gen_g_adv = losses.loss_func_comp( self.dis_y(x_y), 1.0) # Y => X => Y loss_gen_f_adv = 0 y_x = self.gen_f(losses.add_noise( y, sigma=self.args.noise)) # noise injection if self.args.conditional_discriminator: y_x_copy = Variable(self._buffer_x.query(F.concat([y, y_x]).data)) else: y_x_copy = Variable(self._buffer_x.query(y_x.data)) y_x_y = self.gen_g(y_x) loss_cycle_y = losses.loss_avg(y_x_y, y, ksize=self.args.cycle_ksize, norm='l1') if self.args.gen_start < self._iter: if self.args.conditional_discriminator: if self.args.wgan: loss_gen_f_adv = -F.average(self.dis_x(F.concat([y, y_x]))) else: loss_gen_f_adv = losses.loss_func_comp( self.dis_x(F.concat([y, y_x])), 1.0) else: if self.args.wgan: loss_gen_f_adv = -F.average(self.dis_x(y_x)) else: loss_gen_f_adv = losses.loss_func_comp( self.dis_x(y_x), 1.0) ## total loss for generators loss_gen = (self.args.lambda_dis_y * loss_gen_g_adv + self.args.lambda_dis_x * loss_gen_f_adv) + ( self.args.lambda_A * loss_cycle_x + self.args.lambda_B * loss_cycle_y) ## idempotence: f shouldn't change x if self.args.lambda_idempotence > 0: loss_idem_x = F.mean_absolute_error(y_x, self.gen_f(y_x)) loss_idem_y = F.mean_absolute_error(x_y, self.gen_g(x_y)) loss_gen = loss_gen + self.args.lambda_idempotence * (loss_idem_x + loss_idem_y) if self.report_start < self._iter: chainer.report({'loss_idem': loss_idem_x}, self.gen_f) chainer.report({'loss_idem': loss_idem_y}, self.gen_g) if self.args.lambda_domain > 0: loss_dom_x = F.mean_absolute_error(x, self.gen_f(x)) loss_dom_y = F.mean_absolute_error(y, self.gen_g(y)) if self._iter < self.args.warmup: loss_gen = loss_gen + max(self.args.lambda_domain, 1.0) * (loss_dom_x + loss_dom_y) else: loss_gen = loss_gen + self.args.lambda_domain * (loss_dom_x + loss_dom_y) if self.report_start < self._iter: chainer.report({'loss_dom': loss_dom_x}, self.gen_f) chainer.report({'loss_dom': loss_dom_y}, self.gen_g) ## images before/after conversion should look similar in terms of perceptual loss if self.args.lambda_identity_x > 0: loss_id_x = losses.loss_perceptual(x, x_y, self.vgg) loss_gen = loss_gen + self.args.lambda_identity_x * loss_id_x if self.report_start < self._iter: chainer.report({'loss_id': 1e-3 * loss_id_x}, self.gen_g) if self.args.lambda_identity_y > 0: loss_id_y = losses.loss_perceptual(y, y_x, self.vgg) loss_gen = loss_gen + self.args.lambda_identity_y * loss_id_y if self.report_start < self._iter: chainer.report({'loss_id': 1e-3 * loss_id_y}, self.gen_f) ## warm-up if self._iter < self.args.warmup: loss_gen = loss_gen + F.mean_squared_error(x, self.gen_f(x)) loss_gen = loss_gen + F.mean_squared_error(y, self.gen_g(y)) # loss_gen = loss_gen + losses.loss_avg(y,y_x, ksize=self.args.id_ksize, norm='l2') # loss_gen = loss_gen + losses.loss_avg(x,x_y, ksize=self.args.id_ksize, norm='l2') ## background should be preserved if self.args.lambda_air > 0: loss_air_x = losses.loss_range_comp(x, x_y, 0.9, norm='l2') loss_air_y = losses.loss_range_comp(y, y_x, 0.9, norm='l2') loss_gen = loss_gen + self.args.lambda_air * (loss_air_x + loss_air_y) if self.report_start < self._iter: chainer.report({'loss_air': 0.1 * loss_air_x}, self.gen_g) chainer.report({'loss_air': 0.1 * loss_air_y}, self.gen_f) ## comparison of images before/after conversion in the gradient domain if self.args.lambda_grad > 0: loss_grad_x = losses.loss_grad(x, x_y, self.args.grad_norm) loss_grad_y = losses.loss_grad(y, y_x, self.args.grad_norm) loss_gen = loss_gen + self.args.lambda_grad * (loss_grad_x + loss_grad_y) if self.report_start < self._iter: chainer.report({'loss_grad': loss_grad_x}, self.gen_g) chainer.report({'loss_grad': loss_grad_y}, self.gen_f) ## total variation if self.args.lambda_tv > 0: loss_tv = losses.total_variation(x_y, self.args.tv_tau) loss_gen = loss_gen + self.args.lambda_tv * loss_tv if self.report_start < self._iter: chainer.report({'loss_tv': loss_tv}, self.gen_g) if self.report_start < self._iter: chainer.report({'loss_cycle_x': loss_cycle_x}, self.gen_f) chainer.report({'loss_adv': loss_gen_g_adv}, self.gen_g) chainer.report({'loss_cycle_y': loss_cycle_y}, self.gen_g) chainer.report({'loss_adv': loss_gen_f_adv}, self.gen_f) self.gen_f.cleargrads() self.gen_g.cleargrads() loss_gen.backward() opt_f.update() opt_g.update() ### discriminator for t in range(self.args.n_critics): if self.args.wgan: ## synthesised -, real + loss_dis_y_fake = F.average(self.dis_y(x_y_copy)) eps = self.xp.random.uniform(0, 1, size=len(batch_y)).astype( self.xp.float32)[:, None, None, None] if self.args.conditional_discriminator: loss_dis_y_real = -F.average(self.dis_y(F.concat([x, y]))) y_mid = eps * F.concat([y, x]) + (1.0 - eps) * x_y_copy else: loss_dis_y_real = -F.average(self.dis_y(y)) y_mid = eps * y + (1.0 - eps) * x_y_copy # gradient penalty gd_y, = chainer.grad([self.dis_y(y_mid)], [y_mid], enable_double_backprop=True) gd_y = F.sqrt(F.batch_l2_norm_squared(gd_y) + 1e-6) loss_dis_y_gp = F.mean_squared_error( gd_y, self.xp.ones_like(gd_y.data)) if self.report_start < self._iter: chainer.report({'loss_real': loss_dis_y_real}, self.dis_y) chainer.report({'loss_fake': loss_dis_y_fake}, self.dis_y) chainer.report( {'loss_gp': self.args.lambda_wgan_gp * loss_dis_y_gp}, self.dis_y) loss_dis_y = (loss_dis_y_real + loss_dis_y_fake + self.args.lambda_wgan_gp * loss_dis_y_gp) self.dis_y.cleargrads() loss_dis_y.backward() opt_y.update() ## discriminator for B=>A loss_dis_x_fake = F.average(self.dis_x(y_x_copy)) if self.args.conditional_discriminator: loss_dis_x_real = -F.average( self.dis_x( losses.add_noise(F.concat([y, x]), sigma=self.args.noise))) x_mid = eps * F.concat([x, y]) + (1.0 - eps) * y_x_copy else: loss_dis_x_real = -F.average(self.dis_x(x)) x_mid = eps * x + (1.0 - eps) * y_x_copy # gradient penalty gd_x, = chainer.grad([self.dis_x(x_mid)], [x_mid], enable_double_backprop=True) gd_x = F.sqrt(F.batch_l2_norm_squared(gd_x) + 1e-6) loss_dis_x_gp = F.mean_squared_error( gd_x, self.xp.ones_like(gd_x.data)) if self.report_start < self._iter: chainer.report({'loss_real': loss_dis_x_real}, self.dis_x) chainer.report({'loss_fake': loss_dis_x_fake}, self.dis_x) chainer.report( {'loss_gp': self.args.lambda_wgan_gp * loss_dis_x_gp}, self.dis_x) loss_dis_x = (loss_dis_x_real + loss_dis_x_fake + self.args.lambda_wgan_gp * loss_dis_x_gp) self.dis_x.cleargrads() loss_dis_x.backward() opt_x.update() else: ## discriminator for A=>B (real:1, fake:0) loss_dis_y_fake = losses.loss_func_comp( self.dis_y(x_y_copy), 0.0, self.args.dis_jitter) if self.args.conditional_discriminator: loss_dis_y_real = losses.loss_func_comp( self.dis_y(F.concat([x, y])), 1.0, self.args.dis_jitter) else: loss_dis_y_real = losses.loss_func_comp( self.dis_y(y), 1.0, self.args.dis_jitter) loss_dis_y = (loss_dis_y_fake + loss_dis_y_real) * 0.5 self.dis_y.cleargrads() loss_dis_y.backward() opt_y.update() ## discriminator for B=>A loss_dis_x_fake = losses.loss_func_comp( self.dis_x(y_x_copy), 0.0, self.args.dis_jitter) if self.args.conditional_discriminator: loss_dis_x_real = losses.loss_func_comp( self.dis_x(F.concat([y, x])), 1.0, self.args.dis_jitter) else: loss_dis_x_real = losses.loss_func_comp( self.dis_x(x), 1.0, self.args.dis_jitter) loss_dis_x = (loss_dis_x_fake + loss_dis_x_real) * 0.5 self.dis_x.cleargrads() loss_dis_x.backward() opt_x.update() if self.report_start < self._iter: chainer.report({'loss_real': loss_dis_x_real}, self.dis_x) chainer.report({'loss_fake': loss_dis_x_fake}, self.dis_x) chainer.report({'loss_real': loss_dis_y_real}, self.dis_y) chainer.report({'loss_fake': loss_dis_y_fake}, self.dis_y) # prepare next images if (t < self.args.n_critics - 1): x_y_copy = Variable( self.xp.concatenate( random.sample(self._buffer_y.images, self.args.batch_size))) y_x_copy = Variable( self.xp.concatenate( random.sample(self._buffer_x.images, self.args.batch_size))) batch_x = self.get_iterator('main').next() batch_y = self.get_iterator('train_B').next() x = Variable(self.converter(batch_x, self.device)) y = Variable(self.converter(batch_y, self.device))
def update_core(self): opt_enc_x = self.get_optimizer('opt_enc_x') opt_dec_x = self.get_optimizer('opt_dec_x') opt_enc_y = self.get_optimizer('opt_enc_y') opt_dec_y = self.get_optimizer('opt_dec_y') opt_x = self.get_optimizer('opt_x') opt_y = self.get_optimizer('opt_y') opt_z = self.get_optimizer('opt_z') # get mini-batch batch_x = self.get_iterator('main').next() batch_y = self.get_iterator('train_B').next() x = Variable(self.converter(batch_x, self.args.gpu[0])) y = Variable(self.converter(batch_y, self.args.gpu[0])) # encode to latent (X,Y => Z) x_z = self.enc_x(losses.add_noise(x, sigma=self.args.noise)) y_z = self.enc_y(losses.add_noise(y, sigma=self.args.noise)) loss_gen = 0 ## regularisation on the latent space if self.args.lambda_reg > 0: loss_reg_enc_y = losses.loss_func_reg(y_z[-1], 'l2') loss_reg_enc_x = losses.loss_func_reg(x_z[-1], 'l2') loss_gen = loss_gen + self.args.lambda_reg * (loss_reg_enc_x + loss_reg_enc_y) chainer.report({'loss_reg': loss_reg_enc_x}, self.enc_x) chainer.report({'loss_reg': loss_reg_enc_y}, self.enc_y) ## discriminator for the latent space: distribution of image of enc_x should look same as that of enc_y # since z is a list (for u-net), we use only the output of the last layer if self.args.lambda_dis_z > 0: if self.args.dis_wgan: loss_enc_x_adv = -F.average(self.dis_z(x_z[-1])) loss_enc_y_adv = F.average(self.dis_z(y_z[-1])) else: loss_enc_x_adv = losses.loss_func_comp(self.dis_z(x_z[-1]), 1.0) loss_enc_y_adv = losses.loss_func_comp(self.dis_z(y_z[-1]), 0.0) loss_gen = loss_gen + self.args.lambda_dis_z * (loss_enc_x_adv + loss_enc_y_adv) chainer.report({'loss_adv': loss_enc_x_adv}, self.enc_x) chainer.report({'loss_adv': loss_enc_y_adv}, self.enc_y) # cycle for X=>Z=>X (Autoencoder) x_x = self.dec_x(x_z) loss_cycle_xzx = F.mean_absolute_error(x_x, x) chainer.report({'loss_cycle': loss_cycle_xzx}, self.enc_x) # cycle for Y=>Z=>Y (Autoencoder) y_y = self.dec_y(y_z) loss_cycle_yzy = F.mean_absolute_error(y_y, y) chainer.report({'loss_cycle': loss_cycle_yzy}, self.enc_y) loss_gen = loss_gen + self.args.lambda_Az * loss_cycle_xzx + self.args.lambda_Bz * loss_cycle_yzy ## decode from latent Z => Y,X x_y = self.dec_y(x_z) y_x = self.dec_x(y_z) # cycle for X=>Z=>Y=>Z=>X (Z=>Y=>Z does not work well) x_y_x = self.dec_x(self.enc_y(x_y)) loss_cycle_x = F.mean_absolute_error(x_y_x, x) chainer.report({'loss_cycle': loss_cycle_x}, self.dec_x) # cycle for Y=>Z=>X=>Z=>Y y_x_y = self.dec_y(self.enc_x(y_x)) loss_cycle_y = F.mean_absolute_error(y_x_y, y) chainer.report({'loss_cycle': loss_cycle_y}, self.dec_y) loss_gen = loss_gen + self.args.lambda_A * loss_cycle_x + self.args.lambda_B * loss_cycle_y ## adversarial for Y if self.args.lambda_dis_y > 0: x_y_copy = Variable(self._buffer_y.query(x_y.data)) if self.args.dis_wgan: loss_dec_y_adv = -F.average(self.dis_y(x_y)) else: loss_dec_y_adv = losses.loss_func_comp(self.dis_y(x_y), 1.0) loss_gen = loss_gen + self.args.lambda_dis_y * loss_dec_y_adv chainer.report({'loss_adv': loss_dec_y_adv}, self.dec_y) ## adversarial for X if self.args.lambda_dis_x > 0: y_x_copy = Variable(self._buffer_x.query(y_x.data)) if self.args.dis_wgan: loss_dec_x_adv = -F.average(self.dis_x(y_x)) else: loss_dec_x_adv = losses.loss_func_comp(self.dis_x(y_x), 1.0) loss_gen = loss_gen + self.args.lambda_dis_x * loss_dec_x_adv chainer.report({'loss_adv': loss_dec_x_adv}, self.dec_x) ## idempotence if self.args.lambda_idempotence > 0: loss_idem_x = F.mean_absolute_error(y_x, self.dec_x(self.enc_y(y_x))) loss_idem_y = F.mean_absolute_error(x_y, self.dec_y(self.enc_x(x_y))) loss_gen = loss_gen + self.args.lambda_idempotence * (loss_idem_x + loss_idem_y) chainer.report({'loss_idem': loss_idem_x}, self.dec_x) chainer.report({'loss_idem': loss_idem_y}, self.dec_y) # Y => X shouldn't change X if self.args.lambda_domain > 0: loss_dom_x = F.mean_absolute_error(x, self.dec_x(self.enc_y(x))) loss_dom_y = F.mean_absolute_error(y, self.dec_y(self.enc_x(y))) loss_gen = loss_gen + self.args.lambda_domain * (loss_dom_x + loss_dom_y) chainer.report({'loss_dom': loss_dom_x}, self.dec_x) chainer.report({'loss_dom': loss_dom_y}, self.dec_y) ## images before/after conversion should look similar in terms of perceptual loss if self.args.lambda_identity_x > 0: loss_id_x = losses.loss_perceptual( x, x_y, self.vgg, layer=self.args.perceptual_layer, grey=self.args.grey) loss_gen = loss_gen + self.args.lambda_identity_x * loss_id_x chainer.report({'loss_id': 1e-3 * loss_id_x}, self.enc_x) if self.args.lambda_identity_y > 0: loss_id_y = losses.loss_perceptual( y, y_x, self.vgg, layer=self.args.perceptual_layer, grey=self.args.grey) loss_gen = loss_gen + self.args.lambda_identity_y * loss_id_y chainer.report({'loss_id': 1e-3 * loss_id_y}, self.enc_y) ## background (pixels with value -1) should be preserved if self.args.lambda_air > 0: loss_air_x = losses.loss_comp_low(x, x_y, self.args.air_threshold, norm='l2') loss_air_y = losses.loss_comp_low(y, y_x, self.args.air_threshold, norm='l2') loss_gen = loss_gen + self.args.lambda_air * (loss_air_x + loss_air_y) chainer.report({'loss_air': loss_air_x}, self.dec_y) chainer.report({'loss_air': loss_air_y}, self.dec_x) ## images before/after conversion should look similar in the gradient domain if self.args.lambda_grad > 0: loss_grad_x = losses.loss_grad(x, x_y) loss_grad_y = losses.loss_grad(y, y_x) loss_gen = loss_gen + self.args.lambda_grad * (loss_grad_x + loss_grad_y) chainer.report({'loss_grad': loss_grad_x}, self.dec_y) chainer.report({'loss_grad': loss_grad_y}, self.dec_x) ## total variation (only for X -> Y) if self.args.lambda_tv > 0: loss_tv = losses.total_variation(x_y, tau=self.args.tv_tau, method=self.args.tv_method) if self.args.imgtype == "dcm" and self.args.num_slices > 1: loss_tv += losses.total_variation_ch(x_y) loss_gen = loss_gen + self.args.lambda_tv * loss_tv chainer.report({'loss_tv': loss_tv}, self.dec_y) ## back propagate self.enc_x.cleargrads() self.dec_x.cleargrads() self.enc_y.cleargrads() self.dec_y.cleargrads() loss_gen.backward() opt_enc_x.update(loss=loss_gen) opt_dec_x.update(loss=loss_gen) if not self.args.single_encoder: opt_enc_y.update(loss=loss_gen) opt_dec_y.update(loss=loss_gen) ########################################## ## discriminator for Y if self.args.dis_wgan: ## synthesised -, real + eps = self.xp.random.uniform(0, 1, size=len(batch_y)).astype( self.xp.float32)[:, None, None, None] if self.args.lambda_dis_y > 0: ## discriminator for X=>Y loss_dis_y = F.average(self.dis_y(x_y_copy) - self.dis_y(y)) y_mid = eps * y + (1.0 - eps) * x_y_copy # gradient penalty gd_y, = chainer.grad([self.dis_y(y_mid)], [y_mid], enable_double_backprop=True) gd_y = F.sqrt(F.batch_l2_norm_squared(gd_y) + 1e-6) loss_dis_y_gp = F.mean_squared_error( gd_y, self.xp.ones_like(gd_y.data)) chainer.report({'loss_dis': loss_dis_y}, self.dis_y) chainer.report( {'loss_gp': self.args.lambda_wgan_gp * loss_dis_y_gp}, self.dis_y) loss_dis_y = loss_dis_y + self.args.lambda_wgan_gp * loss_dis_y_gp self.dis_y.cleargrads() loss_dis_y.backward() opt_y.update(loss=loss_dis_y) if self.args.lambda_dis_x > 0: ## discriminator for B=>A loss_dis_x = F.average(self.dis_x(y_x_copy) - self.dis_x(x)) x_mid = eps * x + (1.0 - eps) * y_x_copy # gradient penalty gd_x, = chainer.grad([self.dis_x(x_mid)], [x_mid], enable_double_backprop=True) gd_x = F.sqrt(F.batch_l2_norm_squared(gd_x) + 1e-6) loss_dis_x_gp = F.mean_squared_error( gd_x, self.xp.ones_like(gd_x.data)) chainer.report({'loss_dis': loss_dis_x}, self.dis_x) chainer.report( {'loss_gp': self.args.lambda_wgan_gp * loss_dis_x_gp}, self.dis_x) loss_dis_x = loss_dis_x + self.args.lambda_wgan_gp * loss_dis_x_gp self.dis_x.cleargrads() loss_dis_x.backward() opt_x.update(loss=loss_dis_x) ## discriminator for latent: X -> Z is - while Y -> Z is + if self.args.lambda_dis_z > 0 and t == 0: loss_dis_z = F.average( self.dis_z(x_z[-1]) - self.dis_z(y_z[-1])) z_mid = eps * x_z[-1] + (1.0 - eps) * y_z[-1] # gradient penalty gd_z, = chainer.grad([self.dis_z(z_mid)], [z_mid], enable_double_backprop=True) gd_z = F.sqrt(F.batch_l2_norm_squared(gd_z) + 1e-6) loss_dis_z_gp = F.mean_squared_error( gd_z, self.xp.ones_like(gd_z.data)) chainer.report({'loss_dis': loss_dis_z}, self.dis_z) chainer.report( {'loss_gp': self.args.lambda_wgan_gp * loss_dis_y_gp}, self.dis_y) loss_dis_z = loss_dis_z + self.args.lambda_wgan_gp * loss_dis_z_gp self.dis_z.cleargrads() loss_dis_z.backward() opt_z.update(loss=loss_dis_z) else: ## LSGAN if self.args.lambda_dis_y > 0: ## discriminator for A=>B (real:1, fake:0) disy_fake = self.dis_y(x_y_copy) loss_dis_y_fake = losses.loss_func_comp( disy_fake, 0.0, self.args.dis_jitter) disy_real = self.dis_y(y) loss_dis_y_real = losses.loss_func_comp( disy_real, 1.0, self.args.dis_jitter) if self.args.dis_reg_weighting > 0: ## regularization loss_dis_y_reg = ( F.average(F.absolute(disy_real[:, 1, :, :])) + F.average(F.absolute(disy_fake[:, 1, :, :]))) else: loss_dis_y_reg = 0 chainer.report({'loss_reg': loss_dis_y_reg}, self.dis_y) loss_dis_y_gp = 0 chainer.report({'loss_fake': loss_dis_y_fake}, self.dis_y) chainer.report({'loss_real': loss_dis_y_real}, self.dis_y) loss_dis_y = ( loss_dis_y_fake + loss_dis_y_real ) * 0.5 + self.args.dis_reg_weighting * loss_dis_y_reg + self.args.lambda_wgan_gp * loss_dis_y_gp self.dis_y.cleargrads() loss_dis_y.backward() opt_y.update(loss=loss_dis_y) if self.args.lambda_dis_x > 0: ## discriminator for B=>A disx_fake = self.dis_x(y_x_copy) loss_dis_x_fake = losses.loss_func_comp( disx_fake, 0.0, self.args.dis_jitter) disx_real = self.dis_x(x) loss_dis_x_real = losses.loss_func_comp( disx_real, 1.0, self.args.dis_jitter) if self.args.dis_reg_weighting > 0: ## regularization loss_dis_x_reg = ( F.average(F.absolute(disx_fake[:, 1, :, :])) + F.average(F.absolute(disx_real[:, 1, :, :]))) else: loss_dis_x_reg = 0 chainer.report({'loss_reg': loss_dis_x_reg}, self.dis_x) loss_dis_x_gp = 0 chainer.report({'loss_fake': loss_dis_x_fake}, self.dis_x) chainer.report({'loss_real': loss_dis_x_real}, self.dis_x) loss_dis_x = ( loss_dis_x_fake + loss_dis_x_real ) * 0.5 + self.args.dis_reg_weighting * loss_dis_x_reg + self.args.lambda_wgan_gp * loss_dis_x_gp self.dis_x.cleargrads() loss_dis_x.backward() opt_x.update(loss=loss_dis_x) ## discriminator for latent: X -> Z is 0.0 while Y -> Z is 1.0 if self.args.lambda_dis_z > 0: disz_xz = self.dis_z(x_z[-1]) loss_dis_z_x = losses.loss_func_comp(disz_xz, 0.0, self.args.dis_jitter) disz_yz = self.dis_z(y_z[-1]) loss_dis_z_y = losses.loss_func_comp(disz_yz, 1.0, self.args.dis_jitter) if self.args.dis_reg_weighting > 0: ## regularization loss_dis_z_reg = ( F.average(F.absolute(disz_xz[:, 1, :, :])) + F.average(F.absolute(disz_yz[:, 1, :, :]))) else: loss_dis_z_reg = 0 chainer.report({'loss_x': loss_dis_z_x}, self.dis_z) chainer.report({'loss_y': loss_dis_z_y}, self.dis_z) chainer.report({'loss_reg': loss_dis_z_reg}, self.dis_z) loss_dis_z = ( loss_dis_z_x + loss_dis_z_y ) * 0.5 + self.args.dis_reg_weighting * loss_dis_z_reg self.dis_z.cleargrads() loss_dis_z.backward() opt_z.update(loss=loss_dis_z)
def evaluate(self): batch_x = self._iterators['main'].next() batch_y = self._iterators['testB'].next() models = self._targets if self.eval_hook: self.eval_hook(self) fig = plt.figure(figsize=(9, 3 * (len(batch_x) + len(batch_y)))) gs = gridspec.GridSpec(len(batch_x) + len(batch_y), 3, wspace=0.1, hspace=0.1) x = Variable(self.converter(batch_x, self.device)) y = Variable(self.converter(batch_y, self.device)) with chainer.using_config('train', False): with chainer.function.no_backprop_mode(): if len(models) > 2: x_y = models['dec_y'](models['enc_x'](x)) if self.single_encoder: x_y_x = models['dec_x'](models['enc_x'](x_y)) else: x_y_x = models['dec_x']( models['enc_x'](x)) ## autoencoder #x_y_x = models['dec_x'](models['enc_y'](x_y)) else: x_y = models['gen_g'](x) x_y_x = models['gen_f'](x_y) # for i, var in enumerate([x, x_y]): for i, var in enumerate([x, x_y, x_y_x]): imgs = postprocess(var).astype(np.float32) for j in range(len(imgs)): ax = fig.add_subplot(gs[j, i]) if imgs[j].shape[2] == 1: ax.imshow(imgs[j, :, :, 0], interpolation='none', cmap='gray', vmin=0, vmax=1) else: ax.imshow(imgs[j], interpolation='none', vmin=0, vmax=1) ax.set_xticks([]) ax.set_yticks([]) with chainer.using_config('train', False): with chainer.function.no_backprop_mode(): if len(models) > 2: if self.single_encoder: y_x = models['dec_x'](models['enc_x'](y)) else: y_x = models['dec_x'](models['enc_y'](y)) # y_x_y = models['dec_y'](models['enc_y'](y)) ## autoencoder y_x_y = models['dec_y'](models['enc_x'](y_x)) else: # (gen_g, gen_f) y_x = models['gen_f'](y) y_x_y = models['gen_g'](y_x) # for i, var in enumerate([y, y_y]): for i, var in enumerate([y, y_x, y_x_y]): imgs = postprocess(var).astype(np.float32) for j in range(len(imgs)): ax = fig.add_subplot(gs[j + len(batch_x), i]) if imgs[j].shape[2] == 1: ax.imshow(imgs[j, :, :, 0], interpolation='none', cmap='gray', vmin=0, vmax=1) else: ax.imshow(imgs[j], interpolation='none', vmin=0, vmax=1) ax.set_xticks([]) ax.set_yticks([]) gs.tight_layout(fig) plt.savefig(os.path.join(self.vis_out, 'count{:0>4}.jpg'.format(self.count)), dpi=200) self.count += 1 plt.close() cycle_y_l1 = F.mean_absolute_error(y, y_x_y) cycle_y_l2 = F.mean_squared_error(y, y_x_y) cycle_x_l1 = F.mean_absolute_error(x, x_y_x) id_xy_grad = losses.loss_grad(x, x_y) id_xy_l1 = F.mean_absolute_error(x, x_y) result = { "myval/cycle_y_l1": cycle_y_l1, "myval/cycle_y_l2": cycle_y_l2, "myval/cycle_x_l1": cycle_x_l1, "myval/id_xy_grad": id_xy_grad, "myval/id_xy_l1": id_xy_l1 } return result
def update_core(self): opt_enc_x = self.get_optimizer('opt_enc_x') opt_dec_x = self.get_optimizer('opt_dec_x') opt_enc_y = self.get_optimizer('opt_enc_y') opt_dec_y = self.get_optimizer('opt_dec_y') opt_x = self.get_optimizer('opt_x') opt_y = self.get_optimizer('opt_y') opt_z = self.get_optimizer('opt_z') self._iter += 1 if self.is_new_epoch and self.epoch >= self.args.lrdecay_start: decay_step = self.init_alpha / self.args.lrdecay_period # print('lr decay', decay_step) if opt_enc_x.alpha > decay_step: opt_enc_x.alpha -= decay_step if opt_dec_x.alpha > decay_step: opt_dec_x.alpha -= decay_step if opt_enc_y.alpha > decay_step: opt_enc_y.alpha -= decay_step if opt_dec_y.alpha > decay_step: opt_dec_y.alpha -= decay_step if opt_y.alpha > decay_step: opt_y.alpha -= decay_step if opt_x.alpha > decay_step: opt_x.alpha -= decay_step if opt_z.alpha > decay_step: opt_z.alpha -= decay_step # get mini-batch batch_x = self.get_iterator('main').next() batch_y = self.get_iterator('train_B').next() x = Variable(self.converter(batch_x, self.device)) y = Variable(self.converter(batch_y, self.device)) # to latent x_z = self.enc_x(losses.add_noise( x, sigma=self.args.noise)) # noise injection y_z = self.enc_y(losses.add_noise(y, sigma=self.args.noise)) loss_gen = 0 ## regularisation on the latent space if self.args.lambda_reg > 0: loss_reg_enc_y = losses.loss_func_reg(y_z[-1], 'l2') loss_reg_enc_x = losses.loss_func_reg(x_z[-1], 'l2') loss_gen = loss_gen + self.args.lambda_reg * (loss_reg_enc_x + loss_reg_enc_y) if self.report_start < self._iter: chainer.report({'loss_reg': loss_reg_enc_x}, self.enc_x) chainer.report({'loss_reg': loss_reg_enc_y}, self.enc_y) ## discriminator for the latent space: distribution of image of enc_x should look same as that of enc_y if self.args.lambda_dis_z > 0 and self._iter > self.args.dis_z_start: x_z_copy = Variable(self._buffer_xz.query(x_z[-1].data)) loss_enc_x_adv = losses.loss_func_comp(self.dis_z(x_z[-1]), 1.0) loss_gen = loss_gen + self.args.lambda_dis_z * loss_enc_x_adv if self.report_start < self._iter: chainer.report({'loss_adv': loss_enc_x_adv}, self.enc_x) # cycle for X=>Z=>X x_z_clean = x_z[-1].data.copy() x_z[-1] = losses.add_noise(x_z[-1], sigma=self.args.noise_z) x_x = self.dec_x(x_z) loss_cycle_xzx = losses.loss_avg(x_x, x, ksize=self.args.cycle_ksize, norm='l1') if self.report_start < self._iter: chainer.report({'loss_cycle': loss_cycle_xzx}, self.enc_x) # cycle for Y=>Z=>Y y_z_clean = y_z[-1].data.copy() y_z[-1] = losses.add_noise(y_z[-1], sigma=self.args.noise_z) y_y = self.dec_y(y_z) loss_cycle_yzy = losses.loss_avg(y_y, y, ksize=self.args.cycle_ksize, norm='l1') if self.report_start < self._iter: chainer.report({'loss_cycle': loss_cycle_yzy}, self.enc_y) loss_gen = loss_gen + self.args.lambda_A * loss_cycle_xzx + self.args.lambda_B * loss_cycle_yzy ## conversion x_y = self.dec_y(x_z) y_x = self.dec_x(y_z) # cycle for (X=>)Z=>Y=>Z x_y_z = self.enc_y(x_y) loss_cycle_zyz = F.mean_absolute_error(x_y_z[-1], x_z_clean) if self.report_start < self._iter: chainer.report({'loss_cycle': loss_cycle_zyz}, self.dec_y) # cycle for (Y=>)Z=>X=>Z y_x_z = self.enc_x(y_x) loss_cycle_zxz = F.mean_absolute_error(y_x_z[-1], y_z_clean) if self.report_start < self._iter: chainer.report({'loss_cycle': loss_cycle_zxz}, self.dec_x) loss_gen = loss_gen + self.args.lambda_A * loss_cycle_zxz + self.args.lambda_B * loss_cycle_zyz ## adversarial for Y if self.args.lambda_dis_y > 0: if self.args.conditional_discriminator: x_y_copy = Variable( self._buffer_y.query(F.concat([x, x_y]).data)) loss_dec_y_adv = losses.loss_func_comp( self.dis_y(F.concat([x, x_y])), 1.0) else: x_y_copy = Variable(self._buffer_y.query(x_y.data)) loss_dec_y_adv = losses.loss_func_comp(self.dis_y(x_y), 1.0) loss_gen = loss_gen + self.args.lambda_dis_y * loss_dec_y_adv if self.report_start < self._iter: chainer.report({'loss_adv': loss_dec_y_adv}, self.dec_y) ## adversarial for X if self.args.lambda_dis_x > 0: if self.args.conditional_discriminator: y_x_copy = Variable( self._buffer_x.query(F.concat([y, y_x]).data)) loss_dec_x_adv = losses.loss_func_comp( self.dis_x(F.concat([y, y_x])), 1.0) else: y_x_copy = Variable(self._buffer_x.query(y_x.data)) loss_dec_x_adv = losses.loss_func_comp(self.dis_x(y_x), 1.0) loss_gen = loss_gen + self.args.lambda_dis_x * loss_dec_x_adv if self.report_start < self._iter: chainer.report({'loss_adv': loss_dec_x_adv}, self.dec_x) ## images before/after conversion should look similar in terms of perceptual loss if self.args.lambda_identity_x > 0: loss_id_x = losses.loss_perceptual(x, x_y, self.vgg) loss_gen = loss_gen + self.args.lambda_identity_x * loss_id_x if self.report_start < self._iter: chainer.report({'loss_id': 1e-3 * loss_id_x}, self.enc_x) if self.args.lambda_identity_y > 0: loss_id_y = losses.loss_perceptual(y, y_x, self.vgg) loss_gen = loss_gen + self.args.lambda_identity_y * loss_id_y if self.report_start < self._iter: chainer.report({'loss_id': 1e-3 * loss_id_y}, self.enc_y) ## warm-up if self._iter < self.args.warmup: loss_gen += losses.loss_avg(y, y_x, ksize=self.args.id_ksize, norm='l2') loss_gen += losses.loss_avg(x, x_y, ksize=self.args.id_ksize, norm='l2') ## air should be -1 if self.args.lambda_air > 0: loss_air_x = losses.loss_range_comp(x, x_y, 0.9, norm='l2') loss_air_y = losses.loss_range_comp(y, y_x, 0.9, norm='l2') loss_gen = loss_gen + self.args.lambda_air * (loss_air_x + loss_air_y) if self.report_start < self._iter: chainer.report({'loss_air': 0.1 * loss_air_x}, self.dec_y) chainer.report({'loss_air': 0.1 * loss_air_y}, self.dec_x) ## images before/after conversion should look similar in the gradient domain if self.args.lambda_grad > 0: loss_grad_x = losses.loss_grad(x, x_y) loss_grad_y = losses.loss_grad(y, y_x) loss_gen = loss_gen + self.args.lambda_grad * (loss_grad_x + loss_grad_y) if self.report_start < self._iter: chainer.report({'loss_grad': loss_grad_x}, self.dec_y) chainer.report({'loss_grad': loss_grad_y}, self.dec_x) if self.args.lambda_tv > 0: loss_tv = losses.total_variation(x_y, self.args.tv_tau) loss_gen = loss_gen + self.args.lambda_tv * loss_tv if self.report_start < self._iter: chainer.report({'loss_tv': loss_tv}, self.dec_y) ## back propagate self.enc_x.cleargrads() self.dec_x.cleargrads() self.enc_y.cleargrads() self.dec_y.cleargrads() loss_gen.backward() opt_enc_x.update() opt_dec_x.update() if not self.args.single_encoder: opt_enc_y.update() opt_dec_y.update() ## discriminator for Y if self.args.lambda_dis_y > 0: loss_dis_y_fake = losses.loss_func_comp(self.dis_y(x_y_copy), 0.0, self.args.dis_jitter) if self.args.conditional_discriminator: loss_dis_y_real = losses.loss_func_comp( self.dis_y(F.concat([x, y])), 1.0, self.args.dis_jitter) else: loss_dis_y_real = losses.loss_func_comp( self.dis_y(y), 1.0, self.args.dis_jitter) loss_dis_y = (loss_dis_y_fake + loss_dis_y_real) * 0.5 if self.report_start < self._iter: chainer.report({'loss_fake': loss_dis_y_fake}, self.dis_y) chainer.report({'loss_real': loss_dis_y_real}, self.dis_y) self.dis_y.cleargrads() loss_dis_y.backward() opt_y.update() ## discriminator for X if self.args.lambda_dis_x > 0: loss_dis_x_fake = losses.loss_func_comp(self.dis_x(y_x_copy), 0.0, self.args.dis_jitter) if self.args.conditional_discriminator: loss_dis_x_real = losses.loss_func_comp( self.dis_x(F.concat([y, x])), 1.0, self.args.dis_jitter) else: loss_dis_x_real = losses.loss_func_comp( self.dis_x(x), 1.0, self.args.dis_jitter) loss_dis_x = (loss_dis_x_fake + loss_dis_x_real) * 0.5 if self.report_start < self._iter: chainer.report({'loss_fake': loss_dis_x_fake}, self.dis_x) chainer.report({'loss_real': loss_dis_x_real}, self.dis_x) self.dis_x.cleargrads() loss_dis_x.backward() opt_x.update() ## discriminator for latent if self.args.lambda_dis_z > 0 and self._iter > self.args.dis_z_start: loss_dis_z_x = losses.loss_func_comp(self.dis_z(x_z_copy), 0.0, self.args.dis_jitter) loss_dis_z_y = losses.loss_func_comp(self.dis_z(y_z[-1]), 1.0, self.args.dis_jitter) loss_dis_z = (loss_dis_z_x + loss_dis_z_y) * 0.5 if self.report_start < self._iter: chainer.report({'loss_x': loss_dis_z_x}, self.dis_z) chainer.report({'loss_y': loss_dis_z_y}, self.dis_z) self.dis_z.cleargrads() loss_dis_z.backward() opt_z.update()