コード例 #1
0
    def update_core(self):
        opt_g = self.get_optimizer('opt_g')
        opt_f = self.get_optimizer('opt_f')
        opt_x = self.get_optimizer('opt_x')
        opt_y = self.get_optimizer('opt_y')

        self._iter += 1
        # learning rate decay: TODO: weight_decay_rate of AdamW
        if self.is_new_epoch and self.epoch >= self.args.lrdecay_start:
            decay_step = self.init_alpha / self.args.lrdecay_period
            if opt_g.alpha > decay_step:
                opt_g.alpha -= decay_step
            if opt_f.alpha > decay_step:
                opt_f.alpha -= decay_step
            if opt_x.alpha > decay_step:
                opt_x.alpha -= decay_step
            if opt_y.alpha > decay_step:
                opt_y.alpha -= decay_step

        # get mini-batch
        batch_x = self.get_iterator('main').next()
        batch_y = self.get_iterator('train_B').next()
        x = Variable(self.converter(batch_x, self.device))
        y = Variable(self.converter(batch_y, self.device))

        ### generator
        # X => Y => X
        x_y = self.gen_g(losses.add_noise(x, sigma=self.args.noise))
        if self.args.conditional_discriminator:
            x_y_copy = Variable(self._buffer_y.query(F.concat([x, x_y]).data))
        else:
            x_y_copy = Variable(self._buffer_y.query(x_y.data))
        x_y_x = self.gen_f(x_y)
        loss_cycle_x = losses.loss_avg(x_y_x,
                                       x,
                                       ksize=self.args.cycle_ksize,
                                       norm='l1')

        loss_gen_g_adv = 0
        if self.args.gen_start < self._iter:
            if self.args.conditional_discriminator:
                if self.args.wgan:
                    loss_gen_g_adv = -F.average(self.dis_y(F.concat([x, x_y])))
                else:
                    loss_gen_g_adv = losses.loss_func_comp(
                        self.dis_y(F.concat([x, x_y])), 1.0)
            else:
                if self.args.wgan:
                    loss_gen_g_adv = -F.average(self.dis_y(x_y))
                else:
                    loss_gen_g_adv = losses.loss_func_comp(
                        self.dis_y(x_y), 1.0)

        # Y => X => Y
        loss_gen_f_adv = 0
        y_x = self.gen_f(losses.add_noise(
            y, sigma=self.args.noise))  # noise injection
        if self.args.conditional_discriminator:
            y_x_copy = Variable(self._buffer_x.query(F.concat([y, y_x]).data))
        else:
            y_x_copy = Variable(self._buffer_x.query(y_x.data))
        y_x_y = self.gen_g(y_x)
        loss_cycle_y = losses.loss_avg(y_x_y,
                                       y,
                                       ksize=self.args.cycle_ksize,
                                       norm='l1')
        if self.args.gen_start < self._iter:
            if self.args.conditional_discriminator:
                if self.args.wgan:
                    loss_gen_f_adv = -F.average(self.dis_x(F.concat([y, y_x])))
                else:
                    loss_gen_f_adv = losses.loss_func_comp(
                        self.dis_x(F.concat([y, y_x])), 1.0)
            else:
                if self.args.wgan:
                    loss_gen_f_adv = -F.average(self.dis_x(y_x))
                else:
                    loss_gen_f_adv = losses.loss_func_comp(
                        self.dis_x(y_x), 1.0)

        ## total loss for generators
        loss_gen = (self.args.lambda_dis_y * loss_gen_g_adv +
                    self.args.lambda_dis_x * loss_gen_f_adv) + (
                        self.args.lambda_A * loss_cycle_x +
                        self.args.lambda_B * loss_cycle_y)

        ## idempotence: f shouldn't change x
        if self.args.lambda_idempotence > 0:
            loss_idem_x = F.mean_absolute_error(y_x, self.gen_f(y_x))
            loss_idem_y = F.mean_absolute_error(x_y, self.gen_g(x_y))
            loss_gen = loss_gen + self.args.lambda_idempotence * (loss_idem_x +
                                                                  loss_idem_y)
            if self.report_start < self._iter:
                chainer.report({'loss_idem': loss_idem_x}, self.gen_f)
                chainer.report({'loss_idem': loss_idem_y}, self.gen_g)
        if self.args.lambda_domain > 0:
            loss_dom_x = F.mean_absolute_error(x, self.gen_f(x))
            loss_dom_y = F.mean_absolute_error(y, self.gen_g(y))
            if self._iter < self.args.warmup:
                loss_gen = loss_gen + max(self.args.lambda_domain,
                                          1.0) * (loss_dom_x + loss_dom_y)
            else:
                loss_gen = loss_gen + self.args.lambda_domain * (loss_dom_x +
                                                                 loss_dom_y)
            if self.report_start < self._iter:
                chainer.report({'loss_dom': loss_dom_x}, self.gen_f)
                chainer.report({'loss_dom': loss_dom_y}, self.gen_g)

        ## images before/after conversion should look similar in terms of perceptual loss
        if self.args.lambda_identity_x > 0:
            loss_id_x = losses.loss_perceptual(x, x_y, self.vgg)
            loss_gen = loss_gen + self.args.lambda_identity_x * loss_id_x
            if self.report_start < self._iter:
                chainer.report({'loss_id': 1e-3 * loss_id_x}, self.gen_g)
        if self.args.lambda_identity_y > 0:
            loss_id_y = losses.loss_perceptual(y, y_x, self.vgg)
            loss_gen = loss_gen + self.args.lambda_identity_y * loss_id_y
            if self.report_start < self._iter:
                chainer.report({'loss_id': 1e-3 * loss_id_y}, self.gen_f)

        ## warm-up
        if self._iter < self.args.warmup:
            loss_gen = loss_gen + F.mean_squared_error(x, self.gen_f(x))
            loss_gen = loss_gen + F.mean_squared_error(y, self.gen_g(y))
