def update_core(self): opt_g = self.get_optimizer('opt_g') opt_f = self.get_optimizer('opt_f') opt_x = self.get_optimizer('opt_x') opt_y = self.get_optimizer('opt_y') self._iter += 1 # learning rate decay: TODO: weight_decay_rate of AdamW if self.is_new_epoch and self.epoch >= self.args.lrdecay_start: decay_step = self.init_alpha / self.args.lrdecay_period if opt_g.alpha > decay_step: opt_g.alpha -= decay_step if opt_f.alpha > decay_step: opt_f.alpha -= decay_step if opt_x.alpha > decay_step: opt_x.alpha -= decay_step if opt_y.alpha > decay_step: opt_y.alpha -= decay_step # get mini-batch batch_x = self.get_iterator('main').next() batch_y = self.get_iterator('train_B').next() x = Variable(self.converter(batch_x, self.device)) y = Variable(self.converter(batch_y, self.device)) ### generator # X => Y => X x_y = self.gen_g(losses.add_noise(x, sigma=self.args.noise)) if self.args.conditional_discriminator: x_y_copy = Variable(self._buffer_y.query(F.concat([x, x_y]).data)) else: x_y_copy = Variable(self._buffer_y.query(x_y.data)) x_y_x = self.gen_f(x_y) loss_cycle_x = losses.loss_avg(x_y_x, x, ksize=self.args.cycle_ksize, norm='l1') loss_gen_g_adv = 0 if self.args.gen_start < self._iter: if self.args.conditional_discriminator: if self.args.wgan: loss_gen_g_adv = -F.average(self.dis_y(F.concat([x, x_y]))) else: loss_gen_g_adv = losses.loss_func_comp( self.dis_y(F.concat([x, x_y])), 1.0) else: if self.args.wgan: loss_gen_g_adv = -F.average(self.dis_y(x_y)) else: loss_gen_g_adv = losses.loss_func_comp( self.dis_y(x_y), 1.0) # Y => X => Y loss_gen_f_adv = 0 y_x = self.gen_f(losses.add_noise( y, sigma=self.args.noise)) # noise injection if self.args.conditional_discriminator: y_x_copy = Variable(self._buffer_x.query(F.concat([y, y_x]).data)) else: y_x_copy = Variable(self._buffer_x.query(y_x.data)) y_x_y = self.gen_g(y_x) loss_cycle_y = losses.loss_avg(y_x_y, y, ksize=self.args.cycle_ksize, norm='l1') if self.args.gen_start < self._iter: if self.args.conditional_discriminator: if self.args.wgan: loss_gen_f_adv = -F.average(self.dis_x(F.concat([y, y_x]))) else: loss_gen_f_adv = losses.loss_func_comp( self.dis_x(F.concat([y, y_x])), 1.0) else: if self.args.wgan: loss_gen_f_adv = -F.average(self.dis_x(y_x)) else: loss_gen_f_adv = losses.loss_func_comp( self.dis_x(y_x), 1.0) ## total loss for generators loss_gen = (self.args.lambda_dis_y * loss_gen_g_adv + self.args.lambda_dis_x * loss_gen_f_adv) + ( self.args.lambda_A * loss_cycle_x + self.args.lambda_B * loss_cycle_y) ## idempotence: f shouldn't change x if self.args.lambda_idempotence > 0: loss_idem_x = F.mean_absolute_error(y_x, self.gen_f(y_x)) loss_idem_y = F.mean_absolute_error(x_y, self.gen_g(x_y)) loss_gen = loss_gen + self.args.lambda_idempotence * (loss_idem_x + loss_idem_y) if self.report_start < self._iter: chainer.report({'loss_idem': loss_idem_x}, self.gen_f) chainer.report({'loss_idem': loss_idem_y}, self.gen_g) if self.args.lambda_domain > 0: loss_dom_x = F.mean_absolute_error(x, self.gen_f(x)) loss_dom_y = F.mean_absolute_error(y, self.gen_g(y)) if self._iter < self.args.warmup: loss_gen = loss_gen + max(self.args.lambda_domain, 1.0) * (loss_dom_x + loss_dom_y) else: loss_gen = loss_gen + self.args.lambda_domain * (loss_dom_x + loss_dom_y) if self.report_start < self._iter: chainer.report({'loss_dom': loss_dom_x}, self.gen_f) chainer.report({'loss_dom': loss_dom_y}, self.gen_g) ## images before/after conversion should look similar in terms of perceptual loss if self.args.lambda_identity_x > 0: loss_id_x = losses.loss_perceptual(x, x_y, self.vgg) loss_gen = loss_gen + self.args.lambda_identity_x * loss_id_x if self.report_start < self._iter: chainer.report({'loss_id': 1e-3 * loss_id_x}, self.gen_g) if self.args.lambda_identity_y > 0: loss_id_y = losses.loss_perceptual(y, y_x, self.vgg) loss_gen = loss_gen + self.args.lambda_identity_y * loss_id_y if self.report_start < self._iter: chainer.report({'loss_id': 1e-3 * loss_id_y}, self.gen_f) ## warm-up if self._iter < self.args.warmup: loss_gen = loss_gen + F.mean_squared_error(x, self.gen_f(x)) loss_gen = loss_gen + F.mean_squared_error(y, self.gen_g(y)) # loss_gen = loss_gen + losses.loss_avg(y,y_x, ksize=self.args.id_ksize, norm='l2') # loss_gen = loss_gen + losses.loss_avg(x,x_y, ksize=self.args.id_ksize, norm='l2') ## background should be preserved if self.args.lambda_air > 0: loss_air_x = losses.loss_range_comp(x, x_y, 0.9, norm='l2') loss_air_y = losses.loss_range_comp(y, y_x, 0.9, norm='l2') loss_gen = loss_gen + self.args.lambda_air * (loss_air_x + loss_air_y) if self.report_start < self._iter: chainer.report({'loss_air': 0.1 * loss_air_x}, self.gen_g) chainer.report({'loss_air': 0.1 * loss_air_y}, self.gen_f) ## comparison of images before/after conversion in the gradient domain if self.args.lambda_grad > 0: loss_grad_x = losses.loss_grad(x, x_y, self.args.grad_norm) loss_grad_y = losses.loss_grad(y, y_x, self.args.grad_norm) loss_gen = loss_gen + self.args.lambda_grad * (loss_grad_x + loss_grad_y) if self.report_start < self._iter: chainer.report({'loss_grad': loss_grad_x}, self.gen_g) chainer.report({'loss_grad': loss_grad_y}, self.gen_f) ## total variation if self.args.lambda_tv > 0: loss_tv = losses.total_variation(x_y, self.args.tv_tau) loss_gen = loss_gen + self.args.lambda_tv * loss_tv if self.report_start < self._iter: chainer.report({'loss_tv': loss_tv}, self.gen_g) if self.report_start < self._iter: chainer.report({'loss_cycle_x': loss_cycle_x}, self.gen_f) chainer.report({'loss_adv': loss_gen_g_adv}, self.gen_g) chainer.report({'loss_cycle_y': loss_cycle_y}, self.gen_g) chainer.report({'loss_adv': loss_gen_f_adv}, self.gen_f) self.gen_f.cleargrads() self.gen_g.cleargrads() loss_gen.backward() opt_f.update() opt_g.update() ### discriminator for t in range(self.args.n_critics): if self.args.wgan: ## synthesised -, real + loss_dis_y_fake = F.average(self.dis_y(x_y_copy)) eps = self.xp.random.uniform(0, 1, size=len(batch_y)).astype( self.xp.float32)[:, None, None, None] if self.args.conditional_discriminator: loss_dis_y_real = -F.average(self.dis_y(F.concat([x, y]))) y_mid = eps * F.concat([y, x]) + (1.0 - eps) * x_y_copy else: loss_dis_y_real = -F.average(self.dis_y(y)) y_mid = eps * y + (1.0 - eps) * x_y_copy # gradient penalty gd_y, = chainer.grad([self.dis_y(y_mid)], [y_mid], enable_double_backprop=True) gd_y = F.sqrt(F.batch_l2_norm_squared(gd_y) + 1e-6) loss_dis_y_gp = F.mean_squared_error( gd_y, self.xp.ones_like(gd_y.data)) if self.report_start < self._iter: chainer.report({'loss_real': loss_dis_y_real}, self.dis_y) chainer.report({'loss_fake': loss_dis_y_fake}, self.dis_y) chainer.report( {'loss_gp': self.args.lambda_wgan_gp * loss_dis_y_gp}, self.dis_y) loss_dis_y = (loss_dis_y_real + loss_dis_y_fake + self.args.lambda_wgan_gp * loss_dis_y_gp) self.dis_y.cleargrads() loss_dis_y.backward() opt_y.update() ## discriminator for B=>A loss_dis_x_fake = F.average(self.dis_x(y_x_copy)) if self.args.conditional_discriminator: loss_dis_x_real = -F.average( self.dis_x( losses.add_noise(F.concat([y, x]), sigma=self.args.noise))) x_mid = eps * F.concat([x, y]) + (1.0 - eps) * y_x_copy else: loss_dis_x_real = -F.average(self.dis_x(x)) x_mid = eps * x + (1.0 - eps) * y_x_copy # gradient penalty gd_x, = chainer.grad([self.dis_x(x_mid)], [x_mid], enable_double_backprop=True) gd_x = F.sqrt(F.batch_l2_norm_squared(gd_x) + 1e-6) loss_dis_x_gp = F.mean_squared_error( gd_x, self.xp.ones_like(gd_x.data)) if self.report_start < self._iter: chainer.report({'loss_real': loss_dis_x_real}, self.dis_x) chainer.report({'loss_fake': loss_dis_x_fake}, self.dis_x) chainer.report( {'loss_gp': self.args.lambda_wgan_gp * loss_dis_x_gp}, self.dis_x) loss_dis_x = (loss_dis_x_real + loss_dis_x_fake + self.args.lambda_wgan_gp * loss_dis_x_gp) self.dis_x.cleargrads() loss_dis_x.backward() opt_x.update() else: ## discriminator for A=>B (real:1, fake:0) loss_dis_y_fake = losses.loss_func_comp( self.dis_y(x_y_copy), 0.0, self.args.dis_jitter) if self.args.conditional_discriminator: loss_dis_y_real = losses.loss_func_comp( self.dis_y(F.concat([x, y])), 1.0, self.args.dis_jitter) else: loss_dis_y_real = losses.loss_func_comp( self.dis_y(y), 1.0, self.args.dis_jitter) loss_dis_y = (loss_dis_y_fake + loss_dis_y_real) * 0.5 self.dis_y.cleargrads() loss_dis_y.backward() opt_y.update() ## discriminator for B=>A loss_dis_x_fake = losses.loss_func_comp( self.dis_x(y_x_copy), 0.0, self.args.dis_jitter) if self.args.conditional_discriminator: loss_dis_x_real = losses.loss_func_comp( self.dis_x(F.concat([y, x])), 1.0, self.args.dis_jitter) else: loss_dis_x_real = losses.loss_func_comp( self.dis_x(x), 1.0, self.args.dis_jitter) loss_dis_x = (loss_dis_x_fake + loss_dis_x_real) * 0.5 self.dis_x.cleargrads() loss_dis_x.backward() opt_x.update() if self.report_start < self._iter: chainer.report({'loss_real': loss_dis_x_real}, self.dis_x) chainer.report({'loss_fake': loss_dis_x_fake}, self.dis_x) chainer.report({'loss_real': loss_dis_y_real}, self.dis_y) chainer.report({'loss_fake': loss_dis_y_fake}, self.dis_y) # prepare next images if (t < self.args.n_critics - 1): x_y_copy = Variable( self.xp.concatenate( random.sample(self._buffer_y.images, self.args.batch_size))) y_x_copy = Variable( self.xp.concatenate( random.sample(self._buffer_x.images, self.args.batch_size))) batch_x = self.get_iterator('main').next() batch_y = self.get_iterator('train_B').next() x = Variable(self.converter(batch_x, self.device)) y = Variable(self.converter(batch_y, self.device))
def update_core(self): opt_enc_x = self.get_optimizer('opt_enc_x') opt_dec_x = self.get_optimizer('opt_dec_x') opt_enc_y = self.get_optimizer('opt_enc_y') opt_dec_y = self.get_optimizer('opt_dec_y') opt_x = self.get_optimizer('opt_x') opt_y = self.get_optimizer('opt_y') opt_z = self.get_optimizer('opt_z') # get mini-batch batch_x = self.get_iterator('main').next() batch_y = self.get_iterator('train_B').next() x = Variable(self.converter(batch_x, self.args.gpu[0])) y = Variable(self.converter(batch_y, self.args.gpu[0])) # encode to latent (X,Y => Z) x_z = self.enc_x(losses.add_noise(x, sigma=self.args.noise)) y_z = self.enc_y(losses.add_noise(y, sigma=self.args.noise)) loss_gen = 0 ## regularisation on the latent space if self.args.lambda_reg > 0: loss_reg_enc_y = losses.loss_func_reg(y_z[-1], 'l2') loss_reg_enc_x = losses.loss_func_reg(x_z[-1], 'l2') loss_gen = loss_gen + self.args.lambda_reg * (loss_reg_enc_x + loss_reg_enc_y) chainer.report({'loss_reg': loss_reg_enc_x}, self.enc_x) chainer.report({'loss_reg': loss_reg_enc_y}, self.enc_y) ## discriminator for the latent space: distribution of image of enc_x should look same as that of enc_y # since z is a list (for u-net), we use only the output of the last layer if self.args.lambda_dis_z > 0: if self.args.dis_wgan: loss_enc_x_adv = -F.average(self.dis_z(x_z[-1])) loss_enc_y_adv = F.average(self.dis_z(y_z[-1])) else: loss_enc_x_adv = losses.loss_func_comp(self.dis_z(x_z[-1]), 1.0) loss_enc_y_adv = losses.loss_func_comp(self.dis_z(y_z[-1]), 0.0) loss_gen = loss_gen + self.args.lambda_dis_z * (loss_enc_x_adv + loss_enc_y_adv) chainer.report({'loss_adv': loss_enc_x_adv}, self.enc_x) chainer.report({'loss_adv': loss_enc_y_adv}, self.enc_y) # cycle for X=>Z=>X (Autoencoder) x_x = self.dec_x(x_z) loss_cycle_xzx = F.mean_absolute_error(x_x, x) chainer.report({'loss_cycle': loss_cycle_xzx}, self.enc_x) # cycle for Y=>Z=>Y (Autoencoder) y_y = self.dec_y(y_z) loss_cycle_yzy = F.mean_absolute_error(y_y, y) chainer.report({'loss_cycle': loss_cycle_yzy}, self.enc_y) loss_gen = loss_gen + self.args.lambda_Az * loss_cycle_xzx + self.args.lambda_Bz * loss_cycle_yzy ## decode from latent Z => Y,X x_y = self.dec_y(x_z) y_x = self.dec_x(y_z) # cycle for X=>Z=>Y=>Z=>X (Z=>Y=>Z does not work well) x_y_x = self.dec_x(self.enc_y(x_y)) loss_cycle_x = F.mean_absolute_error(x_y_x, x) chainer.report({'loss_cycle': loss_cycle_x}, self.dec_x) # cycle for Y=>Z=>X=>Z=>Y y_x_y = self.dec_y(self.enc_x(y_x)) loss_cycle_y = F.mean_absolute_error(y_x_y, y) chainer.report({'loss_cycle': loss_cycle_y}, self.dec_y) loss_gen = loss_gen + self.args.lambda_A * loss_cycle_x + self.args.lambda_B * loss_cycle_y ## adversarial for Y if self.args.lambda_dis_y > 0: x_y_copy = Variable(self._buffer_y.query(x_y.data)) if self.args.dis_wgan: loss_dec_y_adv = -F.average(self.dis_y(x_y)) else: loss_dec_y_adv = losses.loss_func_comp(self.dis_y(x_y), 1.0) loss_gen = loss_gen + self.args.lambda_dis_y * loss_dec_y_adv chainer.report({'loss_adv': loss_dec_y_adv}, self.dec_y) ## adversarial for X if self.args.lambda_dis_x > 0: y_x_copy = Variable(self._buffer_x.query(y_x.data)) if self.args.dis_wgan: loss_dec_x_adv = -F.average(self.dis_x(y_x)) else: loss_dec_x_adv = losses.loss_func_comp(self.dis_x(y_x), 1.0) loss_gen = loss_gen + self.args.lambda_dis_x * loss_dec_x_adv chainer.report({'loss_adv': loss_dec_x_adv}, self.dec_x) ## idempotence if self.args.lambda_idempotence > 0: loss_idem_x = F.mean_absolute_error(y_x, self.dec_x(self.enc_y(y_x))) loss_idem_y = F.mean_absolute_error(x_y, self.dec_y(self.enc_x(x_y))) loss_gen = loss_gen + self.args.lambda_idempotence * (loss_idem_x + loss_idem_y) chainer.report({'loss_idem': loss_idem_x}, self.dec_x) chainer.report({'loss_idem': loss_idem_y}, self.dec_y) # Y => X shouldn't change X if self.args.lambda_domain > 0: loss_dom_x = F.mean_absolute_error(x, self.dec_x(self.enc_y(x))) loss_dom_y = F.mean_absolute_error(y, self.dec_y(self.enc_x(y))) loss_gen = loss_gen + self.args.lambda_domain * (loss_dom_x + loss_dom_y) chainer.report({'loss_dom': loss_dom_x}, self.dec_x) chainer.report({'loss_dom': loss_dom_y}, self.dec_y) ## images before/after conversion should look similar in terms of perceptual loss if self.args.lambda_identity_x > 0: loss_id_x = losses.loss_perceptual( x, x_y, self.vgg, layer=self.args.perceptual_layer, grey=self.args.grey) loss_gen = loss_gen + self.args.lambda_identity_x * loss_id_x chainer.report({'loss_id': 1e-3 * loss_id_x}, self.enc_x) if self.args.lambda_identity_y > 0: loss_id_y = losses.loss_perceptual( y, y_x, self.vgg, layer=self.args.perceptual_layer, grey=self.args.grey) loss_gen = loss_gen + self.args.lambda_identity_y * loss_id_y chainer.report({'loss_id': 1e-3 * loss_id_y}, self.enc_y) ## background (pixels with value -1) should be preserved if self.args.lambda_air > 0: loss_air_x = losses.loss_comp_low(x, x_y, self.args.air_threshold, norm='l2') loss_air_y = losses.loss_comp_low(y, y_x, self.args.air_threshold, norm='l2') loss_gen = loss_gen + self.args.lambda_air * (loss_air_x + loss_air_y) chainer.report({'loss_air': loss_air_x}, self.dec_y) chainer.report({'loss_air': loss_air_y}, self.dec_x) ## images before/after conversion should look similar in the gradient domain if self.args.lambda_grad > 0: loss_grad_x = losses.loss_grad(x, x_y) loss_grad_y = losses.loss_grad(y, y_x) loss_gen = loss_gen + self.args.lambda_grad * (loss_grad_x + loss_grad_y) chainer.report({'loss_grad': loss_grad_x}, self.dec_y) chainer.report({'loss_grad': loss_grad_y}, self.dec_x) ## total variation (only for X -> Y) if self.args.lambda_tv > 0: loss_tv = losses.total_variation(x_y, tau=self.args.tv_tau, method=self.args.tv_method) if self.args.imgtype == "dcm" and self.args.num_slices > 1: loss_tv += losses.total_variation_ch(x_y) loss_gen = loss_gen + self.args.lambda_tv * loss_tv chainer.report({'loss_tv': loss_tv}, self.dec_y) ## back propagate self.enc_x.cleargrads() self.dec_x.cleargrads() self.enc_y.cleargrads() self.dec_y.cleargrads() loss_gen.backward() opt_enc_x.update(loss=loss_gen) opt_dec_x.update(loss=loss_gen) if not self.args.single_encoder: opt_enc_y.update(loss=loss_gen) opt_dec_y.update(loss=loss_gen) ########################################## ## discriminator for Y if self.args.dis_wgan: ## synthesised -, real + eps = self.xp.random.uniform(0, 1, size=len(batch_y)).astype( self.xp.float32)[:, None, None, None] if self.args.lambda_dis_y > 0: ## discriminator for X=>Y loss_dis_y = F.average(self.dis_y(x_y_copy) - self.dis_y(y)) y_mid = eps * y + (1.0 - eps) * x_y_copy # gradient penalty gd_y, = chainer.grad([self.dis_y(y_mid)], [y_mid], enable_double_backprop=True) gd_y = F.sqrt(F.batch_l2_norm_squared(gd_y) + 1e-6) loss_dis_y_gp = F.mean_squared_error( gd_y, self.xp.ones_like(gd_y.data)) chainer.report({'loss_dis': loss_dis_y}, self.dis_y) chainer.report( {'loss_gp': self.args.lambda_wgan_gp * loss_dis_y_gp}, self.dis_y) loss_dis_y = loss_dis_y + self.args.lambda_wgan_gp * loss_dis_y_gp self.dis_y.cleargrads() loss_dis_y.backward() opt_y.update(loss=loss_dis_y) if self.args.lambda_dis_x > 0: ## discriminator for B=>A loss_dis_x = F.average(self.dis_x(y_x_copy) - self.dis_x(x)) x_mid = eps * x + (1.0 - eps) * y_x_copy # gradient penalty gd_x, = chainer.grad([self.dis_x(x_mid)], [x_mid], enable_double_backprop=True) gd_x = F.sqrt(F.batch_l2_norm_squared(gd_x) + 1e-6) loss_dis_x_gp = F.mean_squared_error( gd_x, self.xp.ones_like(gd_x.data)) chainer.report({'loss_dis': loss_dis_x}, self.dis_x) chainer.report( {'loss_gp': self.args.lambda_wgan_gp * loss_dis_x_gp}, self.dis_x) loss_dis_x = loss_dis_x + self.args.lambda_wgan_gp * loss_dis_x_gp self.dis_x.cleargrads() loss_dis_x.backward() opt_x.update(loss=loss_dis_x) ## discriminator for latent: X -> Z is - while Y -> Z is + if self.args.lambda_dis_z > 0 and t == 0: loss_dis_z = F.average( self.dis_z(x_z[-1]) - self.dis_z(y_z[-1])) z_mid = eps * x_z[-1] + (1.0 - eps) * y_z[-1] # gradient penalty gd_z, = chainer.grad([self.dis_z(z_mid)], [z_mid], enable_double_backprop=True) gd_z = F.sqrt(F.batch_l2_norm_squared(gd_z) + 1e-6) loss_dis_z_gp = F.mean_squared_error( gd_z, self.xp.ones_like(gd_z.data)) chainer.report({'loss_dis': loss_dis_z}, self.dis_z) chainer.report( {'loss_gp': self.args.lambda_wgan_gp * loss_dis_y_gp}, self.dis_y) loss_dis_z = loss_dis_z + self.args.lambda_wgan_gp * loss_dis_z_gp self.dis_z.cleargrads() loss_dis_z.backward() opt_z.update(loss=loss_dis_z) else: ## LSGAN if self.args.lambda_dis_y > 0: ## discriminator for A=>B (real:1, fake:0) disy_fake = self.dis_y(x_y_copy) loss_dis_y_fake = losses.loss_func_comp( disy_fake, 0.0, self.args.dis_jitter) disy_real = self.dis_y(y) loss_dis_y_real = losses.loss_func_comp( disy_real, 1.0, self.args.dis_jitter) if self.args.dis_reg_weighting > 0: ## regularization loss_dis_y_reg = ( F.average(F.absolute(disy_real[:, 1, :, :])) + F.average(F.absolute(disy_fake[:, 1, :, :]))) else: loss_dis_y_reg = 0 chainer.report({'loss_reg': loss_dis_y_reg}, self.dis_y) loss_dis_y_gp = 0 chainer.report({'loss_fake': loss_dis_y_fake}, self.dis_y) chainer.report({'loss_real': loss_dis_y_real}, self.dis_y) loss_dis_y = ( loss_dis_y_fake + loss_dis_y_real ) * 0.5 + self.args.dis_reg_weighting * loss_dis_y_reg + self.args.lambda_wgan_gp * loss_dis_y_gp self.dis_y.cleargrads() loss_dis_y.backward() opt_y.update(loss=loss_dis_y) if self.args.lambda_dis_x > 0: ## discriminator for B=>A disx_fake = self.dis_x(y_x_copy) loss_dis_x_fake = losses.loss_func_comp( disx_fake, 0.0, self.args.dis_jitter) disx_real = self.dis_x(x) loss_dis_x_real = losses.loss_func_comp( disx_real, 1.0, self.args.dis_jitter) if self.args.dis_reg_weighting > 0: ## regularization loss_dis_x_reg = ( F.average(F.absolute(disx_fake[:, 1, :, :])) + F.average(F.absolute(disx_real[:, 1, :, :]))) else: loss_dis_x_reg = 0 chainer.report({'loss_reg': loss_dis_x_reg}, self.dis_x) loss_dis_x_gp = 0 chainer.report({'loss_fake': loss_dis_x_fake}, self.dis_x) chainer.report({'loss_real': loss_dis_x_real}, self.dis_x) loss_dis_x = ( loss_dis_x_fake + loss_dis_x_real ) * 0.5 + self.args.dis_reg_weighting * loss_dis_x_reg + self.args.lambda_wgan_gp * loss_dis_x_gp self.dis_x.cleargrads() loss_dis_x.backward() opt_x.update(loss=loss_dis_x) ## discriminator for latent: X -> Z is 0.0 while Y -> Z is 1.0 if self.args.lambda_dis_z > 0: disz_xz = self.dis_z(x_z[-1]) loss_dis_z_x = losses.loss_func_comp(disz_xz, 0.0, self.args.dis_jitter) disz_yz = self.dis_z(y_z[-1]) loss_dis_z_y = losses.loss_func_comp(disz_yz, 1.0, self.args.dis_jitter) if self.args.dis_reg_weighting > 0: ## regularization loss_dis_z_reg = ( F.average(F.absolute(disz_xz[:, 1, :, :])) + F.average(F.absolute(disz_yz[:, 1, :, :]))) else: loss_dis_z_reg = 0 chainer.report({'loss_x': loss_dis_z_x}, self.dis_z) chainer.report({'loss_y': loss_dis_z_y}, self.dis_z) chainer.report({'loss_reg': loss_dis_z_reg}, self.dis_z) loss_dis_z = ( loss_dis_z_x + loss_dis_z_y ) * 0.5 + self.args.dis_reg_weighting * loss_dis_z_reg self.dis_z.cleargrads() loss_dis_z.backward() opt_z.update(loss=loss_dis_z)
def update_core(self): optimizer_sd = self.get_optimizer('main') optimizer_enc = self.get_optimizer('enc') optimizer_dec = self.get_optimizer('dec') optimizer_dis = self.get_optimizer('dis') xp = self.seed.xp step = self.iteration % self.args.iter osem_step = step % self.args.osem if step == 0: batch = self.get_iterator('main').next() self.prImg, self.rev, self.patient_id, self.slice = self.converter(batch, self.device) print(self.prImg.shape) self.n_reconst += 1 self.recon_freq = 1 if ".npy" in self.args.model_image: self.seed.W.array = xp.reshape(xp.load(self.args.model_image),(1,1,self.args.crop_height,self.args.crop_width)) elif ".dcm" in self.args.model_image: ref_dicom = dicom.read_file(self.args.model_image, force=True) img = xp.array(ref_dicom.pixel_array+ref_dicom.RescaleIntercept) img = (2*(xp.clip(img,self.args.HU_base,self.args.HU_base+self.args.HU_range)-self.args.HU_base)/self.args.HU_range-1.0).astype(np.float32) self.seed.W.array = xp.reshape(img,(1,1,self.args.crop_height,self.args.crop_width)) else: # initializers.Uniform(scale=0.5)(self.seed.W.array) initializers.HeNormal()(self.seed.W.array) self.initial_seed = self.seed.W.array.copy() # print(xp.min(self.initial_seed),xp.max(self.initial_seed),xp.mean(self.initial_seed)) ## for seed array arr = self.seed() HU = self.var2HU(arr) raw = self.HU2raw(HU) self.seed.cleargrads() loss_seed = Variable(xp.array([0.0],dtype=np.float32)) # conjugate correction using system matrix if self.args.lambda_sd > 0: self.seed.W.grad = xp.zeros_like(self.seed.W.array) loss_sd = 0 for i in range(len(self.prImg)): if self.rev[i]: rec_sd = F.exp(-F.sparse_matmul(self.prMats[osem_step],F.reshape(raw[i,:,::-1,::-1],(-1,1)))) ## else: rec_sd = F.exp(-F.sparse_matmul(self.prMats[osem_step],F.reshape(raw[i],(-1,1)))) ## if self.args.log: loss_sd += F.mean_squared_error(F.log(rec_sd),F.log(self.prImg[i][osem_step])) else: loss_sd += F.mean_squared_error(rec_sd,self.prImg[i][osem_step]) if self.args.system_matrix: gd = F.sparse_matmul( self.conjMats[osem_step], rec_sd-self.prImg[i][osem_step], transa=True) if self.rev[i]: self.seed.W.grad[i] -= self.args.lambda_sd * F.reshape(gd, (1,self.args.crop_height,self.args.crop_width)).array[:,::-1,::-1] # / logrep.shape[0] ? else: self.seed.W.grad[i] -= self.args.lambda_sd * F.reshape(gd, (1,self.args.crop_height,self.args.crop_width)).array # / logrep.shape[0] ? if not self.args.system_matrix: (self.args.lambda_sd *loss_sd).backward() chainer.report({'loss_sd': loss_sd/len(self.prImg)}, self.seed) if self.args.lambda_tvs > 0: loss_tvs = losses.total_variation(arr, tau=self.args.tv_tau, method=self.args.tv_method) loss_seed += self.args.lambda_tvs * loss_tvs chainer.report({'loss_tvs': loss_tvs}, self.seed) if self.args.lambda_advs>0: L_advs = F.average( (self.dis(arr)-1.0)**2 ) loss_seed += self.args.lambda_advs * L_advs chainer.report({'loss_advs': L_advs}, self.seed) ## generator output arr_n = losses.add_noise(arr,self.args.noise_gen) if self.args.no_train_seed: arr_n.unchain() if not self.args.decoder_only: arr_n = self.encoder(arr_n) gen = self.decoder(arr_n) # range = [-1,1] ## generator loss loss_gen = Variable(xp.array([0.0],dtype=np.float32)) plan, plan_ae = None, None if self.args.lambda_ae1>0 or self.args.lambda_ae2>0: plan = losses.add_noise(Variable(self.converter(self.get_iterator('planct').next(), self.device)), self.args.noise_dis) plan_enc = self.encoder(plan) plan_ae = self.decoder(plan_enc) loss_ae1 = F.mean_absolute_error(plan,plan_ae) loss_ae2 = F.mean_squared_error(plan,plan_ae) if self.args.lambda_reg>0: loss_reg_ae = losses.loss_func_reg(plan_enc[-1],'l2') chainer.report({'loss_reg_ae': loss_reg_ae}, self.seed) loss_gen += self.args.lambda_reg * loss_reg_ae loss_gen += self.args.lambda_ae1 * loss_ae1 + self.args.lambda_ae2 * loss_ae2 chainer.report({'loss_ae1': loss_ae1}, self.seed) chainer.report({'loss_ae2': loss_ae2}, self.seed) if self.args.lambda_tv > 0: L_tv = losses.total_variation(gen, tau=self.args.tv_tau, method=self.args.tv_method) loss_gen += self.args.lambda_tv * L_tv chainer.report({'loss_tv': L_tv}, self.seed) if self.args.lambda_adv>0: L_adv = F.average( (self.dis(gen)-1.0)**2 ) loss_gen += self.args.lambda_adv * L_adv chainer.report({'loss_adv': L_adv}, self.seed) ## regularisation on the latent space if self.args.lambda_reg>0: loss_reg = losses.loss_func_reg(arr_n[-1],'l2') chainer.report({'loss_reg': loss_reg}, self.seed) loss_gen += self.args.lambda_reg * loss_reg self.encoder.cleargrads() self.decoder.cleargrads() loss_gen.backward() loss_seed.backward() chainer.report({'loss_gen': loss_gen}, self.seed) optimizer_enc.update() optimizer_dec.update() optimizer_sd.update() chainer.report({'grad_sd': F.average(F.absolute(self.seed.W.grad))}, self.seed) if hasattr(self.decoder, 'latent_fc'): chainer.report({'grad_gen': F.average(F.absolute(self.decoder.latent_fc.W.grad))}, self.seed) # reconstruction consistency for NN if (step % self.recon_freq == 0) and self.args.lambda_nn>0: self.encoder.cleargrads() self.decoder.cleargrads() self.seed.cleargrads() gen.grad = xp.zeros_like(gen.array) HU_nn = self.var2HU(gen) raw_nn = self.HU2raw(HU_nn) loss_nn = 0 for i in range(len(self.prImg)): if self.rev[i]: rec_nn = F.exp(-F.sparse_matmul(self.prMats[osem_step],F.reshape(raw_nn[i,:,::-1,::-1],(-1,1)))) else: rec_nn = F.exp(-F.sparse_matmul(self.prMats[osem_step],F.reshape(raw_nn[i],(-1,1)))) loss_nn += F.mean_squared_error(rec_nn,self.prImg[i][osem_step]) if self.args.system_matrix: gd_nn = F.sparse_matmul( rec_nn-self.prImg[i][osem_step], self.conjMats[osem_step], transa=True ) if self.rev[i]: gen.grad[i] -= self.args.lambda_nn * F.reshape(gd_nn, (1,self.args.crop_height,self.args.crop_width)).array[:,::-1,::-1] else: gen.grad[i] -= self.args.lambda_nn * F.reshape(gd_nn, (1,self.args.crop_height,self.args.crop_width)).array chainer.report({'loss_nn': loss_nn/len(self.prImg)}, self.seed) if self.args.system_matrix: gen.backward() else: (self.args.lambda_nn * loss_nn).backward() if not self.args.no_train_seed: optimizer_sd.update() if not self.args.no_train_enc: optimizer_enc.update() if not self.args.no_train_dec: optimizer_dec.update() if self.seed.W.grad is not None: chainer.report({'grad_sd_consistency': F.average(F.absolute(self.seed.W.grad))}, self.seed) if hasattr(self.decoder, 'latent_fc'): chainer.report({'grad_gen_consistency': F.average(F.absolute(self.decoder.latent_fc.W.grad))}, self.seed) elif hasattr(self.decoder, 'ul'): chainer.report({'grad_gen_consistency': F.average(F.absolute(self.decoder.ul.c1.c.W.grad))}, self.seed) chainer.report({'seed_diff': F.mean_absolute_error(self.initial_seed,self.seed.W)/F.mean_absolute_error(self.initial_seed,xp.zeros_like(self.initial_seed))}, self.seed) # clip seed to [-1,1] if self.args.clip: self.seed.W.array = xp.clip(self.seed.W.array,a_min=-1.0, a_max=1.0) # adjust consistency loss update frequency self.recon_freq = max(1,int(round(self.args.max_reconst_freq * (step-self.args.reconst_freq_decay_start) / (self.args.iter+1-self.args.reconst_freq_decay_start)))) ## for discriminator fake = None if self.args.dis_freq > 0 and ( (step+1) % self.args.dis_freq == 0) and (self.args.lambda_gan+self.args.lambda_adv+self.args.lambda_advs>0): # get mini-batch if plan is None: plan = self.converter(self.get_iterator('planct').next(), self.device) plan = losses.add_noise(Variable(plan),self.args.noise_dis) # create fake if self.args.lambda_gan>0: if self.args.decoder_only: fake_seed = xp.random.uniform(-1,1,(1,self.args.latent_dim)).astype(np.float32) else: fake_seed = self.encoder(xp.random.uniform(-1,1,(1,1,self.args.crop_height,self.args.crop_width)).astype(np.float32)) fake = self.decoder(fake_seed) # decoder self.decoder.cleargrads() loss_gan = F.average( (self.dis(fake)-1.0)**2 ) chainer.report({'loss_gan': loss_gan}, self.seed) loss_gan *= self.args.lambda_gan loss_gan.backward() optimizer_dec.update(loss=loss_gan) fake_copy = self._buffer.query(fake.array) if self.args.lambda_nn>0: fake_copy = self._buffer.query(self.converter(self.get_iterator('mvct').next(), self.device)) if (step+1) % (self.args.iter // 30): fake_copy = Variable(self._buffer.query(gen.array)) # discriminator L_real = F.average( (self.dis(plan)-1.0)**2 ) L_fake = F.average( self.dis(fake_copy)**2 ) loss_dis = 0.5*(L_real+L_fake) self.dis.cleargrads() loss_dis.backward() optimizer_dis.update() chainer.report({'loss_dis': (L_real+L_fake)/2}, self.seed) if ((self.iteration+1) % self.args.vis_freq == 0) or ((step+1)==self.args.iter): for i in range(self.args.batchsize): outlist=[] if not self.args.no_train_seed and not self.args.decoder_only: outlist.append((self.seed()[i],"0sd")) if plan_ae is not None: outlist.append((plan[i],'2pl')) outlist.append((plan_ae[i],'3ae')) if self.args.lambda_nn>0 or self.args.lambda_adv>0: if self.args.decoder_only: gen_img = self.decoder([self.seed()]) else: gen_img = self.decoder(self.encoder(self.seed())) outlist.append((gen_img[i],'1gn')) if fake is not None: outlist.append((fake[i],'4fa')) for out,typ in outlist: out.to_cpu() HU = (((out+1)/2 * self.args.HU_range)+self.args.HU_base).array # [-1000=air,0=water,>1000=bone] print("type: ",typ,"HU:",np.min(HU),np.mean(HU),np.max(HU)) #visimg = np.clip((out.array+1)/2,0,1) * 255.0 b,r = -self.args.HU_range_vis//2,self.args.HU_range_vis visimg = (np.clip(HU,b,b+r)-b)/r * 255.0 fn = 'n{:0>5}_iter{:0>6}_p{}_z{}_{}'.format(self.n_reconst,step+1,self.patient_id[i],self.slice[i],typ) write_image(np.uint8(visimg),os.path.join(self.args.out,fn+'.jpg')) if (step+1)==self.args.iter or (not self.args.no_save_dcm): #np.save(os.path.join(self.args.out,fn+'.npy'),HU[0]) write_dicom(os.path.join(self.args.out,fn+'.dcm'),HU[0])
def update_core(self): opt_enc_x = self.get_optimizer('opt_enc_x') opt_dec_x = self.get_optimizer('opt_dec_x') opt_enc_y = self.get_optimizer('opt_enc_y') opt_dec_y = self.get_optimizer('opt_dec_y') opt_x = self.get_optimizer('opt_x') opt_y = self.get_optimizer('opt_y') opt_z = self.get_optimizer('opt_z') self._iter += 1 if self.is_new_epoch and self.epoch >= self.args.lrdecay_start: decay_step = self.init_alpha / self.args.lrdecay_period # print('lr decay', decay_step) if opt_enc_x.alpha > decay_step: opt_enc_x.alpha -= decay_step if opt_dec_x.alpha > decay_step: opt_dec_x.alpha -= decay_step if opt_enc_y.alpha > decay_step: opt_enc_y.alpha -= decay_step if opt_dec_y.alpha > decay_step: opt_dec_y.alpha -= decay_step if opt_y.alpha > decay_step: opt_y.alpha -= decay_step if opt_x.alpha > decay_step: opt_x.alpha -= decay_step if opt_z.alpha > decay_step: opt_z.alpha -= decay_step # get mini-batch batch_x = self.get_iterator('main').next() batch_y = self.get_iterator('train_B').next() x = Variable(self.converter(batch_x, self.device)) y = Variable(self.converter(batch_y, self.device)) # to latent x_z = self.enc_x(losses.add_noise( x, sigma=self.args.noise)) # noise injection y_z = self.enc_y(losses.add_noise(y, sigma=self.args.noise)) loss_gen = 0 ## regularisation on the latent space if self.args.lambda_reg > 0: loss_reg_enc_y = losses.loss_func_reg(y_z[-1], 'l2') loss_reg_enc_x = losses.loss_func_reg(x_z[-1], 'l2') loss_gen = loss_gen + self.args.lambda_reg * (loss_reg_enc_x + loss_reg_enc_y) if self.report_start < self._iter: chainer.report({'loss_reg': loss_reg_enc_x}, self.enc_x) chainer.report({'loss_reg': loss_reg_enc_y}, self.enc_y) ## discriminator for the latent space: distribution of image of enc_x should look same as that of enc_y if self.args.lambda_dis_z > 0 and self._iter > self.args.dis_z_start: x_z_copy = Variable(self._buffer_xz.query(x_z[-1].data)) loss_enc_x_adv = losses.loss_func_comp(self.dis_z(x_z[-1]), 1.0) loss_gen = loss_gen + self.args.lambda_dis_z * loss_enc_x_adv if self.report_start < self._iter: chainer.report({'loss_adv': loss_enc_x_adv}, self.enc_x) # cycle for X=>Z=>X x_z_clean = x_z[-1].data.copy() x_z[-1] = losses.add_noise(x_z[-1], sigma=self.args.noise_z) x_x = self.dec_x(x_z) loss_cycle_xzx = losses.loss_avg(x_x, x, ksize=self.args.cycle_ksize, norm='l1') if self.report_start < self._iter: chainer.report({'loss_cycle': loss_cycle_xzx}, self.enc_x) # cycle for Y=>Z=>Y y_z_clean = y_z[-1].data.copy() y_z[-1] = losses.add_noise(y_z[-1], sigma=self.args.noise_z) y_y = self.dec_y(y_z) loss_cycle_yzy = losses.loss_avg(y_y, y, ksize=self.args.cycle_ksize, norm='l1') if self.report_start < self._iter: chainer.report({'loss_cycle': loss_cycle_yzy}, self.enc_y) loss_gen = loss_gen + self.args.lambda_A * loss_cycle_xzx + self.args.lambda_B * loss_cycle_yzy ## conversion x_y = self.dec_y(x_z) y_x = self.dec_x(y_z) # cycle for (X=>)Z=>Y=>Z x_y_z = self.enc_y(x_y) loss_cycle_zyz = F.mean_absolute_error(x_y_z[-1], x_z_clean) if self.report_start < self._iter: chainer.report({'loss_cycle': loss_cycle_zyz}, self.dec_y) # cycle for (Y=>)Z=>X=>Z y_x_z = self.enc_x(y_x) loss_cycle_zxz = F.mean_absolute_error(y_x_z[-1], y_z_clean) if self.report_start < self._iter: chainer.report({'loss_cycle': loss_cycle_zxz}, self.dec_x) loss_gen = loss_gen + self.args.lambda_A * loss_cycle_zxz + self.args.lambda_B * loss_cycle_zyz ## adversarial for Y if self.args.lambda_dis_y > 0: if self.args.conditional_discriminator: x_y_copy = Variable( self._buffer_y.query(F.concat([x, x_y]).data)) loss_dec_y_adv = losses.loss_func_comp( self.dis_y(F.concat([x, x_y])), 1.0) else: x_y_copy = Variable(self._buffer_y.query(x_y.data)) loss_dec_y_adv = losses.loss_func_comp(self.dis_y(x_y), 1.0) loss_gen = loss_gen + self.args.lambda_dis_y * loss_dec_y_adv if self.report_start < self._iter: chainer.report({'loss_adv': loss_dec_y_adv}, self.dec_y) ## adversarial for X if self.args.lambda_dis_x > 0: if self.args.conditional_discriminator: y_x_copy = Variable( self._buffer_x.query(F.concat([y, y_x]).data)) loss_dec_x_adv = losses.loss_func_comp( self.dis_x(F.concat([y, y_x])), 1.0) else: y_x_copy = Variable(self._buffer_x.query(y_x.data)) loss_dec_x_adv = losses.loss_func_comp(self.dis_x(y_x), 1.0) loss_gen = loss_gen + self.args.lambda_dis_x * loss_dec_x_adv if self.report_start < self._iter: chainer.report({'loss_adv': loss_dec_x_adv}, self.dec_x) ## images before/after conversion should look similar in terms of perceptual loss if self.args.lambda_identity_x > 0: loss_id_x = losses.loss_perceptual(x, x_y, self.vgg) loss_gen = loss_gen + self.args.lambda_identity_x * loss_id_x if self.report_start < self._iter: chainer.report({'loss_id': 1e-3 * loss_id_x}, self.enc_x) if self.args.lambda_identity_y > 0: loss_id_y = losses.loss_perceptual(y, y_x, self.vgg) loss_gen = loss_gen + self.args.lambda_identity_y * loss_id_y if self.report_start < self._iter: chainer.report({'loss_id': 1e-3 * loss_id_y}, self.enc_y) ## warm-up if self._iter < self.args.warmup: loss_gen += losses.loss_avg(y, y_x, ksize=self.args.id_ksize, norm='l2') loss_gen += losses.loss_avg(x, x_y, ksize=self.args.id_ksize, norm='l2') ## air should be -1 if self.args.lambda_air > 0: loss_air_x = losses.loss_range_comp(x, x_y, 0.9, norm='l2') loss_air_y = losses.loss_range_comp(y, y_x, 0.9, norm='l2') loss_gen = loss_gen + self.args.lambda_air * (loss_air_x + loss_air_y) if self.report_start < self._iter: chainer.report({'loss_air': 0.1 * loss_air_x}, self.dec_y) chainer.report({'loss_air': 0.1 * loss_air_y}, self.dec_x) ## images before/after conversion should look similar in the gradient domain if self.args.lambda_grad > 0: loss_grad_x = losses.loss_grad(x, x_y) loss_grad_y = losses.loss_grad(y, y_x) loss_gen = loss_gen + self.args.lambda_grad * (loss_grad_x + loss_grad_y) if self.report_start < self._iter: chainer.report({'loss_grad': loss_grad_x}, self.dec_y) chainer.report({'loss_grad': loss_grad_y}, self.dec_x) if self.args.lambda_tv > 0: loss_tv = losses.total_variation(x_y, self.args.tv_tau) loss_gen = loss_gen + self.args.lambda_tv * loss_tv if self.report_start < self._iter: chainer.report({'loss_tv': loss_tv}, self.dec_y) ## back propagate self.enc_x.cleargrads() self.dec_x.cleargrads() self.enc_y.cleargrads() self.dec_y.cleargrads() loss_gen.backward() opt_enc_x.update() opt_dec_x.update() if not self.args.single_encoder: opt_enc_y.update() opt_dec_y.update() ## discriminator for Y if self.args.lambda_dis_y > 0: loss_dis_y_fake = losses.loss_func_comp(self.dis_y(x_y_copy), 0.0, self.args.dis_jitter) if self.args.conditional_discriminator: loss_dis_y_real = losses.loss_func_comp( self.dis_y(F.concat([x, y])), 1.0, self.args.dis_jitter) else: loss_dis_y_real = losses.loss_func_comp( self.dis_y(y), 1.0, self.args.dis_jitter) loss_dis_y = (loss_dis_y_fake + loss_dis_y_real) * 0.5 if self.report_start < self._iter: chainer.report({'loss_fake': loss_dis_y_fake}, self.dis_y) chainer.report({'loss_real': loss_dis_y_real}, self.dis_y) self.dis_y.cleargrads() loss_dis_y.backward() opt_y.update() ## discriminator for X if self.args.lambda_dis_x > 0: loss_dis_x_fake = losses.loss_func_comp(self.dis_x(y_x_copy), 0.0, self.args.dis_jitter) if self.args.conditional_discriminator: loss_dis_x_real = losses.loss_func_comp( self.dis_x(F.concat([y, x])), 1.0, self.args.dis_jitter) else: loss_dis_x_real = losses.loss_func_comp( self.dis_x(x), 1.0, self.args.dis_jitter) loss_dis_x = (loss_dis_x_fake + loss_dis_x_real) * 0.5 if self.report_start < self._iter: chainer.report({'loss_fake': loss_dis_x_fake}, self.dis_x) chainer.report({'loss_real': loss_dis_x_real}, self.dis_x) self.dis_x.cleargrads() loss_dis_x.backward() opt_x.update() ## discriminator for latent if self.args.lambda_dis_z > 0 and self._iter > self.args.dis_z_start: loss_dis_z_x = losses.loss_func_comp(self.dis_z(x_z_copy), 0.0, self.args.dis_jitter) loss_dis_z_y = losses.loss_func_comp(self.dis_z(y_z[-1]), 1.0, self.args.dis_jitter) loss_dis_z = (loss_dis_z_x + loss_dis_z_y) * 0.5 if self.report_start < self._iter: chainer.report({'loss_x': loss_dis_z_x}, self.dis_z) chainer.report({'loss_y': loss_dis_z_y}, self.dis_z) self.dis_z.cleargrads() loss_dis_z.backward() opt_z.update()