#            loss_gen = loss_gen + losses.loss_avg(y,y_x, ksize=self.args.id_ksize, norm='l2')
#            loss_gen = loss_gen + losses.loss_avg(x,x_y, ksize=self.args.id_ksize, norm='l2')

## background should be preserved
        if self.args.lambda_air > 0:
            loss_air_x = losses.loss_range_comp(x, x_y, 0.9, norm='l2')
            loss_air_y = losses.loss_range_comp(y, y_x, 0.9, norm='l2')
            loss_gen = loss_gen + self.args.lambda_air * (loss_air_x +
                                                          loss_air_y)
            if self.report_start < self._iter:
                chainer.report({'loss_air': 0.1 * loss_air_x}, self.gen_g)
                chainer.report({'loss_air': 0.1 * loss_air_y}, self.gen_f)

        ## comparison of images before/after conversion in the gradient domain
        if self.args.lambda_grad > 0:
            loss_grad_x = losses.loss_grad(x, x_y, self.args.grad_norm)
            loss_grad_y = losses.loss_grad(y, y_x, self.args.grad_norm)
            loss_gen = loss_gen + self.args.lambda_grad * (loss_grad_x +
                                                           loss_grad_y)
            if self.report_start < self._iter:
                chainer.report({'loss_grad': loss_grad_x}, self.gen_g)
                chainer.report({'loss_grad': loss_grad_y}, self.gen_f)
        ## total variation
        if self.args.lambda_tv > 0:
            loss_tv = losses.total_variation(x_y, self.args.tv_tau)
            loss_gen = loss_gen + self.args.lambda_tv * loss_tv
            if self.report_start < self._iter:
                chainer.report({'loss_tv': loss_tv}, self.gen_g)

        if self.report_start < self._iter:
            chainer.report({'loss_cycle_x': loss_cycle_x}, self.gen_f)
            chainer.report({'loss_adv': loss_gen_g_adv}, self.gen_g)
            chainer.report({'loss_cycle_y': loss_cycle_y}, self.gen_g)
            chainer.report({'loss_adv': loss_gen_f_adv}, self.gen_f)

        self.gen_f.cleargrads()
        self.gen_g.cleargrads()
        loss_gen.backward()
        opt_f.update()
        opt_g.update()

        ### discriminator
        for t in range(self.args.n_critics):
            if self.args.wgan:  ## synthesised -, real +
                loss_dis_y_fake = F.average(self.dis_y(x_y_copy))
                eps = self.xp.random.uniform(0, 1, size=len(batch_y)).astype(
                    self.xp.float32)[:, None, None, None]
                if self.args.conditional_discriminator:
                    loss_dis_y_real = -F.average(self.dis_y(F.concat([x, y])))
                    y_mid = eps * F.concat([y, x]) + (1.0 - eps) * x_y_copy
                else:
                    loss_dis_y_real = -F.average(self.dis_y(y))
                    y_mid = eps * y + (1.0 - eps) * x_y_copy
                # gradient penalty
                gd_y, = chainer.grad([self.dis_y(y_mid)], [y_mid],
                                     enable_double_backprop=True)
                gd_y = F.sqrt(F.batch_l2_norm_squared(gd_y) + 1e-6)
                loss_dis_y_gp = F.mean_squared_error(
                    gd_y, self.xp.ones_like(gd_y.data))

                if self.report_start < self._iter:
                    chainer.report({'loss_real': loss_dis_y_real}, self.dis_y)
                    chainer.report({'loss_fake': loss_dis_y_fake}, self.dis_y)
                    chainer.report(
                        {'loss_gp': self.args.lambda_wgan_gp * loss_dis_y_gp},
                        self.dis_y)
                loss_dis_y = (loss_dis_y_real + loss_dis_y_fake +
                              self.args.lambda_wgan_gp * loss_dis_y_gp)
                self.dis_y.cleargrads()
                loss_dis_y.backward()
                opt_y.update()

                ## discriminator for B=>A
                loss_dis_x_fake = F.average(self.dis_x(y_x_copy))
                if self.args.conditional_discriminator:
                    loss_dis_x_real = -F.average(
                        self.dis_x(
                            losses.add_noise(F.concat([y, x]),
                                             sigma=self.args.noise)))
                    x_mid = eps * F.concat([x, y]) + (1.0 - eps) * y_x_copy
                else:
                    loss_dis_x_real = -F.average(self.dis_x(x))
                    x_mid = eps * x + (1.0 - eps) * y_x_copy
                # gradient penalty
                gd_x, = chainer.grad([self.dis_x(x_mid)], [x_mid],
                                     enable_double_backprop=True)
                gd_x = F.sqrt(F.batch_l2_norm_squared(gd_x) + 1e-6)
                loss_dis_x_gp = F.mean_squared_error(
                    gd_x, self.xp.ones_like(gd_x.data))

                if self.report_start < self._iter:
                    chainer.report({'loss_real': loss_dis_x_real}, self.dis_x)
                    chainer.report({'loss_fake': loss_dis_x_fake}, self.dis_x)
                    chainer.report(
                        {'loss_gp': self.args.lambda_wgan_gp * loss_dis_x_gp},
                        self.dis_x)
                loss_dis_x = (loss_dis_x_real + loss_dis_x_fake +
                              self.args.lambda_wgan_gp * loss_dis_x_gp)
                self.dis_x.cleargrads()
                loss_dis_x.backward()
                opt_x.update()

            else:
                ## discriminator for A=>B (real:1, fake:0)
                loss_dis_y_fake = losses.loss_func_comp(
                    self.dis_y(x_y_copy), 0.0, self.args.dis_jitter)
                if self.args.conditional_discriminator:
                    loss_dis_y_real = losses.loss_func_comp(
                        self.dis_y(F.concat([x, y])), 1.0,
                        self.args.dis_jitter)
                else:
                    loss_dis_y_real = losses.loss_func_comp(
                        self.dis_y(y), 1.0, self.args.dis_jitter)
                loss_dis_y = (loss_dis_y_fake + loss_dis_y_real) * 0.5
                self.dis_y.cleargrads()
                loss_dis_y.backward()
                opt_y.update()

                ## discriminator for B=>A
                loss_dis_x_fake = losses.loss_func_comp(
                    self.dis_x(y_x_copy), 0.0, self.args.dis_jitter)
                if self.args.conditional_discriminator:
                    loss_dis_x_real = losses.loss_func_comp(
                        self.dis_x(F.concat([y, x])), 1.0,
                        self.args.dis_jitter)
                else:
                    loss_dis_x_real = losses.loss_func_comp(
                        self.dis_x(x), 1.0, self.args.dis_jitter)
                loss_dis_x = (loss_dis_x_fake + loss_dis_x_real) * 0.5
                self.dis_x.cleargrads()
                loss_dis_x.backward()
                opt_x.update()

                if self.report_start < self._iter:
                    chainer.report({'loss_real': loss_dis_x_real}, self.dis_x)
                    chainer.report({'loss_fake': loss_dis_x_fake}, self.dis_x)
                    chainer.report({'loss_real': loss_dis_y_real}, self.dis_y)
                    chainer.report({'loss_fake': loss_dis_y_fake}, self.dis_y)

            # prepare next images
            if (t < self.args.n_critics - 1):
                x_y_copy = Variable(
                    self.xp.concatenate(
                        random.sample(self._buffer_y.images,
                                      self.args.batch_size)))
                y_x_copy = Variable(
                    self.xp.concatenate(
                        random.sample(self._buffer_x.images,
                                      self.args.batch_size)))
                batch_x = self.get_iterator('main').next()
                batch_y = self.get_iterator('train_B').next()
                x = Variable(self.converter(batch_x, self.device))
                y = Variable(self.converter(batch_y, self.device))
コード例 #2
0
    def update_core(self):
        opt_enc_x = self.get_optimizer('opt_enc_x')
        opt_dec_x = self.get_optimizer('opt_dec_x')
        opt_enc_y = self.get_optimizer('opt_enc_y')
        opt_dec_y = self.get_optimizer('opt_dec_y')
        opt_x = self.get_optimizer('opt_x')
        opt_y = self.get_optimizer('opt_y')
        opt_z = self.get_optimizer('opt_z')

        # get mini-batch
        batch_x = self.get_iterator('main').next()
        batch_y = self.get_iterator('train_B').next()
        x = Variable(self.converter(batch_x, self.args.gpu[0]))
        y = Variable(self.converter(batch_y, self.args.gpu[0]))

        # encode to latent (X,Y => Z)
        x_z = self.enc_x(losses.add_noise(x, sigma=self.args.noise))
        y_z = self.enc_y(losses.add_noise(y, sigma=self.args.noise))

        loss_gen = 0
        ## regularisation on the latent space
        if self.args.lambda_reg > 0:
            loss_reg_enc_y = losses.loss_func_reg(y_z[-1], 'l2')
            loss_reg_enc_x = losses.loss_func_reg(x_z[-1], 'l2')
            loss_gen = loss_gen + self.args.lambda_reg * (loss_reg_enc_x +
                                                          loss_reg_enc_y)
            chainer.report({'loss_reg': loss_reg_enc_x}, self.enc_x)
            chainer.report({'loss_reg': loss_reg_enc_y}, self.enc_y)

        ## discriminator for the latent space: distribution of image of enc_x should look same as that of enc_y
        # since z is a list (for u-net), we use only the output of the last layer
        if self.args.lambda_dis_z > 0:
            if self.args.dis_wgan:
                loss_enc_x_adv = -F.average(self.dis_z(x_z[-1]))
                loss_enc_y_adv = F.average(self.dis_z(y_z[-1]))
            else:
                loss_enc_x_adv = losses.loss_func_comp(self.dis_z(x_z[-1]),
                                                       1.0)
                loss_enc_y_adv = losses.loss_func_comp(self.dis_z(y_z[-1]),
                                                       0.0)
            loss_gen = loss_gen + self.args.lambda_dis_z * (loss_enc_x_adv +
                                                            loss_enc_y_adv)
            chainer.report({'loss_adv': loss_enc_x_adv}, self.enc_x)
            chainer.report({'loss_adv': loss_enc_y_adv}, self.enc_y)

        # cycle for X=>Z=>X (Autoencoder)
        x_x = self.dec_x(x_z)
        loss_cycle_xzx = F.mean_absolute_error(x_x, x)
        chainer.report({'loss_cycle': loss_cycle_xzx}, self.enc_x)
        # cycle for Y=>Z=>Y (Autoencoder)
        y_y = self.dec_y(y_z)
        loss_cycle_yzy = F.mean_absolute_error(y_y, y)
        chainer.report({'loss_cycle': loss_cycle_yzy}, self.enc_y)
        loss_gen = loss_gen + self.args.lambda_Az * loss_cycle_xzx + self.args.lambda_Bz * loss_cycle_yzy

        ## decode from latent Z => Y,X
        x_y = self.dec_y(x_z)
        y_x = self.dec_x(y_z)

        # cycle for X=>Z=>Y=>Z=>X  (Z=>Y=>Z does not work well)
        x_y_x = self.dec_x(self.enc_y(x_y))
        loss_cycle_x = F.mean_absolute_error(x_y_x, x)
        chainer.report({'loss_cycle': loss_cycle_x}, self.dec_x)
        # cycle for Y=>Z=>X=>Z=>Y
        y_x_y = self.dec_y(self.enc_x(y_x))
        loss_cycle_y = F.mean_absolute_error(y_x_y, y)
        chainer.report({'loss_cycle': loss_cycle_y}, self.dec_y)
        loss_gen = loss_gen + self.args.lambda_A * loss_cycle_x + self.args.lambda_B * loss_cycle_y

        ## adversarial for Y
        if self.args.lambda_dis_y > 0:
            x_y_copy = Variable(self._buffer_y.query(x_y.data))
            if self.args.dis_wgan:
                loss_dec_y_adv = -F.average(self.dis_y(x_y))
            else:
                loss_dec_y_adv = losses.loss_func_comp(self.dis_y(x_y), 1.0)
            loss_gen = loss_gen + self.args.lambda_dis_y * loss_dec_y_adv
            chainer.report({'loss_adv': loss_dec_y_adv}, self.dec_y)
        ## adversarial for X
        if self.args.lambda_dis_x > 0:
            y_x_copy = Variable(self._buffer_x.query(y_x.data))
            if self.args.dis_wgan:
                loss_dec_x_adv = -F.average(self.dis_x(y_x))
            else:
                loss_dec_x_adv = losses.loss_func_comp(self.dis_x(y_x), 1.0)
            loss_gen = loss_gen + self.args.lambda_dis_x * loss_dec_x_adv
            chainer.report({'loss_adv': loss_dec_x_adv}, self.dec_x)

        ## idempotence
        if self.args.lambda_idempotence > 0:
            loss_idem_x = F.mean_absolute_error(y_x,
                                                self.dec_x(self.enc_y(y_x)))
            loss_idem_y = F.mean_absolute_error(x_y,
                                                self.dec_y(self.enc_x(x_y)))
            loss_gen = loss_gen + self.args.lambda_idempotence * (loss_idem_x +
                                                                  loss_idem_y)
            chainer.report({'loss_idem': loss_idem_x}, self.dec_x)
            chainer.report({'loss_idem': loss_idem_y}, self.dec_y)
        # Y => X shouldn't change X
        if self.args.lambda_domain > 0:
            loss_dom_x = F.mean_absolute_error(x, self.dec_x(self.enc_y(x)))
            loss_dom_y = F.mean_absolute_error(y, self.dec_y(self.enc_x(y)))
            loss_gen = loss_gen + self.args.lambda_domain * (loss_dom_x +
                                                             loss_dom_y)
            chainer.report({'loss_dom': loss_dom_x}, self.dec_x)
            chainer.report({'loss_dom': loss_dom_y}, self.dec_y)

        ## images before/after conversion should look similar in terms of perceptual loss
        if self.args.lambda_identity_x > 0:
            loss_id_x = losses.loss_perceptual(
                x,
                x_y,
                self.vgg,
                layer=self.args.perceptual_layer,
                grey=self.args.grey)
            loss_gen = loss_gen + self.args.lambda_identity_x * loss_id_x
            chainer.report({'loss_id': 1e-3 * loss_id_x}, self.enc_x)
        if self.args.lambda_identity_y > 0:
            loss_id_y = losses.loss_perceptual(
                y,
                y_x,
                self.vgg,
                layer=self.args.perceptual_layer,
                grey=self.args.grey)
            loss_gen = loss_gen + self.args.lambda_identity_y * loss_id_y
            chainer.report({'loss_id': 1e-3 * loss_id_y}, self.enc_y)
        ## background (pixels with value -1) should be preserved
        if self.args.lambda_air > 0:
            loss_air_x = losses.loss_comp_low(x,
                                              x_y,
                                              self.args.air_threshold,
                                              norm='l2')
            loss_air_y = losses.loss_comp_low(y,
                                              y_x,
                                              self.args.air_threshold,
                                              norm='l2')
            loss_gen = loss_gen + self.args.lambda_air * (loss_air_x +
                                                          loss_air_y)
            chainer.report({'loss_air': loss_air_x}, self.dec_y)
            chainer.report({'loss_air': loss_air_y}, self.dec_x)
        ## images before/after conversion should look similar in the gradient domain
        if self.args.lambda_grad > 0:
            loss_grad_x = losses.loss_grad(x, x_y)
            loss_grad_y = losses.loss_grad(y, y_x)
            loss_gen = loss_gen + self.args.lambda_grad * (loss_grad_x +
                                                           loss_grad_y)
            chainer.report({'loss_grad': loss_grad_x}, self.dec_y)
            chainer.report({'loss_grad': loss_grad_y}, self.dec_x)
        ## total variation (only for X -> Y)
        if self.args.lambda_tv > 0:
            loss_tv = losses.total_variation(x_y,
                                             tau=self.args.tv_tau,
                                             method=self.args.tv_method)
            if self.args.imgtype == "dcm" and self.args.num_slices > 1:
                loss_tv += losses.total_variation_ch(x_y)
            loss_gen = loss_gen + self.args.lambda_tv * loss_tv
            chainer.report({'loss_tv': loss_tv}, self.dec_y)

        ## back propagate
        self.enc_x.cleargrads()
        self.dec_x.cleargrads()
        self.enc_y.cleargrads()
        self.dec_y.cleargrads()
        loss_gen.backward()
        opt_enc_x.update(loss=loss_gen)
        opt_dec_x.update(loss=loss_gen)
        if not self.args.single_encoder:
            opt_enc_y.update(loss=loss_gen)
        opt_dec_y.update(loss=loss_gen)

        ##########################################
        ## discriminator for Y
        if self.args.dis_wgan:  ## synthesised -, real +
            eps = self.xp.random.uniform(0, 1, size=len(batch_y)).astype(
                self.xp.float32)[:, None, None, None]
            if self.args.lambda_dis_y > 0:
                ## discriminator for X=>Y
                loss_dis_y = F.average(self.dis_y(x_y_copy) - self.dis_y(y))
                y_mid = eps * y + (1.0 - eps) * x_y_copy
                # gradient penalty
                gd_y, = chainer.grad([self.dis_y(y_mid)], [y_mid],
                                     enable_double_backprop=True)
                gd_y = F.sqrt(F.batch_l2_norm_squared(gd_y) + 1e-6)
                loss_dis_y_gp = F.mean_squared_error(
                    gd_y, self.xp.ones_like(gd_y.data))
                chainer.report({'loss_dis': loss_dis_y}, self.dis_y)
                chainer.report(
                    {'loss_gp': self.args.lambda_wgan_gp * loss_dis_y_gp},
                    self.dis_y)
                loss_dis_y = loss_dis_y + self.args.lambda_wgan_gp * loss_dis_y_gp
                self.dis_y.cleargrads()
                loss_dis_y.backward()
                opt_y.update(loss=loss_dis_y)

            if self.args.lambda_dis_x > 0:
                ## discriminator for B=>A
                loss_dis_x = F.average(self.dis_x(y_x_copy) - self.dis_x(x))
                x_mid = eps * x + (1.0 - eps) * y_x_copy
                # gradient penalty
                gd_x, = chainer.grad([self.dis_x(x_mid)], [x_mid],
                                     enable_double_backprop=True)
                gd_x = F.sqrt(F.batch_l2_norm_squared(gd_x) + 1e-6)
                loss_dis_x_gp = F.mean_squared_error(
                    gd_x, self.xp.ones_like(gd_x.data))
                chainer.report({'loss_dis': loss_dis_x}, self.dis_x)
                chainer.report(
                    {'loss_gp': self.args.lambda_wgan_gp * loss_dis_x_gp},
                    self.dis_x)
                loss_dis_x = loss_dis_x + self.args.lambda_wgan_gp * loss_dis_x_gp
                self.dis_x.cleargrads()
                loss_dis_x.backward()
                opt_x.update(loss=loss_dis_x)

            ## discriminator for latent: X -> Z is - while Y -> Z is +
            if self.args.lambda_dis_z > 0 and t == 0:
                loss_dis_z = F.average(
                    self.dis_z(x_z[-1]) - self.dis_z(y_z[-1]))
                z_mid = eps * x_z[-1] + (1.0 - eps) * y_z[-1]
                # gradient penalty
                gd_z, = chainer.grad([self.dis_z(z_mid)], [z_mid],
                                     enable_double_backprop=True)
                gd_z = F.sqrt(F.batch_l2_norm_squared(gd_z) + 1e-6)
                loss_dis_z_gp = F.mean_squared_error(
                    gd_z, self.xp.ones_like(gd_z.data))
                chainer.report({'loss_dis': loss_dis_z}, self.dis_z)
                chainer.report(
                    {'loss_gp': self.args.lambda_wgan_gp * loss_dis_y_gp},
                    self.dis_y)
                loss_dis_z = loss_dis_z + self.args.lambda_wgan_gp * loss_dis_z_gp
                self.dis_z.cleargrads()
                loss_dis_z.backward()
                opt_z.update(loss=loss_dis_z)

        else:  ## LSGAN
            if self.args.lambda_dis_y > 0:
                ## discriminator for A=>B (real:1, fake:0)
                disy_fake = self.dis_y(x_y_copy)
                loss_dis_y_fake = losses.loss_func_comp(
                    disy_fake, 0.0, self.args.dis_jitter)
                disy_real = self.dis_y(y)
                loss_dis_y_real = losses.loss_func_comp(
                    disy_real, 1.0, self.args.dis_jitter)
                if self.args.dis_reg_weighting > 0:  ## regularization
                    loss_dis_y_reg = (
                        F.average(F.absolute(disy_real[:, 1, :, :])) +
                        F.average(F.absolute(disy_fake[:, 1, :, :])))
                else:
                    loss_dis_y_reg = 0
                chainer.report({'loss_reg': loss_dis_y_reg}, self.dis_y)
                loss_dis_y_gp = 0
                chainer.report({'loss_fake': loss_dis_y_fake}, self.dis_y)
                chainer.report({'loss_real': loss_dis_y_real}, self.dis_y)
                loss_dis_y = (
                    loss_dis_y_fake + loss_dis_y_real
                ) * 0.5 + self.args.dis_reg_weighting * loss_dis_y_reg + self.args.lambda_wgan_gp * loss_dis_y_gp
                self.dis_y.cleargrads()
                loss_dis_y.backward()
                opt_y.update(loss=loss_dis_y)

            if self.args.lambda_dis_x > 0:
                ## discriminator for B=>A
                disx_fake = self.dis_x(y_x_copy)
                loss_dis_x_fake = losses.loss_func_comp(
                    disx_fake, 0.0, self.args.dis_jitter)
                disx_real = self.dis_x(x)
                loss_dis_x_real = losses.loss_func_comp(
                    disx_real, 1.0, self.args.dis_jitter)
                if self.args.dis_reg_weighting > 0:  ## regularization
                    loss_dis_x_reg = (
                        F.average(F.absolute(disx_fake[:, 1, :, :])) +
                        F.average(F.absolute(disx_real[:, 1, :, :])))
                else:
                    loss_dis_x_reg = 0
                chainer.report({'loss_reg': loss_dis_x_reg}, self.dis_x)
                loss_dis_x_gp = 0
                chainer.report({'loss_fake': loss_dis_x_fake}, self.dis_x)
                chainer.report({'loss_real': loss_dis_x_real}, self.dis_x)
                loss_dis_x = (
                    loss_dis_x_fake + loss_dis_x_real
                ) * 0.5 + self.args.dis_reg_weighting * loss_dis_x_reg + self.args.lambda_wgan_gp * loss_dis_x_gp
                self.dis_x.cleargrads()
                loss_dis_x.backward()
                opt_x.update(loss=loss_dis_x)

            ## discriminator for latent: X -> Z is 0.0 while Y -> Z is 1.0
            if self.args.lambda_dis_z > 0:
                disz_xz = self.dis_z(x_z[-1])
                loss_dis_z_x = losses.loss_func_comp(disz_xz, 0.0,
                                                     self.args.dis_jitter)
                disz_yz = self.dis_z(y_z[-1])
                loss_dis_z_y = losses.loss_func_comp(disz_yz, 1.0,
                                                     self.args.dis_jitter)
                if self.args.dis_reg_weighting > 0:  ## regularization
                    loss_dis_z_reg = (
                        F.average(F.absolute(disz_xz[:, 1, :, :])) +
                        F.average(F.absolute(disz_yz[:, 1, :, :])))
                else:
                    loss_dis_z_reg = 0
                chainer.report({'loss_x': loss_dis_z_x}, self.dis_z)
                chainer.report({'loss_y': loss_dis_z_y}, self.dis_z)
                chainer.report({'loss_reg': loss_dis_z_reg}, self.dis_z)
                loss_dis_z = (
                    loss_dis_z_x + loss_dis_z_y
                ) * 0.5 + self.args.dis_reg_weighting * loss_dis_z_reg
                self.dis_z.cleargrads()
                loss_dis_z.backward()
                opt_z.update(loss=loss_dis_z)
コード例 #3
0
    def evaluate(self):
        batch_x = self._iterators['main'].next()
        batch_y = self._iterators['testB'].next()
        models = self._targets
        if self.eval_hook:
            self.eval_hook(self)

        fig = plt.figure(figsize=(9, 3 * (len(batch_x) + len(batch_y))))
        gs = gridspec.GridSpec(len(batch_x) + len(batch_y),
                               3,
                               wspace=0.1,
                               hspace=0.1)

        x = Variable(self.converter(batch_x, self.device))
        y = Variable(self.converter(batch_y, self.device))

        with chainer.using_config('train', False):
            with chainer.function.no_backprop_mode():
                if len(models) > 2:
                    x_y = models['dec_y'](models['enc_x'](x))
                    if self.single_encoder:
                        x_y_x = models['dec_x'](models['enc_x'](x_y))
                    else:
                        x_y_x = models['dec_x'](
                            models['enc_x'](x))  ## autoencoder
                        #x_y_x = models['dec_x'](models['enc_y'](x_y))
                else:
                    x_y = models['gen_g'](x)
                    x_y_x = models['gen_f'](x_y)

#        for i, var in enumerate([x, x_y]):
        for i, var in enumerate([x, x_y, x_y_x]):
            imgs = postprocess(var).astype(np.float32)
            for j in range(len(imgs)):
                ax = fig.add_subplot(gs[j, i])
                if imgs[j].shape[2] == 1:
                    ax.imshow(imgs[j, :, :, 0],
                              interpolation='none',
                              cmap='gray',
                              vmin=0,
                              vmax=1)
                else:
                    ax.imshow(imgs[j], interpolation='none', vmin=0, vmax=1)
                ax.set_xticks([])
                ax.set_yticks([])

        with chainer.using_config('train', False):
            with chainer.function.no_backprop_mode():
                if len(models) > 2:
                    if self.single_encoder:
                        y_x = models['dec_x'](models['enc_x'](y))
                    else:
                        y_x = models['dec_x'](models['enc_y'](y))
#                    y_x_y = models['dec_y'](models['enc_y'](y))   ## autoencoder
                    y_x_y = models['dec_y'](models['enc_x'](y_x))
                else:  # (gen_g, gen_f)
                    y_x = models['gen_f'](y)
                    y_x_y = models['gen_g'](y_x)

#        for i, var in enumerate([y, y_y]):
        for i, var in enumerate([y, y_x, y_x_y]):
            imgs = postprocess(var).astype(np.float32)
            for j in range(len(imgs)):
                ax = fig.add_subplot(gs[j + len(batch_x), i])
                if imgs[j].shape[2] == 1:
                    ax.imshow(imgs[j, :, :, 0],
                              interpolation='none',
                              cmap='gray',
                              vmin=0,
                              vmax=1)
                else:
                    ax.imshow(imgs[j], interpolation='none', vmin=0, vmax=1)
                ax.set_xticks([])
                ax.set_yticks([])

        gs.tight_layout(fig)
        plt.savefig(os.path.join(self.vis_out,
                                 'count{:0>4}.jpg'.format(self.count)),
                    dpi=200)
        self.count += 1
        plt.close()

        cycle_y_l1 = F.mean_absolute_error(y, y_x_y)
        cycle_y_l2 = F.mean_squared_error(y, y_x_y)
        cycle_x_l1 = F.mean_absolute_error(x, x_y_x)
        id_xy_grad = losses.loss_grad(x, x_y)
        id_xy_l1 = F.mean_absolute_error(x, x_y)

        result = {
            "myval/cycle_y_l1": cycle_y_l1,
            "myval/cycle_y_l2": cycle_y_l2,
            "myval/cycle_x_l1": cycle_x_l1,
            "myval/id_xy_grad": id_xy_grad,
            "myval/id_xy_l1": id_xy_l1
        }
        return result
コード例 #4
0
    def update_core(self):
        opt_enc_x = self.get_optimizer('opt_enc_x')
        opt_dec_x = self.get_optimizer('opt_dec_x')
        opt_enc_y = self.get_optimizer('opt_enc_y')
        opt_dec_y = self.get_optimizer('opt_dec_y')
        opt_x = self.get_optimizer('opt_x')
        opt_y = self.get_optimizer('opt_y')
        opt_z = self.get_optimizer('opt_z')
        self._iter += 1
        if self.is_new_epoch and self.epoch >= self.args.lrdecay_start:
            decay_step = self.init_alpha / self.args.lrdecay_period
            #            print('lr decay', decay_step)
            if opt_enc_x.alpha > decay_step:
                opt_enc_x.alpha -= decay_step
            if opt_dec_x.alpha > decay_step:
                opt_dec_x.alpha -= decay_step
            if opt_enc_y.alpha > decay_step:
                opt_enc_y.alpha -= decay_step
            if opt_dec_y.alpha > decay_step:
                opt_dec_y.alpha -= decay_step
            if opt_y.alpha > decay_step:
                opt_y.alpha -= decay_step
            if opt_x.alpha > decay_step:
                opt_x.alpha -= decay_step
            if opt_z.alpha > decay_step:
                opt_z.alpha -= decay_step

        # get mini-batch
        batch_x = self.get_iterator('main').next()
        batch_y = self.get_iterator('train_B').next()
        x = Variable(self.converter(batch_x, self.device))
        y = Variable(self.converter(batch_y, self.device))

        # to latent
        x_z = self.enc_x(losses.add_noise(
            x, sigma=self.args.noise))  # noise injection
        y_z = self.enc_y(losses.add_noise(y, sigma=self.args.noise))

        loss_gen = 0
        ## regularisation on the latent space
        if self.args.lambda_reg > 0:
            loss_reg_enc_y = losses.loss_func_reg(y_z[-1], 'l2')
            loss_reg_enc_x = losses.loss_func_reg(x_z[-1], 'l2')
            loss_gen = loss_gen + self.args.lambda_reg * (loss_reg_enc_x +
                                                          loss_reg_enc_y)
            if self.report_start < self._iter:
                chainer.report({'loss_reg': loss_reg_enc_x}, self.enc_x)
                chainer.report({'loss_reg': loss_reg_enc_y}, self.enc_y)

        ## discriminator for the latent space: distribution of image of enc_x should look same as that of enc_y
        if self.args.lambda_dis_z > 0 and self._iter > self.args.dis_z_start:
            x_z_copy = Variable(self._buffer_xz.query(x_z[-1].data))
            loss_enc_x_adv = losses.loss_func_comp(self.dis_z(x_z[-1]), 1.0)
            loss_gen = loss_gen + self.args.lambda_dis_z * loss_enc_x_adv
            if self.report_start < self._iter:
                chainer.report({'loss_adv': loss_enc_x_adv}, self.enc_x)

        # cycle for X=>Z=>X
        x_z_clean = x_z[-1].data.copy()
        x_z[-1] = losses.add_noise(x_z[-1], sigma=self.args.noise_z)
        x_x = self.dec_x(x_z)
        loss_cycle_xzx = losses.loss_avg(x_x,
                                         x,
                                         ksize=self.args.cycle_ksize,
                                         norm='l1')
        if self.report_start < self._iter:
            chainer.report({'loss_cycle': loss_cycle_xzx}, self.enc_x)

        # cycle for Y=>Z=>Y
        y_z_clean = y_z[-1].data.copy()
        y_z[-1] = losses.add_noise(y_z[-1], sigma=self.args.noise_z)
        y_y = self.dec_y(y_z)
        loss_cycle_yzy = losses.loss_avg(y_y,
                                         y,
                                         ksize=self.args.cycle_ksize,
                                         norm='l1')
        if self.report_start < self._iter:
            chainer.report({'loss_cycle': loss_cycle_yzy}, self.enc_y)

        loss_gen = loss_gen + self.args.lambda_A * loss_cycle_xzx + self.args.lambda_B * loss_cycle_yzy

        ## conversion
        x_y = self.dec_y(x_z)
        y_x = self.dec_x(y_z)

        # cycle for (X=>)Z=>Y=>Z
        x_y_z = self.enc_y(x_y)
        loss_cycle_zyz = F.mean_absolute_error(x_y_z[-1], x_z_clean)
        if self.report_start < self._iter:
            chainer.report({'loss_cycle': loss_cycle_zyz}, self.dec_y)
        # cycle for (Y=>)Z=>X=>Z
        y_x_z = self.enc_x(y_x)
        loss_cycle_zxz = F.mean_absolute_error(y_x_z[-1], y_z_clean)
        if self.report_start < self._iter:
            chainer.report({'loss_cycle': loss_cycle_zxz}, self.dec_x)

        loss_gen = loss_gen + self.args.lambda_A * loss_cycle_zxz + self.args.lambda_B * loss_cycle_zyz

        ## adversarial for Y
        if self.args.lambda_dis_y > 0:
            if self.args.conditional_discriminator:
                x_y_copy = Variable(
                    self._buffer_y.query(F.concat([x, x_y]).data))
                loss_dec_y_adv = losses.loss_func_comp(
                    self.dis_y(F.concat([x, x_y])), 1.0)
            else:
                x_y_copy = Variable(self._buffer_y.query(x_y.data))
                loss_dec_y_adv = losses.loss_func_comp(self.dis_y(x_y), 1.0)
            loss_gen = loss_gen + self.args.lambda_dis_y * loss_dec_y_adv
            if self.report_start < self._iter:
                chainer.report({'loss_adv': loss_dec_y_adv}, self.dec_y)
        ## adversarial for X
        if self.args.lambda_dis_x > 0:
            if self.args.conditional_discriminator:
                y_x_copy = Variable(
                    self._buffer_x.query(F.concat([y, y_x]).data))
                loss_dec_x_adv = losses.loss_func_comp(
                    self.dis_x(F.concat([y, y_x])), 1.0)
            else:
                y_x_copy = Variable(self._buffer_x.query(y_x.data))
                loss_dec_x_adv = losses.loss_func_comp(self.dis_x(y_x), 1.0)
            loss_gen = loss_gen + self.args.lambda_dis_x * loss_dec_x_adv
            if self.report_start < self._iter:
                chainer.report({'loss_adv': loss_dec_x_adv}, self.dec_x)

        ## images before/after conversion should look similar in terms of perceptual loss
        if self.args.lambda_identity_x > 0:
            loss_id_x = losses.loss_perceptual(x, x_y, self.vgg)
            loss_gen = loss_gen + self.args.lambda_identity_x * loss_id_x
            if self.report_start < self._iter:
                chainer.report({'loss_id': 1e-3 * loss_id_x}, self.enc_x)
        if self.args.lambda_identity_y > 0:
            loss_id_y = losses.loss_perceptual(y, y_x, self.vgg)
            loss_gen = loss_gen + self.args.lambda_identity_y * loss_id_y
            if self.report_start < self._iter:
                chainer.report({'loss_id': 1e-3 * loss_id_y}, self.enc_y)
        ## warm-up
        if self._iter < self.args.warmup:
            loss_gen += losses.loss_avg(y,
                                        y_x,
                                        ksize=self.args.id_ksize,
                                        norm='l2')
            loss_gen += losses.loss_avg(x,
                                        x_y,
                                        ksize=self.args.id_ksize,
                                        norm='l2')
        ## air should be -1
        if self.args.lambda_air > 0:
            loss_air_x = losses.loss_range_comp(x, x_y, 0.9, norm='l2')
            loss_air_y = losses.loss_range_comp(y, y_x, 0.9, norm='l2')
            loss_gen = loss_gen + self.args.lambda_air * (loss_air_x +
                                                          loss_air_y)
            if self.report_start < self._iter:
                chainer.report({'loss_air': 0.1 * loss_air_x}, self.dec_y)
                chainer.report({'loss_air': 0.1 * loss_air_y}, self.dec_x)
        ## images before/after conversion should look similar in the gradient domain
        if self.args.lambda_grad > 0:
            loss_grad_x = losses.loss_grad(x, x_y)
            loss_grad_y = losses.loss_grad(y, y_x)
            loss_gen = loss_gen + self.args.lambda_grad * (loss_grad_x +
                                                           loss_grad_y)
            if self.report_start < self._iter:
                chainer.report({'loss_grad': loss_grad_x}, self.dec_y)
                chainer.report({'loss_grad': loss_grad_y}, self.dec_x)
        if self.args.lambda_tv > 0:
            loss_tv = losses.total_variation(x_y, self.args.tv_tau)
            loss_gen = loss_gen + self.args.lambda_tv * loss_tv
            if self.report_start < self._iter:
                chainer.report({'loss_tv': loss_tv}, self.dec_y)

        ## back propagate
        self.enc_x.cleargrads()
        self.dec_x.cleargrads()
        self.enc_y.cleargrads()
        self.dec_y.cleargrads()
        loss_gen.backward()
        opt_enc_x.update()
        opt_dec_x.update()
        if not self.args.single_encoder:
            opt_enc_y.update()
        opt_dec_y.update()

        ## discriminator for Y
        if self.args.lambda_dis_y > 0:
            loss_dis_y_fake = losses.loss_func_comp(self.dis_y(x_y_copy), 0.0,
                                                    self.args.dis_jitter)
            if self.args.conditional_discriminator:
                loss_dis_y_real = losses.loss_func_comp(
                    self.dis_y(F.concat([x, y])), 1.0, self.args.dis_jitter)
            else:
                loss_dis_y_real = losses.loss_func_comp(
                    self.dis_y(y), 1.0, self.args.dis_jitter)
            loss_dis_y = (loss_dis_y_fake + loss_dis_y_real) * 0.5
            if self.report_start < self._iter:
                chainer.report({'loss_fake': loss_dis_y_fake}, self.dis_y)
                chainer.report({'loss_real': loss_dis_y_real}, self.dis_y)
            self.dis_y.cleargrads()
            loss_dis_y.backward()
            opt_y.update()

        ## discriminator for X
        if self.args.lambda_dis_x > 0:
            loss_dis_x_fake = losses.loss_func_comp(self.dis_x(y_x_copy), 0.0,
                                                    self.args.dis_jitter)
            if self.args.conditional_discriminator:
                loss_dis_x_real = losses.loss_func_comp(
                    self.dis_x(F.concat([y, x])), 1.0, self.args.dis_jitter)
            else:
                loss_dis_x_real = losses.loss_func_comp(
                    self.dis_x(x), 1.0, self.args.dis_jitter)
            loss_dis_x = (loss_dis_x_fake + loss_dis_x_real) * 0.5
            if self.report_start < self._iter:
                chainer.report({'loss_fake': loss_dis_x_fake}, self.dis_x)
                chainer.report({'loss_real': loss_dis_x_real}, self.dis_x)
            self.dis_x.cleargrads()
            loss_dis_x.backward()
            opt_x.update()

        ## discriminator for latent
        if self.args.lambda_dis_z > 0 and self._iter > self.args.dis_z_start:
            loss_dis_z_x = losses.loss_func_comp(self.dis_z(x_z_copy), 0.0,
                                                 self.args.dis_jitter)
            loss_dis_z_y = losses.loss_func_comp(self.dis_z(y_z[-1]), 1.0,
                                                 self.args.dis_jitter)
            loss_dis_z = (loss_dis_z_x + loss_dis_z_y) * 0.5
            if self.report_start < self._iter:
                chainer.report({'loss_x': loss_dis_z_x}, self.dis_z)
                chainer.report({'loss_y': loss_dis_z_y}, self.dis_z)
            self.dis_z.cleargrads()
            loss_dis_z.backward()
            opt_z.update()