コード例 #1
0
    def build_model(self):
        """ Graph """
        if self.phase == 'train':
            self.d_loss_per_res = {}
            self.g_loss_per_res = {}
            self.generator_optim = {}
            self.discriminator_optim = {}
            self.alpha_summary_per_res = {}
            self.d_summary_per_res = {}
            self.g_summary_per_res = {}
            self.train_fake_images = {}

            for res in self.resolutions[self.resolutions.index(self.start_res
                                                               ):]:
                g_loss_per_gpu = []
                d_loss_per_gpu = []
                train_fake_images_per_gpu = []

                batch_size = self.batch_sizes.get(res, self.batch_size_base)
                global_step = tf.get_variable(
                    'global_step_{}'.format(res),
                    shape=[],
                    dtype=tf.float32,
                    initializer=tf.initializers.zeros(),
                    trainable=False,
                    aggregation=tf.VariableAggregation.ONLY_FIRST_TOWER)
                alpha_const, zero_constant = get_alpha_const(
                    self.iteration // 2, batch_size * self.gpu_num,
                    global_step)

                # smooth transition variable
                do_train_trans = self.train_with_trans[res]

                alpha = tf.get_variable(
                    'alpha_{}'.format(res),
                    shape=[],
                    dtype=tf.float32,
                    initializer=tf.initializers.ones()
                    if do_train_trans else tf.initializers.zeros(),
                    trainable=False,
                    aggregation=tf.VariableAggregation.ONLY_FIRST_TOWER)

                if do_train_trans:
                    alpha_assign_op = tf.assign(alpha, alpha_const)
                else:
                    alpha_assign_op = tf.assign(alpha, zero_constant)

                image_class = ImageData(res)
                main_ds = tf.data.Dataset.from_tensor_slices(self.dataset)
                main_ds = main_ds. \
                    apply(shuffle_and_repeat(self.dataset_num)). \
                    apply(map_and_batch(image_class.image_processing, batch_size, num_parallel_batches=32, drop_remainder=True)). \
                    prefetch(self.gpu_num * 256)

                with tf.control_dependencies([alpha_assign_op]):
                    for gpu_id in range(self.gpu_num):
                        with tf.device(
                                tf.DeviceSpec(device_type="GPU",
                                              device_index=gpu_id)):
                            with tf.variable_scope(tf.get_variable_scope(),
                                                   reuse=(gpu_id > 0)):
                                # images
                                gpu_device = '/gpu:{}'.format(gpu_id)

                                inputs = main_ds. \
                                    shard(self.gpu_num, gpu_id). \
                                    apply(prefetch_to_device(gpu_device, None))
                                # When using dataset.prefetch, use buffer_size=None to let it detect optimal buffer size

                                inputs_iterator = inputs.make_one_shot_iterator(
                                )

                                real_img = inputs_iterator.get_next()
                                z = tf.random_normal(
                                    shape=[batch_size, self.z_dim])

                                fake_img = self.generator(z, alpha, res)
                                real_img = smooth_crossfade(real_img, alpha)

                                real_logit = self.discriminator(
                                    real_img, alpha, res)
                                fake_logit = self.discriminator(
                                    fake_img, alpha, res)

                                # compute loss
                                d_loss, g_loss = compute_loss(
                                    real_img, real_logit, fake_logit)

                                d_loss_per_gpu.append(d_loss)
                                g_loss_per_gpu.append(g_loss)
                                train_fake_images_per_gpu.append(fake_img)

                print("Create graph for {} resolution".format(res))

                # prepare appropriate training vars
                d_vars, g_vars = filter_trainable_variables(res)

                d_loss = tf.reduce_mean(d_loss_per_gpu)
                g_loss = tf.reduce_mean(g_loss_per_gpu)

                d_lr = self.d_learning_rates.get(res, self.learning_rate_base)
                g_lr = self.g_learning_rates.get(res, self.learning_rate_base)

                if self.gpu_num == 1:
                    colocate_grad = False
                else:
                    colocate_grad = True

                d_optim = tf.train.AdamOptimizer(
                    d_lr, beta1=0.0, beta2=0.99, epsilon=1e-8).minimize(
                        d_loss,
                        var_list=d_vars,
                        colocate_gradients_with_ops=colocate_grad)

                g_optim = tf.train.AdamOptimizer(
                    g_lr, beta1=0.0, beta2=0.99, epsilon=1e-8).minimize(
                        g_loss,
                        var_list=g_vars,
                        global_step=global_step,
                        colocate_gradients_with_ops=colocate_grad)

                self.discriminator_optim[res] = d_optim
                self.generator_optim[res] = g_optim

                self.d_loss_per_res[res] = d_loss
                self.g_loss_per_res[res] = g_loss

                self.train_fake_images[res] = tf.concat(
                    train_fake_images_per_gpu, axis=0)
                """ Summary """
                self.alpha_summary_per_res[res] = tf.summary.scalar(
                    "alpha_{}".format(res), alpha)

                self.d_summary_per_res[res] = tf.summary.scalar(
                    "d_loss_{}".format(res), self.d_loss_per_res[res])
                self.g_summary_per_res[res] = tf.summary.scalar(
                    "g_loss_{}".format(res), self.g_loss_per_res[res])

        else:
            """" Testing """
            test_z = tf.random_normal(shape=[self.batch_size, self.z_dim])
            alpha = tf.constant(0.0, dtype=tf.float32, shape=[])
            self.fake_images = self.generator(test_z,
                                              alpha=alpha,
                                              target_img_size=self.img_size,
                                              is_training=False)
コード例 #2
0
ファイル: UGATIT.py プロジェクト: jqueguiner/ai-api-template
    def build_model(self):
        if self.phase == 'train':
            self.lr = tf.placeholder(tf.float32, name='learning_rate')
            """ Input Image"""
            Image_Data_Class = ImageData(self.img_size, self.img_ch,
                                         self.augment_flag)

            trainA = tf.data.Dataset.from_tensor_slices(self.trainA_dataset)
            trainB = tf.data.Dataset.from_tensor_slices(self.trainB_dataset)

            gpu_device = '/gpu:0'
            trainA = trainA.apply(shuffle_and_repeat(self.dataset_num)).apply(
                map_and_batch(Image_Data_Class.image_processing,
                              self.batch_size,
                              num_parallel_batches=16,
                              drop_remainder=True)).apply(
                                  prefetch_to_device(gpu_device, None))
            trainB = trainB.apply(shuffle_and_repeat(self.dataset_num)).apply(
                map_and_batch(Image_Data_Class.image_processing,
                              self.batch_size,
                              num_parallel_batches=16,
                              drop_remainder=True)).apply(
                                  prefetch_to_device(gpu_device, None))

            trainA_iterator = trainA.make_one_shot_iterator()
            trainB_iterator = trainB.make_one_shot_iterator()

            self.domain_A = trainA_iterator.get_next()
            self.domain_B = trainB_iterator.get_next()
            """ Define Generator, Discriminator """
            x_ab, cam_ab = self.generate_a2b(self.domain_A)  # real a
            x_ba, cam_ba = self.generate_b2a(self.domain_B)  # real b

            x_aba, _ = self.generate_b2a(x_ab, reuse=True)  # real b
            x_bab, _ = self.generate_a2b(x_ba, reuse=True)  # real a

            x_aa, cam_aa = self.generate_b2a(self.domain_A,
                                             reuse=True)  # fake b
            x_bb, cam_bb = self.generate_a2b(self.domain_B,
                                             reuse=True)  # fake a

            real_A_logit, real_A_cam_logit, real_B_logit, real_B_cam_logit = self.discriminate_real(
                self.domain_A, self.domain_B)
            fake_A_logit, fake_A_cam_logit, fake_B_logit, fake_B_cam_logit = self.discriminate_fake(
                x_ba, x_ab)
            """ Define Loss """
            if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan':
                GP_A, GP_CAM_A = self.gradient_panalty(real=self.domain_A,
                                                       fake=x_ba,
                                                       scope="discriminator_A")
                GP_B, GP_CAM_B = self.gradient_panalty(real=self.domain_B,
                                                       fake=x_ab,
                                                       scope="discriminator_B")
            else:
                GP_A, GP_CAM_A = 0, 0
                GP_B, GP_CAM_B = 0, 0

            G_ad_loss_A = (generator_loss(self.gan_type, fake_A_logit) +
                           generator_loss(self.gan_type, fake_A_cam_logit))
            G_ad_loss_B = (generator_loss(self.gan_type, fake_B_logit) +
                           generator_loss(self.gan_type, fake_B_cam_logit))

            D_ad_loss_A = (
                discriminator_loss(self.gan_type, real_A_logit, fake_A_logit) +
                discriminator_loss(self.gan_type, real_A_cam_logit,
                                   fake_A_cam_logit) + GP_A + GP_CAM_A)
            D_ad_loss_B = (
                discriminator_loss(self.gan_type, real_B_logit, fake_B_logit) +
                discriminator_loss(self.gan_type, real_B_cam_logit,
                                   fake_B_cam_logit) + GP_B + GP_CAM_B)

            reconstruction_A = L1_loss(x_aba, self.domain_A)  # reconstruction
            reconstruction_B = L1_loss(x_bab, self.domain_B)  # reconstruction

            identity_A = L1_loss(x_aa, self.domain_A)
            identity_B = L1_loss(x_bb, self.domain_B)

            cam_A = cam_loss(source=cam_ba, non_source=cam_aa)
            cam_B = cam_loss(source=cam_ab, non_source=cam_bb)

            Generator_A_gan = self.adv_weight * G_ad_loss_A
            Generator_A_cycle = self.cycle_weight * reconstruction_B
            Generator_A_identity = self.identity_weight * identity_A
            Generator_A_cam = self.cam_weight * cam_A

            Generator_B_gan = self.adv_weight * G_ad_loss_B
            Generator_B_cycle = self.cycle_weight * reconstruction_A
            Generator_B_identity = self.identity_weight * identity_B
            Generator_B_cam = self.cam_weight * cam_B

            Generator_A_loss = Generator_A_gan + Generator_A_cycle + Generator_A_identity + Generator_A_cam
            Generator_B_loss = Generator_B_gan + Generator_B_cycle + Generator_B_identity + Generator_B_cam

            Discriminator_A_loss = self.adv_weight * D_ad_loss_A
            Discriminator_B_loss = self.adv_weight * D_ad_loss_B

            self.Generator_loss = Generator_A_loss + Generator_B_loss + regularization_loss(
                'generator')
            self.Discriminator_loss = Discriminator_A_loss + Discriminator_B_loss + regularization_loss(
                'discriminator')
            """ Result Image """
            self.fake_A = x_ba
            self.fake_B = x_ab

            self.real_A = self.domain_A
            self.real_B = self.domain_B
            """ Training """
            t_vars = tf.trainable_variables()
            G_vars = [var for var in t_vars if 'generator' in var.name]
            D_vars = [var for var in t_vars if 'discriminator' in var.name]

            self.G_optim = tf.train.AdamOptimizer(self.lr,
                                                  beta1=0.5,
                                                  beta2=0.999).minimize(
                                                      self.Generator_loss,
                                                      var_list=G_vars)
            self.D_optim = tf.train.AdamOptimizer(self.lr,
                                                  beta1=0.5,
                                                  beta2=0.999).minimize(
                                                      self.Discriminator_loss,
                                                      var_list=D_vars)
            """" Summary """
            self.all_G_loss = tf.summary.scalar("Generator_loss",
                                                self.Generator_loss)
            self.all_D_loss = tf.summary.scalar("Discriminator_loss",
                                                self.Discriminator_loss)

            self.G_A_loss = tf.summary.scalar("G_A_loss", Generator_A_loss)
            self.G_A_gan = tf.summary.scalar("G_A_gan", Generator_A_gan)
            self.G_A_cycle = tf.summary.scalar("G_A_cycle", Generator_A_cycle)
            self.G_A_identity = tf.summary.scalar("G_A_identity",
                                                  Generator_A_identity)
            self.G_A_cam = tf.summary.scalar("G_A_cam", Generator_A_cam)

            self.G_B_loss = tf.summary.scalar("G_B_loss", Generator_B_loss)
            self.G_B_gan = tf.summary.scalar("G_B_gan", Generator_B_gan)
            self.G_B_cycle = tf.summary.scalar("G_B_cycle", Generator_B_cycle)
            self.G_B_identity = tf.summary.scalar("G_B_identity",
                                                  Generator_B_identity)
            self.G_B_cam = tf.summary.scalar("G_B_cam", Generator_B_cam)

            self.D_A_loss = tf.summary.scalar("D_A_loss", Discriminator_A_loss)
            self.D_B_loss = tf.summary.scalar("D_B_loss", Discriminator_B_loss)

            self.rho_var = []
            for var in tf.trainable_variables():
                if 'rho' in var.name:
                    self.rho_var.append(tf.summary.histogram(var.name, var))
                    self.rho_var.append(
                        tf.summary.scalar(var.name + "_min",
                                          tf.reduce_min(var)))
                    self.rho_var.append(
                        tf.summary.scalar(var.name + "_max",
                                          tf.reduce_max(var)))
                    self.rho_var.append(
                        tf.summary.scalar(var.name + "_mean",
                                          tf.reduce_mean(var)))

            g_summary_list = [
                self.G_A_loss, self.G_A_gan, self.G_A_cycle, self.G_A_identity,
                self.G_A_cam, self.G_B_loss, self.G_B_gan, self.G_B_cycle,
                self.G_B_identity, self.G_B_cam, self.all_G_loss
            ]

            g_summary_list.extend(self.rho_var)
            d_summary_list = [self.D_A_loss, self.D_B_loss, self.all_D_loss]

            self.G_loss = tf.summary.merge(g_summary_list)
            self.D_loss = tf.summary.merge(d_summary_list)

        else:
            """ Test """
            self.test_domain_A = tf.placeholder(
                tf.float32, [1, self.img_size, self.img_size, self.img_ch],
                name='test_domain_A')
            self.test_domain_B = tf.placeholder(
                tf.float32, [1, self.img_size, self.img_size, self.img_ch],
                name='test_domain_B')

            self.test_fake_B, _ = self.generate_a2b(self.test_domain_A)
            self.test_fake_A, _ = self.generate_b2a(self.test_domain_B)
コード例 #3
0
    def build_model(self):
        """ Graph Input """
        # images
        if self.custom_dataset:
            Image_Data_Class = ImageData(self.img_size,
                                         self.c_dim,
                                         crop_pos=self.crop_pos,
                                         zoom_range=self.zoom_range)
            inputs = tf.data.Dataset.from_tensor_slices(self.data)
            gpu_device = '/gpu:0'
            inputs = inputs.apply(shuffle_and_repeat(self.dataset_num)).apply(
                map_and_batch(Image_Data_Class.image_processing,
                              self.batch_size,
                              num_parallel_batches=16,
                              drop_remainder=True)).apply(
                                  prefetch_to_device(gpu_device,
                                                     self.batch_size))

            inputs_iterator = tf.compat.v1.data.make_one_shot_iterator(inputs)

            self.inputs = inputs_iterator.get_next()

        else:
            self.inputs = tf.compat.v1.placeholder(
                tf.float32,
                [self.batch_size, self.img_size, self.img_size, self.c_dim],
                name='real_images')

        # noises
        self.z = tf.compat.v1.placeholder(tf.float32,
                                          [self.batch_size, 1, 1, self.z_dim],
                                          name='z')
        """ Loss Function """
        # output of D for real images
        real_logits = self.discriminator(self.inputs)

        # output of D for fake images
        fake_images = self.generator(self.z)
        fake_logits = self.discriminator(fake_images, reuse=True)

        if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan':
            GP = self.gradient_penalty(real=self.inputs, fake=fake_images)
        else:
            GP = 0

        # get loss for discriminator
        self.d_loss = discriminator_loss(
            self.gan_type, real=real_logits, fake=fake_logits) + GP

        # get loss for generator
        self.g_loss = generator_loss(self.gan_type, fake=fake_logits)
        """ Training """
        # divide trainable variables into a group for D and a group for G
        t_vars = tf.compat.v1.trainable_variables()
        d_vars = [var for var in t_vars if 'discriminator' in var.name]
        g_vars = [var for var in t_vars if 'generator' in var.name]

        # optimizers
        self.d_optim = tf.compat.v1.train.AdamOptimizer(
            self.d_learning_rate, beta1=self.beta1,
            beta2=self.beta2).minimize(self.d_loss, var_list=d_vars)
        self.g_optim = tf.compat.v1.train.AdamOptimizer(
            self.g_learning_rate, beta1=self.beta1,
            beta2=self.beta2).minimize(self.g_loss, var_list=g_vars)
        """" Testing """
        # for test
        self.fake_images = self.generator(self.z,
                                          is_training=False,
                                          reuse=True)
        """ Summary """
        self.d_sum = tf.compat.v1.summary.scalar("d_loss", self.d_loss)
        self.g_sum = tf.compat.v1.summary.scalar("g_loss", self.g_loss)
    def build_model(self):
        if self.custom_dataset:
            Image_Data_Class = ImageData(self.img_size, self.c_dim)
            inputs = tf.data.Dataset.from_tensor_slices(self.data)

            gpu_device = '/gpu:0'
            inputs = inputs.apply(shuffle_and_repeat(self.dataset_num)).apply(
                map_and_batch(Image_Data_Class.image_processing,
                              self.batch_size,
                              num_parallel_batches=16,
                              drop_remainder=True)).apply(
                                  prefetch_to_device(gpu_device,
                                                     self.batch_size))

            inputs_iterator = inputs.make_one_shot_iterator()

            self.inputs = inputs_iterator.get_next()

        else:
            self.inputs = tf.placeholder(
                tf.float32,
                [self.batch_size, self.img_size, self.img_size, self.c_dim],
                name='real_images')

        self.z = tf.placeholder(tf.float32,
                                [self.batch_size, 1, 1, self.z_dim],
                                name='z')

        real_logits = self.discriminator(self.inputs)

        fake_images = self.generator(self.z)
        fake_logits = self.discriminator(fake_images, reuse=True)

        if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan':
            GP = self.gradient_penalty(real=self.inputs, fake=fake_images)
        else:
            GP = 0

        self.d_loss = discriminator_loss(
            self.gan_type, real=real_logits, fake=fake_logits) + GP

        self.g_loss = generator_loss(self.gan_type, fake=fake_logits)

        t_vars = tf.trainable_variables()
        d_vars = [var for var in t_vars if 'discriminator' in var.name]
        g_vars = [var for var in t_vars if 'generator' in var.name]

        self.d_optim = tf.train.AdamOptimizer(self.d_learning_rate,
                                              beta1=self.beta1,
                                              beta2=self.beta2).minimize(
                                                  self.d_loss, var_list=d_vars)
        self.g_optim = tf.train.AdamOptimizer(self.g_learning_rate,
                                              beta1=self.beta1,
                                              beta2=self.beta2).minimize(
                                                  self.g_loss, var_list=g_vars)

        self.fake_images = self.generator(self.z,
                                          is_training=False,
                                          reuse=True)

        self.d_sum = tf.summary.scalar("d_loss", self.d_loss)
        self.g_sum = tf.summary.scalar("g_loss", self.g_loss)
コード例 #5
0
    def build_model(self):
        self.lr = tf.placeholder(tf.float32, name='learning_rate')


        """ Input Image"""
        Image_Data_Class = ImageData(self.img_size, self.img_ch, self.augment_flag)

        trainA = tf.data.Dataset.from_tensor_slices(self.trainA_dataset)
        trainB = tf.data.Dataset.from_tensor_slices(self.trainB_dataset)

        gpu_device = '/gpu:0'
        trainA = trainA.apply(shuffle_and_repeat(self.dataset_num)).apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply(prefetch_to_device(gpu_device, self.batch_size))
        trainB = trainB.apply(shuffle_and_repeat(self.dataset_num)).apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply(prefetch_to_device(gpu_device, self.batch_size))

        trainA_iterator = trainA.make_one_shot_iterator()
        trainB_iterator = trainB.make_one_shot_iterator()


        self.identity_A = trainA_iterator.get_next()
        self.shape_A = trainA_iterator.get_next()
        self.other_A = trainA_iterator.get_next()

        self.shape_B = trainB_iterator.get_next()


        self.test_identity_A = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='test_identity_A')
        self.test_shape_B = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='test_shape_B')


        """ Define Generator, Discriminator """
        self.fake_same = self.generator(x_identity=self.identity_A, x_shape=self.shape_A)
        self.fake_diff = self.generator(x_identity=self.identity_A, x_shape=self.shape_B, reuse=True)
        fake_diff_shape = self.generator(x_identity=self.shape_B, x_shape=self.fake_diff, reuse=True)
        fake_diff_identity = self.generator(x_identity=self.fake_diff, x_shape=self.shape_B, reuse=True)


        real_logit = self.discriminator(x_identity=self.identity_A, x=self.other_A)
        fake_logit = self.discriminator(x_identity=self.identity_A, x=self.fake_diff, reuse=True)


        """ Define Loss """
        g_identity_loss = self.adv_weight * generator_loss(self.gan_type, fake=minpool(fake_logit)) * 64
        g_shape_loss_same = self.L1_weight * L1_loss(self.fake_same, self.shape_A)
        g_shape_loss_diff_shape = self.L1_weight * L1_loss(fake_diff_shape, self.shape_B)
        g_shape_loss_diff_identity = self.L1_weight * L1_loss(fake_diff_identity, self.fake_diff)


        self.Generator_loss = g_identity_loss + g_shape_loss_same + g_shape_loss_diff_shape + g_shape_loss_diff_identity
        self.Discriminator_loss = self.adv_weight * discriminator_loss(self.gan_type, real=real_logit, fake=fake_logit)


        """ Result Image """
        self.test_fake = self.generator(x_identity=self.test_identity_A, x_shape=self.test_shape_B, reuse=True)


        """ Training """
        t_vars = tf.trainable_variables()
        G_vars = [var for var in t_vars if 'generator' in var.name]
        D_vars = [var for var in t_vars if 'discriminator' in var.name]

        self.G_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Generator_loss, var_list=G_vars)
        self.D_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Discriminator_loss, var_list=D_vars)


        """" Summary """
        self.G_loss = tf.summary.scalar("Generator_loss", self.Generator_loss)
        self.D_loss = tf.summary.scalar("Discriminator_loss", self.Discriminator_loss)


        self.G_identity = tf.summary.scalar("G_identity", g_identity_loss)
        self.G_shape_loss_same = tf.summary.scalar("G_shape_loss_same", g_shape_loss_same)
        self.G_shape_loss_diff_shape = tf.summary.scalar("G_shape_loss_diff_shape", g_shape_loss_diff_shape)
        self.G_shape_loss_diff_identity = tf.summary.scalar("G_shape_loss_diff_identity", g_shape_loss_diff_identity)


        self.G_loss_merge = tf.summary.merge([self.G_loss, self.G_identity, self.G_shape_loss_same, self.G_shape_loss_diff_shape, self.G_shape_loss_diff_identity])
        self.D_loss_merge = tf.summary.merge([self.D_loss])
コード例 #6
0
ファイル: SPADE.py プロジェクト: san-guy/SPADE-Tensorflow
    def build_model(self):
        self.lr = tf.placeholder(tf.float32, name='learning_rate')

        """ Input Image"""
        img_class = Image_data(self.img_height, self.img_width, self.img_ch, self.segmap_ch, self.dataset_path, self.augment_flag)
        img_class.preprocess()

        self.dataset_num = len(img_class.image)
        self.test_dataset_num = len(img_class.segmap_test)


        img_and_segmap = tf.data.Dataset.from_tensor_slices((img_class.image, img_class.segmap))
        segmap_test = tf.data.Dataset.from_tensor_slices(img_class.segmap_test)


        gpu_device = '/gpu:0'
        img_and_segmap = img_and_segmap.apply(shuffle_and_repeat(self.dataset_num)).apply(
            map_and_batch(img_class.image_processing, self.batch_size, num_parallel_batches=16,
                          drop_remainder=True)).apply(prefetch_to_device(gpu_device, self.batch_size))

        segmap_test = segmap_test.apply(shuffle_and_repeat(self.dataset_num)).apply(
            map_and_batch(img_class.test_image_processing, batch_size=self.batch_size, num_parallel_batches=16,
                          drop_remainder=True)).apply(prefetch_to_device(gpu_device, self.batch_size))


        img_and_segmap_iterator = img_and_segmap.make_one_shot_iterator()
        segmap_test_iterator = segmap_test.make_one_shot_iterator()

        self.real_x, self.real_x_segmap, self.real_x_segmap_onehot = img_and_segmap_iterator.get_next()
        self.real_x_segmap_test, self.real_x_segmap_test_onehot = segmap_test_iterator.get_next()


        """ Define Generator, Discriminator """
        fake_x, x_mean, x_var = self.image_translate(segmap_img=self.real_x_segmap_onehot, x_img=self.real_x)
        real_logit, fake_logit = self.image_discriminate(segmap_img=self.real_x_segmap_onehot, real_img=self.real_x, fake_img=fake_x)

        if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan':
            GP = self.gradient_penalty(real=self.real_x, segmap=self.real_x_segmap_onehot, fake=fake_x)
        else:
            GP = 0

        """ Define Loss """
        g_adv_loss = self.adv_weight * generator_loss(self.gan_type, fake_logit)
        g_kl_loss = self.kl_weight * kl_loss(x_mean, x_var)
        g_vgg_loss = self.vgg_weight * VGGLoss()(self.real_x, fake_x)
        g_feature_loss = self.feature_weight * feature_loss(real_logit, fake_logit)
        g_reg_loss = regularization_loss('generator') + regularization_loss('encoder')


        d_adv_loss = self.adv_weight * (discriminator_loss(self.gan_type, real_logit, fake_logit) + GP)
        d_reg_loss = regularization_loss('discriminator')

        self.g_loss = g_adv_loss + g_kl_loss + g_vgg_loss + g_feature_loss + g_reg_loss
        self.d_loss = d_adv_loss + d_reg_loss

        """ Result Image """
        self.fake_x = fake_x
        self.random_fake_x, _, _ = self.image_translate(segmap_img=self.real_x_segmap_onehot, random_style=True, reuse=True)

        """ Test """
        self.test_segmap_image = tf.placeholder(tf.float32, [1, self.img_height, self.img_width, len(img_class.color_value_dict)])
        self.random_test_fake_x, _, _ = self.image_translate(segmap_img=self.test_segmap_image, random_style=True, reuse=True)

        self.test_guide_image = tf.placeholder(tf.float32, [1, self.img_height, self.img_width, self.img_ch])
        self.guide_test_fake_x, _, _ = self.image_translate(segmap_img=self.test_segmap_image, x_img=self.test_guide_image, reuse=True)


        """ Training """
        t_vars = tf.trainable_variables()
        G_vars = [var for var in t_vars if 'encoder' in var.name or 'generator' in var.name]
        D_vars = [var for var in t_vars if 'discriminator' in var.name]

        if self.TTUR :
            beta1 = 0.0
            beta2 = 0.9

            g_lr = self.lr / 2
            d_lr = self.lr * 2

        else :
            beta1 = self.beta1
            beta2 = self.beta2
            g_lr = self.lr
            d_lr = self.lr

        self.G_optim = tf.train.AdamOptimizer(g_lr, beta1=beta1, beta2=beta2).minimize(self.g_loss, var_list=G_vars)
        self.D_optim = tf.train.AdamOptimizer(d_lr, beta1=beta1, beta2=beta2).minimize(self.d_loss, var_list=D_vars)

        """" Summary """
        self.summary_g_loss = tf.summary.scalar("g_loss", self.g_loss)
        self.summary_d_loss = tf.summary.scalar("d_loss", self.d_loss)

        self.summary_g_adv_loss = tf.summary.scalar("g_adv_loss", g_adv_loss)
        self.summary_g_kl_loss = tf.summary.scalar("g_kl_loss", g_kl_loss)
        self.summary_g_vgg_loss = tf.summary.scalar("g_vgg_loss", g_vgg_loss)
        self.summary_g_feature_loss = tf.summary.scalar("g_feature_loss", g_feature_loss)


        g_summary_list = [self.summary_g_loss, self.summary_g_adv_loss, self.summary_g_kl_loss, self.summary_g_vgg_loss, self.summary_g_feature_loss]
        d_summary_list = [self.summary_d_loss]

        self.G_loss = tf.summary.merge(g_summary_list)
        self.D_loss = tf.summary.merge(d_summary_list)
コード例 #7
0
    def build_model(self):
        self.lr = tf.placeholder(tf.float32, name='learning_rate')
        """ Input Image"""
        Image_data_class = ImageData(load_size=self.img_size,
                                     channels=self.img_ch,
                                     data_path=self.dataset_path,
                                     selected_attrs=self.selected_attrs,
                                     augment_flag=self.augment_flag)
        Image_data_class.preprocess()

        train_dataset_num = len(Image_data_class.train_dataset)
        test_dataset_num = len(Image_data_class.test_dataset)

        train_dataset = tf.data.Dataset.from_tensor_slices(
            (Image_data_class.train_dataset,
             Image_data_class.train_dataset_label,
             Image_data_class.train_dataset_fix_label))
        test_dataset = tf.data.Dataset.from_tensor_slices(
            (Image_data_class.test_dataset,
             Image_data_class.test_dataset_label,
             Image_data_class.test_dataset_fix_label))

        gpu_device = '/gpu:0'
        train_dataset = train_dataset.\
            apply(shuffle_and_repeat(train_dataset_num)).\
            apply(map_and_batch(Image_data_class.image_processing, self.batch_size, num_parallel_batches=8, drop_remainder=True)).\
            apply(prefetch_to_device(gpu_device, self.batch_size))

        test_dataset = test_dataset.\
            apply(shuffle_and_repeat(test_dataset_num)).\
            apply(map_and_batch(Image_data_class.image_processing, self.batch_size, num_parallel_batches=8, drop_remainder=True)).\
            apply(prefetch_to_device(gpu_device, self.batch_size))

        train_dataset_iterator = train_dataset.make_one_shot_iterator()
        test_dataset_iterator = test_dataset.make_one_shot_iterator()

        self.x_real, label_org, label_fix_list = train_dataset_iterator.get_next(
        )  # Input image / Original domain labels
        label_trg = tf.random_shuffle(label_org)  # Target domain labels
        label_fix_list = tf.transpose(label_fix_list, perm=[1, 0, 2])

        self.x_test, test_label_org, test_label_fix_list = test_dataset_iterator.get_next(
        )  # Input image / Original domain labels
        test_label_fix_list = tf.transpose(test_label_fix_list, perm=[1, 0, 2])

        self.custom_image = tf.placeholder(
            tf.float32, [1, self.img_size, self.img_size, self.img_ch],
            name='custom_image')  # Custom Image
        custom_label_fix_list = tf.transpose(create_labels(
            self.custom_label, self.selected_attrs),
                                             perm=[1, 0, 2])
        """ Define Generator, Discriminator """
        x_fake = self.generator(self.x_real, label_trg)  # real a
        x_recon = self.generator(x_fake, label_org, reuse=True)  # real b

        real_logit, real_cls = self.discriminator(self.x_real)
        fake_logit, fake_cls = self.discriminator(x_fake, reuse=True)
        """ Define Loss """
        if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan':
            GP = self.gradient_panalty(real=self.x_real, fake=x_fake)
        else:
            GP = 0

        g_adv_loss = generator_loss(loss_func=self.gan_type, fake=fake_logit)
        g_cls_loss = classification_loss(logit=fake_cls, label=label_trg)
        g_rec_loss = L1_loss(self.x_real, x_recon)

        d_adv_loss = discriminator_loss(
            loss_func=self.gan_type, real=real_logit, fake=fake_logit) + GP
        d_cls_loss = classification_loss(logit=real_cls, label=label_org)

        self.d_loss = self.adv_weight * d_adv_loss + self.cls_weight * d_cls_loss
        self.g_loss = self.adv_weight * g_adv_loss + self.cls_weight * g_cls_loss + self.rec_weight * g_rec_loss
        """ Result Image """
        self.x_fake_list = tf.map_fn(
            lambda x: self.generator(self.x_real, x, reuse=True),
            label_fix_list,
            dtype=tf.float32)
        """ Test Image """
        self.x_test_fake_list = tf.map_fn(
            lambda x: self.generator(self.x_test, x, reuse=True),
            test_label_fix_list,
            dtype=tf.float32)
        self.custom_fake_image = tf.map_fn(
            lambda x: self.generator(self.custom_image, x, reuse=True),
            custom_label_fix_list,
            dtype=tf.float32)
        """ Training """
        t_vars = tf.trainable_variables()
        G_vars = [var for var in t_vars if 'generator' in var.name]
        D_vars = [var for var in t_vars if 'discriminator' in var.name]

        self.g_optimizer = tf.train.AdamOptimizer(self.lr,
                                                  beta1=0.5,
                                                  beta2=0.999).minimize(
                                                      self.g_loss,
                                                      var_list=G_vars)
        self.d_optimizer = tf.train.AdamOptimizer(self.lr,
                                                  beta1=0.5,
                                                  beta2=0.999).minimize(
                                                      self.d_loss,
                                                      var_list=D_vars)
        """" Summary """
        self.Generator_loss = tf.summary.scalar("Generator_loss", self.g_loss)
        self.Discriminator_loss = tf.summary.scalar("Discriminator_loss",
                                                    self.d_loss)

        self.g_adv_loss = tf.summary.scalar("g_adv_loss", g_adv_loss)
        self.g_cls_loss = tf.summary.scalar("g_cls_loss", g_cls_loss)
        self.g_rec_loss = tf.summary.scalar("g_rec_loss", g_rec_loss)

        self.d_adv_loss = tf.summary.scalar("d_adv_loss", d_adv_loss)
        self.d_cls_loss = tf.summary.scalar("d_cls_loss", d_cls_loss)

        self.g_summary_loss = tf.summary.merge([
            self.Generator_loss, self.g_adv_loss, self.g_cls_loss,
            self.g_rec_loss
        ])
        self.d_summary_loss = tf.summary.merge(
            [self.Discriminator_loss, self.d_adv_loss, self.d_cls_loss])
コード例 #8
0
    def build_model(self):
        label_fix_onehot_list = []
        """ Input Image"""
        if self.dataset_name == 'celebA-HQ' or self.dataset_name == 'celebA':
            img_class = ImageData_celebA(self.img_height, self.img_width,
                                         self.img_ch, self.dataset_path,
                                         self.label_list, self.augment_flag)
            img_class.preprocess(self.phase)

        else:
            img_class = Image_data(self.img_height, self.img_width,
                                   self.img_ch, self.dataset_path,
                                   self.label_list, self.augment_flag)
            img_class.preprocess()

            label_fix_onehot_list = img_class.label_onehot_list
            label_fix_onehot_list = tf.tile(
                tf.expand_dims(label_fix_onehot_list, axis=1),
                [1, self.batch_size, 1])

        dataset_num = len(img_class.image)
        print("Dataset number : ", dataset_num)

        if self.phase == 'train':
            self.lr = tf.placeholder(tf.float32, name='learning_rate')

            if self.dataset_name == 'celebA-HQ' or self.dataset_name == 'celebA':
                img_and_label = tf.data.Dataset.from_tensor_slices(
                    (img_class.image, img_class.label,
                     img_class.train_label_onehot_list))
            else:
                img_and_label = tf.data.Dataset.from_tensor_slices(
                    (img_class.image, img_class.label))

            gpu_device = '/gpu:0'
            img_and_label = img_and_label.apply(
                shuffle_and_repeat(dataset_num)).apply(
                    map_and_batch(img_class.image_processing,
                                  self.batch_size,
                                  num_parallel_batches=16,
                                  drop_remainder=True)).apply(
                                      prefetch_to_device(gpu_device, None))

            img_and_label_iterator = img_and_label.make_one_shot_iterator()

            if self.dataset_name == 'celebA-HQ' or self.dataset_name == 'celebA':
                self.x_real, label_org, label_fix_onehot_list = img_and_label_iterator.get_next(
                )
                label_trg = tf.random_shuffle(
                    label_org)  # Target domain labels
                label_fix_onehot_list = tf.transpose(label_fix_onehot_list,
                                                     perm=[1, 0, 2])
            else:
                self.x_real, label_org = img_and_label_iterator.get_next()
                label_trg = tf.random_shuffle(
                    label_org)  # Target domain labels
            """ Define Generator, Discriminator """
            fake_style_code = tf.random_normal(
                shape=[self.batch_size, self.style_dim])
            x_fake = self.generator(self.x_real, label_trg,
                                    fake_style_code)  # real a

            recon_style_code = tf.random_normal(
                shape=[self.batch_size, self.style_dim])
            x_recon = self.generator(x_fake,
                                     label_org,
                                     recon_style_code,
                                     reuse=True)  # real b

            real_logit, real_cls, _ = self.discriminator(self.x_real)
            fake_logit, fake_cls, fake_noise = self.discriminator(x_fake,
                                                                  reuse=True)
            """ Define Loss """
            if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan':
                GP = self.gradient_panalty(real=self.x_real, fake=x_fake)
            else:
                GP = 0

            g_adv_loss = self.adv_weight * generator_loss(
                self.gan_type, fake_logit)
            g_cls_loss = self.cls_weight * classification_loss(logit=fake_cls,
                                                               label=label_trg)
            g_rec_loss = self.rec_weight * L1_loss(self.x_real, x_recon)
            g_noise_loss = self.noise_weight * L1_loss(fake_style_code,
                                                       fake_noise)

            d_adv_loss = self.adv_weight * discriminator_loss(
                self.gan_type, real_logit, fake_logit) + GP
            d_cls_loss = self.cls_weight * classification_loss(logit=real_cls,
                                                               label=label_org)
            d_noise_loss = self.noise_weight * L1_loss(fake_style_code,
                                                       fake_noise)

            self.d_loss = d_adv_loss + d_cls_loss + d_noise_loss
            self.g_loss = g_adv_loss + g_cls_loss + g_rec_loss + g_noise_loss
            """ Result Image """
            if self.dataset_name == 'celebA-HQ' or self.dataset_name == 'celebA':
                self.x_fake_list = []

                for _ in range(self.num_style):
                    random_style_code = tf.random_normal(
                        shape=[self.batch_size, self.style_dim])
                    self.x_fake_list.append(
                        tf.map_fn(lambda c: self.generator(
                            self.x_real, c, random_style_code, reuse=True),
                                  label_fix_onehot_list,
                                  dtype=tf.float32))

            else:
                self.x_fake_list = []

                for _ in range(self.num_style):
                    random_style_code = tf.random_normal(
                        shape=[self.batch_size, self.style_dim])
                    self.x_fake_list.append(
                        tf.map_fn(lambda c: self.generator(
                            self.x_real, c, random_style_code, reuse=True),
                                  label_fix_onehot_list,
                                  dtype=tf.float32))
            """ Training """
            t_vars = tf.trainable_variables()
            G_vars = [var for var in t_vars if 'generator' in var.name]
            D_vars = [var for var in t_vars if 'discriminator' in var.name]

            self.g_optimizer = tf.train.AdamOptimizer(self.lr,
                                                      beta1=0.5,
                                                      beta2=0.999).minimize(
                                                          self.g_loss,
                                                          var_list=G_vars)
            self.d_optimizer = tf.train.AdamOptimizer(self.lr,
                                                      beta1=0.5,
                                                      beta2=0.999).minimize(
                                                          self.d_loss,
                                                          var_list=D_vars)
            """" Summary """
            self.Generator_loss = tf.summary.scalar("g_loss", self.g_loss)
            self.Discriminator_loss = tf.summary.scalar("d_loss", self.d_loss)

            self.g_adv_loss = tf.summary.scalar("g_adv_loss", g_adv_loss)
            self.g_cls_loss = tf.summary.scalar("g_cls_loss", g_cls_loss)
            self.g_rec_loss = tf.summary.scalar("g_rec_loss", g_rec_loss)
            self.g_noise_loss = tf.summary.scalar("g_noise_loss", g_noise_loss)

            self.d_adv_loss = tf.summary.scalar("d_adv_loss", d_adv_loss)
            self.d_cls_loss = tf.summary.scalar("d_cls_loss", d_cls_loss)
            self.d_noise_loss = tf.summary.scalar("d_noise_loss", d_noise_loss)

            self.g_summary_loss = tf.summary.merge([
                self.Generator_loss, self.g_adv_loss, self.g_cls_loss,
                self.g_rec_loss, self.g_noise_loss
            ])
            self.d_summary_loss = tf.summary.merge([
                self.Discriminator_loss, self.d_adv_loss, self.d_cls_loss,
                self.d_noise_loss
            ])

        else:
            """ Test """
            if self.dataset_name == 'celebA-HQ' or self.dataset_name == 'celebA':
                img_and_label = tf.data.Dataset.from_tensor_slices(
                    (img_class.test_image, img_class.test_label,
                     img_class.test_label_onehot_list))
                dataset_num = len(img_class.test_image)

                gpu_device = '/gpu:0'
                img_and_label = img_and_label.apply(
                    shuffle_and_repeat(dataset_num)).apply(
                        map_and_batch(img_class.image_processing,
                                      batch_size=self.batch_size,
                                      num_parallel_batches=16,
                                      drop_remainder=True)).apply(
                                          prefetch_to_device(gpu_device, None))

                img_and_label_iterator = img_and_label.make_one_shot_iterator()

                self.x_test, _, self.test_label_fix_onehot_list = img_and_label_iterator.get_next(
                )
                self.test_img_placeholder = tf.placeholder(
                    tf.float32,
                    [1, self.img_height, self.img_width, self.img_ch])
                self.test_label_fix_placeholder = tf.placeholder(
                    tf.float32, [self.c_dim, 1, self.c_dim])

                self.custom_image = tf.placeholder(
                    tf.float32,
                    [1, self.img_height, self.img_width, self.img_ch],
                    name='custom_image')  # Custom Image
                custom_label_fix_onehot_list = tf.transpose(
                    np.expand_dims(label2onehot(self.label_list), axis=0),
                    perm=[1, 0, 2])  # [c_dim, bs, c_dim]
                """ Test Image """
                test_random_style_code = tf.random_normal(
                    shape=[1, self.style_dim])

                self.x_test_fake_list = tf.map_fn(
                    lambda c: self.generator(self.test_img_placeholder, c,
                                             test_random_style_code),
                    self.test_label_fix_placeholder,
                    dtype=tf.float32)
                self.custom_fake_image = tf.map_fn(lambda c: self.generator(
                    self.custom_image, c, test_random_style_code, reuse=True),
                                                   custom_label_fix_onehot_list,
                                                   dtype=tf.float32)

            else:
                self.custom_image = tf.placeholder(
                    tf.float32,
                    [1, self.img_height, self.img_width, self.img_ch],
                    name='custom_image')  # Custom Image
                custom_label_fix_onehot_list = tf.transpose(
                    np.expand_dims(label2onehot(self.label_list), axis=0),
                    perm=[1, 0, 2])  # [c_dim, bs, c_dim]

                test_random_style_code = tf.random_normal(
                    shape=[1, self.style_dim])
                self.custom_fake_image = tf.map_fn(
                    lambda c: self.generator(self.custom_image, c,
                                             test_random_style_code),
                    custom_label_fix_onehot_list,
                    dtype=tf.float32)
コード例 #9
0
    def build_model(self):
        """ Graph Input """
        # images
        Image_Data_Class = ImageData(self.img_size, self.c_dim,
                                     self.custom_dataset)
        inputs = tf.data.Dataset.from_tensor_slices(self.data)

        gpu_device = '/gpu:0'
        inputs = inputs.\
            apply(shuffle_and_repeat(self.dataset_num)).\
            apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).\
            apply(prefetch_to_device(gpu_device, self.batch_size))

        inputs_iterator = inputs.make_one_shot_iterator()

        self.inputs = inputs_iterator.get_next()

        # noises
        self.z = tf.truncated_normal(shape=[self.batch_size, 1, 1, self.z_dim],
                                     name='random_z')
        """ Loss Function """
        # output of D for real images
        real_logits = self.discriminator(self.inputs)

        # output of D for fake images
        fake_images = self.generator(self.z)
        fake_logits = self.discriminator(fake_images, reuse=True)

        if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan':
            GP = self.gradient_penalty(real=self.inputs, fake=fake_images)
        else:
            GP = 0

        # get loss for discriminator
        self.d_loss = discriminator_loss(
            self.gan_type, real=real_logits, fake=fake_logits) + GP

        # get loss for generator
        self.g_loss = generator_loss(self.gan_type, fake=fake_logits)
        """ Training """
        # divide trainable variables into a group for D and a group for G
        t_vars = tf.trainable_variables()
        d_vars = [var for var in t_vars if 'discriminator' in var.name]
        g_vars = [var for var in t_vars if 'generator' in var.name]

        # optimizers
        with tf.control_dependencies(tf.get_collection(
                tf.GraphKeys.UPDATE_OPS)):
            self.d_optim = tf.train.AdamOptimizer(self.d_learning_rate,
                                                  beta1=self.beta1,
                                                  beta2=self.beta2).minimize(
                                                      self.d_loss,
                                                      var_list=d_vars)

            self.opt = MovingAverageOptimizer(tf.train.AdamOptimizer(
                self.g_learning_rate, beta1=self.beta1, beta2=self.beta2),
                                              average_decay=self.moving_decay)
            self.g_optim = self.opt.minimize(self.g_loss, var_list=g_vars)
        """" Testing """
        # for test
        self.fake_images = self.generator(self.z,
                                          is_training=False,
                                          reuse=True)
        """ Summary """
        self.d_sum = tf.summary.scalar("d_loss", self.d_loss)
        self.g_sum = tf.summary.scalar("g_loss", self.g_loss)
コード例 #10
0
    def build_model(self):
        self.lr = tf.placeholder(tf.float32, name='lr')

        """ Input Image"""
        Image_Data_Class = ImageData(self.img_size, self.img_ch, self.augment_flag)

        trainA = tf.data.Dataset.from_tensor_slices(self.trainA_dataset)
        trainB = tf.data.Dataset.from_tensor_slices(self.trainB_dataset)

        gpu_device = '/gpu:0'
        trainA = trainA.apply(shuffle_and_repeat(self.dataset_num)).apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply(prefetch_to_device(gpu_device, self.batch_size))
        trainB = trainB.apply(shuffle_and_repeat(self.dataset_num)).apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply(prefetch_to_device(gpu_device, self.batch_size))


        trainA_iterator = trainA.make_one_shot_iterator()
        trainB_iterator = trainB.make_one_shot_iterator()


        self.domain_A = trainA_iterator.get_next()
        self.domain_B = trainB_iterator.get_next()


        """ Define Encoder, Generator, Discriminator """
        random_z = tf.random_normal(shape=[self.batch_size, self.n_z], mean=0.0, stddev=1.0, dtype=tf.float32)

        # encode
        content_a, attribute_a, mean_a, logvar_a = self.Encoder_A(self.domain_A)
        content_b, attribute_b, mean_b, logvar_b = self.Encoder_B(self.domain_B)

        # decode (fake, identity, random)
        fake_a = self.Decoder_A(content_B=content_b, attribute_A=attribute_a)
        fake_b = self.Decoder_B(content_A=content_a, attribute_B=attribute_b)

        recon_a = self.Decoder_A(content_B=content_a, attribute_A=attribute_a, reuse=True)
        recon_b = self.Decoder_B(content_A=content_b, attribute_B=attribute_b, reuse=True)

        random_fake_a = self.Decoder_A(content_B=content_b, attribute_A=random_z, reuse=True)
        random_fake_b = self.Decoder_B(content_A=content_a, attribute_B=random_z, reuse=True)

        # encode & decode again for cycle-consistency
        content_fake_a, attribute_fake_a, _, _ = self.Encoder_A(fake_a, reuse=True)
        content_fake_b, attribute_fake_b, _, _ = self.Encoder_B(fake_b, reuse=True)

        cycle_a = self.Decoder_A(content_B=content_fake_b, attribute_A=attribute_fake_a, reuse=True)
        cycle_b = self.Decoder_B(content_A=content_fake_a, attribute_B=attribute_fake_b, reuse=True)

        # for latent regression
        _, attribute_fake_random_a, _, _ = self.Encoder_A(random_fake_a, random_fake=True, reuse=True)
        _, attribute_fake_random_b, _, _ = self.Encoder_B(random_fake_b, random_fake=True, reuse=True)


        # discriminate
        real_A_logit, real_B_logit = self.discriminate_real(self.domain_A, self.domain_B)
        fake_A_logit, fake_B_logit = self.discriminate_fake(fake_a, fake_b)
        random_fake_A_logit, random_fake_B_logit = self.discriminate_fake(random_fake_a, random_fake_b)
        content_A_logit, content_B_logit = self.discriminate_content(content_a, content_b)


        """ Define Loss """
        g_adv_loss_a = generator_loss(self.gan_type, fake_A_logit) + generator_loss(self.gan_type, random_fake_A_logit)
        g_adv_loss_b = generator_loss(self.gan_type, fake_B_logit) + generator_loss(self.gan_type, random_fake_B_logit)

        g_con_loss_a = generator_loss(self.gan_type, content_A_logit, content=True)
        g_con_loss_b = generator_loss(self.gan_type, content_B_logit, content=True)

        g_cyc_loss_a = L1_loss(cycle_a, self.domain_A)
        g_cyc_loss_b = L1_loss(cycle_b, self.domain_B)

        g_rec_loss_a = L1_loss(recon_a, self.domain_A)
        g_rec_loss_b = L1_loss(recon_b, self.domain_B)

        g_latent_loss_a = L1_loss(attribute_fake_random_a, random_z)
        g_latent_loss_b = L1_loss(attribute_fake_random_b, random_z)

        if self.concat :
            g_kl_loss_a = kl_loss(mean_a, logvar_a) + l2_regularize(content_a)
            g_kl_loss_b = kl_loss(mean_b, logvar_b) + l2_regularize(content_b)
        else :
            g_kl_loss_a = l2_regularize(attribute_a) + l2_regularize(content_a)
            g_kl_loss_b = l2_regularize(attribute_b) + l2_regularize(content_b)


        d_adv_loss_a = discriminator_loss(self.gan_type, real_A_logit, fake_A_logit, random_fake_A_logit)
        d_adv_loss_b = discriminator_loss(self.gan_type, real_B_logit, fake_B_logit, random_fake_B_logit)

        d_con_loss = discriminator_loss(self.gan_type, content_A_logit, content_B_logit, content=True)

        Generator_A_domain_loss = self.domain_adv_w * g_adv_loss_a
        Generator_A_content_loss = self.content_adv_w * g_con_loss_a
        Generator_A_cycle_loss = self.cycle_w * g_cyc_loss_b
        Generator_A_recon_loss = self.recon_w * g_rec_loss_a
        Generator_A_latent_loss = self.latent_w * g_latent_loss_a
        Generator_A_kl_loss = self.kl_w * g_kl_loss_a

        Generator_A_loss = Generator_A_domain_loss + \
                           Generator_A_content_loss + \
                           Generator_A_cycle_loss + \
                           Generator_A_recon_loss + \
                           Generator_A_latent_loss + \
                           Generator_A_kl_loss

        Generator_B_domain_loss = self.domain_adv_w * g_adv_loss_b
        Generator_B_content_loss = self.content_adv_w * g_con_loss_b
        Generator_B_cycle_loss = self.cycle_w * g_cyc_loss_a
        Generator_B_recon_loss = self.recon_w * g_rec_loss_b
        Generator_B_latent_loss = self.latent_w * g_latent_loss_b
        Generator_B_kl_loss = self.kl_w * g_kl_loss_b

        Generator_B_loss = Generator_B_domain_loss + \
                           Generator_B_content_loss + \
                           Generator_B_cycle_loss + \
                           Generator_B_recon_loss + \
                           Generator_B_latent_loss + \
                           Generator_B_kl_loss

        Discriminator_A_loss = self.domain_adv_w * d_adv_loss_a
        Discriminator_B_loss = self.domain_adv_w * d_adv_loss_b
        Discriminator_content_loss = self.content_adv_w * d_con_loss

        self.Generator_loss = Generator_A_loss + Generator_B_loss
        self.Discriminator_loss = Discriminator_A_loss + Discriminator_B_loss
        self.Discriminator_content_loss = Discriminator_content_loss

        """ Training """
        t_vars = tf.trainable_variables()
        G_vars = [var for var in t_vars if 'encoder' in var.name or 'generator' in var.name]
        D_vars = [var for var in t_vars if 'discriminator' in var.name and 'content' not in var.name]
        D_content_vars = [var for var in t_vars if 'content_discriminator' in var.name]

        grads, _ = tf.clip_by_global_norm(tf.gradients(self.Discriminator_content_loss, D_content_vars), clip_norm=5)

        self.G_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Generator_loss, var_list=G_vars)
        self.D_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Discriminator_loss, var_list=D_vars)
        self.D_content_optim = tf.train.AdamOptimizer(self.d_content_init_lr, beta1=0.5, beta2=0.999).apply_gradients(zip(grads, D_content_vars))


        """" Summary """
        self.lr_write = tf.summary.scalar("learning_rate", self.lr)

        self.all_G_loss = tf.summary.scalar("Generator_loss", self.Generator_loss)
        self.all_D_loss = tf.summary.scalar("Discriminator_loss", self.Discriminator_loss)

        self.G_A_loss = tf.summary.scalar("G_A_loss", Generator_A_loss)
        self.G_A_domain_loss = tf.summary.scalar("G_A_domain_loss", Generator_A_domain_loss)
        self.G_A_content_loss = tf.summary.scalar("G_A_content_loss", Generator_A_content_loss)
        self.G_A_cycle_loss = tf.summary.scalar("G_A_cycle_loss", Generator_A_cycle_loss)
        self.G_A_recon_loss = tf.summary.scalar("G_A_recon_loss", Generator_A_recon_loss)
        self.G_A_latent_loss = tf.summary.scalar("G_A_latent_loss", Generator_A_latent_loss)
        self.G_A_kl_loss = tf.summary.scalar("G_A_kl_loss", Generator_A_kl_loss)


        self.G_B_loss = tf.summary.scalar("G_B_loss", Generator_B_loss)
        self.G_B_domain_loss = tf.summary.scalar("G_B_domain_loss", Generator_B_domain_loss)
        self.G_B_content_loss = tf.summary.scalar("G_B_content_loss", Generator_B_content_loss)
        self.G_B_cycle_loss = tf.summary.scalar("G_B_cycle_loss", Generator_B_cycle_loss)
        self.G_B_recon_loss = tf.summary.scalar("G_B_recon_loss", Generator_B_recon_loss)
        self.G_B_latent_loss = tf.summary.scalar("G_B_latent_loss", Generator_B_latent_loss)
        self.G_B_kl_loss = tf.summary.scalar("G_B_kl_loss", Generator_B_kl_loss)

        self.D_A_loss = tf.summary.scalar("D_A_loss", Discriminator_A_loss)
        self.D_B_loss = tf.summary.scalar("D_B_loss", Discriminator_B_loss)

        self.G_loss = tf.summary.merge([self.G_A_loss,
                                        self.G_A_domain_loss, self.G_A_content_loss,
                                        self.G_A_cycle_loss, self.G_A_recon_loss,
                                        self.G_A_latent_loss, self.G_A_kl_loss,

                                        self.G_B_loss,
                                        self.G_B_domain_loss, self.G_B_content_loss,
                                        self.G_B_cycle_loss, self.G_B_recon_loss,
                                        self.G_B_latent_loss, self.G_B_kl_loss,

                                        self.all_G_loss])

        self.D_loss = tf.summary.merge([self.D_A_loss,
                                        self.D_B_loss,
                                        self.all_D_loss])

        self.D_content_loss = tf.summary.scalar("Discriminator_content_loss", self.Discriminator_content_loss)



        """ Image """
        self.fake_A = random_fake_a
        self.fake_B = random_fake_b

        self.real_A = self.domain_A
        self.real_B = self.domain_B


        """ Test """
        self.test_image = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='test_image')
        self.test_random_z = tf.random_normal(shape=[1, self.n_z], mean=0.0, stddev=1.0, dtype=tf.float32)

        test_content_a, _, _, _ = self.Encoder_A(self.test_image, is_training=False, reuse=True)
        test_content_b, _, _, _ = self.Encoder_B(self.test_image, is_training=False, reuse=True)

        self.test_fake_A = self.Decoder_A(content_B=test_content_b, attribute_A=self.test_random_z, reuse=True)
        self.test_fake_B = self.Decoder_B(content_A=test_content_a, attribute_B=self.test_random_z, reuse=True)

        """ Guided Image Translation """
        self.content_image = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='content_image')
        self.attribute_image = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='guide_attribute_image')

        if self.direction == 'a2b' :
            guide_content_A, _, _, _ = self.Encoder_A(self.content_image, is_training=False, reuse=True)
            _, guide_attribute_B, _, _ = self.Encoder_B(self.attribute_image, is_training=False, reuse=True)
            self.guide_fake_B = self.Decoder_B(content_A=guide_content_A, attribute_B=guide_attribute_B, reuse=True)

        else :
            guide_content_B, _, _, _ = self.Encoder_B(self.content_image, is_training=False, reuse=True)
            _, guide_attribute_A, _, _ = self.Encoder_A(self.attribute_image, is_training=False, reuse=True)
            self.guide_fake_A = self.Decoder_A(content_B=guide_content_B, attribute_A=guide_attribute_A, reuse=True)
コード例 #11
0
ファイル: model.py プロジェクト: WYu-Feng/ISMC
    def build_model(self):
        if self.phase == 'train':
            #初始化步长
            self.lr = tf.placeholder(tf.float32, name='learning_rate')
            """ Input Image"""
            Image_Data_Class = ImageData(self.img_size, self.img_ch,
                                         self.augment_flag)
            trainA = tf.data.Dataset.from_tensor_slices(self.trainA_dataset)
            trainB = tf.data.Dataset.from_tensor_slices(self.trainB_dataset)
            #将图片地址切片
            #print('ok trainA:',trainA)

            gpu_device = '/gpu:0'
            trainA = trainA.apply(shuffle_and_repeat(100)).apply(
                map_and_batch(Image_Data_Class.image_processing,
                              self.batch_size,
                              num_parallel_batches=16,
                              drop_remainder=True)).apply(
                                  prefetch_to_device(gpu_device, None))
            trainB = trainB.apply(shuffle_and_repeat(100)).apply(
                map_and_batch(Image_Data_Class.image_processing,
                              self.batch_size,
                              num_parallel_batches=16,
                              drop_remainder=True)).apply(
                                  prefetch_to_device(gpu_device, None))
            #打乱原图片排列
            #map_and_batch 1:将tensor的嵌套结构映射到另一个tensor嵌套结构的函数
            #              2:要在此数据集合并的单个batch中的连续元素数(一个batch 32 个元素,即输出是32维的元素)
            #              3:要并行创建的batch数。一方面,较高的值可以帮助减轻落后者的影响。另一方面,如果CPU空闲,较高的值可能会增加竞争。
            #              4:表示是否应丢弃最后一个batch,以防其大小小于所需值
            #返回 32*64*64*3 的随机增强的数组
            #应用gpu加速
            #print('ok trainA:',trainA.shape)
            trainA_iterator = trainA.make_one_shot_iterator()
            trainB_iterator = trainB.make_one_shot_iterator()

            self.domain_A = trainA_iterator.get_next()
            self.domain_B = trainB_iterator.get_next()
            """ Define Generator, Discriminator """
            x_ab, cam_ab = self.generate_a2b(self.domain_A)  # real a
            #self.domain_A 是 卡通图片
            #x_ab是由self.domain_A经过下采样 再 上采样得到的图片
            x_ba, cam_ba = self.generate_b2a(self.domain_B)  # real b
            #generate_a2b和generate_b2a是两套不同的参数

            x_aba, _ = self.generate_b2a(x_ab, reuse=True)  # real b
            x_bab, _ = self.generate_a2b(x_ba, reuse=True)  # real a
            #固定参数不变,再将generate_a2b生成的图片,用generate_b2a生成一遍
            #generate_b2a 尝试 将真人图生成卡通图
            #generate_a2b 尝试 将卡通图生成真人图
            #可以看做将generate_a2b与generate_b2a作为逆变换

            x_aa, cam_aa = self.generate_b2a(self.domain_A,
                                             reuse=True)  # fake b
            x_bb, cam_bb = self.generate_a2b(self.domain_B,
                                             reuse=True)  # fake a
            #固定参数不变
            #***将卡通图生成卡通图
            #确保在generate_b2a和generate_a2b过程中,颜色区域是不变的
            real_A_logit, real_A_cam_logit, real_B_logit, real_B_cam_logit = self.discriminate_real(
                self.domain_A, self.domain_B)
            #鉴别 真卡通图 与 真真人图
            fake_A_logit, fake_A_cam_logit, fake_B_logit, fake_B_cam_logit = self.discriminate_fake(
                x_ba, x_ab)
            #鉴别 假卡通图 与 假真人图
            #输入的是生成器生成的图片
            #输出的是图片经过卷积的32*8*8*1的张量和池化结果的连接(32, 2)
            """ Define Loss """
            if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan':
                GP_A, GP_CAM_A = self.gradient_panalty(real=self.domain_A,
                                                       fake=x_ba,
                                                       scope="discriminator_A")
                GP_B, GP_CAM_B = self.gradient_panalty(real=self.domain_B,
                                                       fake=x_ab,
                                                       scope="discriminator_B")
            else:
                GP_A, GP_CAM_A = 0, 0
                GP_B, GP_CAM_B = 0, 0

            #接下来是对于假图片判别器discriminate_fake 真图片判别器discriminate_real的损失计算
            #对判别器和生成器的损失计算
            '''
            一、对抗损失(T)  
                用的最小二乘损失
            '''
            #fake_A_logit[0],fake_A_logit[1]与1的平方差的和+fake_A_cam_logit[0],fake_A_cam_logit[1]与1的平方差的和
            #从生成器的角度看,真图片会被discriminate鉴别器判为0,生成的图片经过判别器(1-fake_*),也要尽量被判别为0
            G_ad_loss_A = (generator_loss(self.gan_type, fake_A_logit) +
                           generator_loss(self.gan_type, fake_A_cam_logit))
            #生成假卡通图的损失
            G_ad_loss_B = (generator_loss(self.gan_type, fake_B_logit) +
                           generator_loss(self.gan_type, fake_B_cam_logit))
            #生成假真人图的损失
            #生成器的损失

            D_ad_loss_A = (
                discriminator_loss(self.gan_type, real_A_logit, fake_A_logit) +
                discriminator_loss(self.gan_type, real_A_cam_logit,
                                   fake_A_cam_logit) + GP_A + GP_CAM_A)
            #对卡通图的鉴别损失
            #discriminator_loss力求将真图片判断为1,假图片判断为0
            D_ad_loss_B = (
                discriminator_loss(self.gan_type, real_B_logit, fake_B_logit) +
                discriminator_loss(self.gan_type, real_B_cam_logit,
                                   fake_B_cam_logit) + GP_B + GP_CAM_B)
            #鉴别器的损失
            '''
            二、Cycle 损失(T)
                 the image should be successfully translated back to the original domain
            '''
            reconstruction_A = L1_loss(x_aba, self.domain_A)  # reconstruction
            #由卡通图生成真人图再生成卡通图的损失
            reconstruction_B = L1_loss(x_bab, self.domain_B)  # reconstruction
            #由真人图生成卡通图再生成真人图的损失
            '''
            三、Identity 损失
                确保在A B相互变化时,身份信息不丢失
            '''
            identity_A = L1_loss(x_aa, self.domain_A)
            identity_B = L1_loss(x_bb, self.domain_B)
            #将卡通图生成卡通图的损失
            '''
            四、CAM 损失
                利用辅助分类器,使得G和D知道在哪里进行强化变换
                在A变B时,热力图应该有明显显示
                在A变A时,热力图应该没有显示
            '''
            cam_A = cam_loss(source=cam_ba, non_source=cam_aa)
            #cam_ba是从真人图到卡通图的全连接(两次,用不同方法池化)
            #cam_aa是从卡通图到卡通图的全连接
            cam_B = cam_loss(source=cam_ab, non_source=cam_bb)

            #开始时的比重是如何决定的???
            #
            Generator_A_gan = self.adv_weight * G_ad_loss_A
            #1
            #网络由真人图生成卡通图的损失*相对的比重
            Generator_A_cycle = self.cycle_weight * reconstruction_B
            #10
            #self.generate_a2b(self.generate_b2a(self.domain_B), reuse=True)
            #由真人图生成卡通图再生成真人图的损失*相对的比重
            Generator_A_identity = self.identity_weight * identity_A
            #10
            #由卡通图生成卡通图的损失*相对的比重
            Generator_A_cam = self.cam_weight * cam_A
            #1000
            #从真人图到卡通图的全连接*相对的比重

            Generator_B_gan = self.adv_weight * G_ad_loss_B
            Generator_B_cycle = self.cycle_weight * reconstruction_A
            Generator_B_identity = self.identity_weight * identity_B
            Generator_B_cam = self.cam_weight * cam_B
            print('ok 5')

            Generator_A_loss = Generator_A_gan + Generator_A_cycle + Generator_A_identity + Generator_A_cam
            #所有生成卡通图的损失
            Generator_B_loss = Generator_B_gan + Generator_B_cycle + Generator_B_identity + Generator_B_cam

            Discriminator_A_loss = self.adv_weight * D_ad_loss_A
            #对生成的卡通图 和 真的卡通图的鉴别损失+生成的卡通图的全连接 和 真的卡通图的全连接的鉴别损失*权重
            Discriminator_B_loss = self.adv_weight * D_ad_loss_B

            self.Generator_loss = Generator_A_loss + Generator_B_loss + regularization_loss(
                'generator')
            #生成器的总损失
            self.Discriminator_loss = Discriminator_A_loss + Discriminator_B_loss + regularization_loss(
                'discriminator')
            #鉴别器的总损失
            print('55')
            """ Result Image """
            #生成的假图片(用于储存)
            self.fake_A = x_ba
            self.fake_B = x_ab

            #输入的真图片
            self.real_A = self.domain_A
            self.real_B = self.domain_B

            self.imgba = imgba
            self.imgab = imgab
            """ Training """
            t_vars = tf.trainable_variables()
            G_vars = [var for var in t_vars if 'generator' in var.name]
            D_vars = [var for var in t_vars if 'discriminator' in var.name]

            self.G_optim = tf.train.AdamOptimizer(self.lr,
                                                  beta1=0.5,
                                                  beta2=0.999).minimize(
                                                      self.Generator_loss,
                                                      var_list=G_vars)
            self.D_optim = tf.train.AdamOptimizer(self.lr,
                                                  beta1=0.5,
                                                  beta2=0.999).minimize(
                                                      self.Discriminator_loss,
                                                      var_list=D_vars)
            #var_list:在优化时每次要迭代更新的参数集合
            """" Summary """
            # self.all_G_loss = tf.summary.scalar("Generator_loss", self.Generator_loss)
            # self.all_D_loss = tf.summary.scalar("Discriminator_loss", self.Discriminator_loss)

            # self.G_A_loss = tf.summary.scalar("G_A_loss", Generator_A_loss)
            # self.G_A_gan = tf.summary.scalar("G_A_gan", Generator_A_gan)
            # self.G_A_cycle = tf.summary.scalar("G_A_cycle", Generator_A_cycle)
            # self.G_A_identity = tf.summary.scalar("G_A_identity", Generator_A_identity)
            # self.G_A_cam = tf.summary.scalar("G_A_cam", Generator_A_cam)

            # self.G_B_loss = tf.summary.scalar("G_B_loss", Generator_B_loss)
            # self.G_B_gan = tf.summary.scalar("G_B_gan", Generator_B_gan)
            # self.G_B_cycle = tf.summary.scalar("G_B_cycle", Generator_B_cycle)
            # self.G_B_identity = tf.summary.scalar("G_B_identity", Generator_B_identity)
            # self.G_B_cam = tf.summary.scalar("G_B_cam", Generator_B_cam)

            # self.D_A_loss = tf.summary.scalar("D_A_loss", Discriminator_A_loss)
            # self.D_B_loss = tf.summary.scalar("D_B_loss", Discriminator_B_loss)
            '''
            画图
            '''
            # self.rho_var = []
            # for var in tf.trainable_variables():
            # if 'rho' in var.name:
            # self.rho_var.append(tf.summary.histogram(var.name, var))
            # self.rho_var.append(tf.summary.scalar(var.name + "_min", tf.reduce_min(var)))
            # self.rho_var.append(tf.summary.scalar(var.name + "_max", tf.reduce_max(var)))
            # self.rho_var.append(tf.summary.scalar(var.name + "_mean", tf.reduce_mean(var)))
            # print('ok 7')

            # g_summary_list = [self.G_A_loss, self.G_A_gan, self.G_A_cycle, self.G_A_identity, self.G_A_cam,
            # self.G_B_loss, self.G_B_gan, self.G_B_cycle, self.G_B_identity, self.G_B_cam,
            # self.all_G_loss]

            # g_summary_list.extend(self.rho_var)
            # d_summary_list = [self.D_A_loss, self.D_B_loss, self.all_D_loss]

            # self.G_loss = tf.summary.merge(g_summary_list)
            # self.D_loss = tf.summary.merge(d_summary_list)

        else:
            """ Test """
            self.test_domain_A = tf.placeholder(
                tf.float32, [1, self.img_size, self.img_size, self.img_ch],
                name='test_domain_A')
            self.test_domain_B = tf.placeholder(
                tf.float32, [1, self.img_size, self.img_size, self.img_ch],
                name='test_domain_B')

            self.test_fake_B, _ = self.generate_a2b(self.test_domain_A)
            self.test_fake_A, _ = self.generate_b2a(self.test_domain_B)
コード例 #12
0
    def build_model(self):
        """ Graph Input """
        # images
        if self.custom_dataset :
            Image_Data_Class = ImageData(self.img_size, self.c_dim)
            inputs = tf.data.Dataset.from_tensor_slices(self.data)

            gpu_device = '/gpu:0'
            inputs = inputs.apply(shuffle_and_repeat(self.dataset_num)).apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=8, drop_remainder=True)).apply(prefetch_to_device(gpu_device, self.batch_size))

            inputs_iterator = inputs.make_one_shot_iterator()

            self.inputs = inputs_iterator.get_next()

        else :
            self.inputs = tf.placeholder(tf.float32, [self.batch_size, self.img_size, self.img_size, self.c_dim], name='real_images')

        # noises
        self.z = tf.placeholder(tf.float32, [self.batch_size, self.z_dim], name='z')

        """ Loss Function """
        # output of D for real images
        real_logits = self.discriminator(self.inputs)

        # output of D for fake images
        fake_images = self.generator(self.z)
        fake_logits = self.discriminator(fake_images, reuse=True)

        if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan' :
            GP = self.gradient_penalty(real=self.inputs, fake=fake_images)
        else :
            GP = 0

        # get loss for discriminator
        self.d_loss = discriminator_loss(self.gan_type, real=real_logits, fake=fake_logits) + GP

        # get loss for generator
        self.g_loss = generator_loss(self.gan_type, fake=fake_logits)

        """ Training """
        # divide trainable variables into a group for D and a group for G
        t_vars = tf.trainable_variables()
        d_vars = [var for var in t_vars if 'discriminator' in var.name]
        g_vars = [var for var in t_vars if 'generator' in var.name]

        # optimizers
        self.d_optim = tf.train.AdamOptimizer(self.d_learning_rate, beta1=self.beta1, beta2=self.beta2).minimize(self.d_loss, var_list=d_vars)
        self.g_optim = tf.train.AdamOptimizer(self.g_learning_rate, beta1=self.beta1, beta2=self.beta2).minimize(self.g_loss, var_list=g_vars)

        """" Testing """
        # for test
        self.fake_images = self.generator(self.z, is_training=False, reuse=True)

        """ Summary """
        self.d_sum = tf.summary.scalar("d_loss", self.d_loss)
        self.g_sum = tf.summary.scalar("g_loss", self.g_loss)
コード例 #13
0
    def build_model(self):
        self.lr = tf.placeholder(tf.float32, name='learning_rate')


        """ Input Image"""
        Image_Data_Class = ImageData(self.img_size, self.img_ch, self.augment_flag)

        trainA = tf.data.Dataset.from_tensor_slices(self.trainA_dataset)
        trainB = tf.data.Dataset.from_tensor_slices(self.trainB_dataset)
        trainB_smooth = tf.data.Dataset.from_tensor_slices(self.trainB_smooth_dataset)

        gpu_device = '/gpu:0'
        trainA = trainA.apply(shuffle_and_repeat(self.dataset_num)).apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply(prefetch_to_device(gpu_device, self.batch_size))
        trainB = trainB.apply(shuffle_and_repeat(self.dataset_num)).apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply(prefetch_to_device(gpu_device, self.batch_size))
        trainB_smooth = trainB_smooth.apply(shuffle_and_repeat(self.dataset_num)).apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply(prefetch_to_device(gpu_device, self.batch_size))

        trainA_iterator = trainA.make_one_shot_iterator()
        trainB_iterator = trainB.make_one_shot_iterator()
        trainB_smooth_iterator = trainB_smooth.make_one_shot_iterator()


        self.real_A = trainA_iterator.get_next()
        self.real_B = trainB_iterator.get_next()
        self.real_B_smooth = trainB_smooth_iterator.get_next()

        self.test_real_A = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='test_real_A')


        """ Define Generator, Discriminator """
        self.fake_B = self.generator(self.real_A)

        real_B_logit = self.discriminator(self.real_B)
        fake_B_logit = self.discriminator(self.fake_B, reuse=True)
        real_B_smooth_logit = self.discriminator(self.real_B_smooth, reuse=True)


        """ Define Loss """
        if self.gan_type.__contains__('gp') or self.gan_type.__contains__('lp') or self.gan_type.__contains__('dragan') :
            GP = self.gradient_panalty(real=self.real_B, fake=self.fake_B) + self.gradient_panalty(self.real_B, fake=self.real_B_smooth)
        else :
            GP = 0.0

        v_loss = self.vgg_weight * vgg_loss(self.real_A, self.fake_B)
        g_loss = self.adv_weight * generator_loss(self.gan_type, fake_B_logit)
        d_loss = self.adv_weight * discriminator_loss(self.gan_type, real_B_logit, fake_B_logit, real_B_smooth_logit) + GP

        self.Vgg_loss = v_loss
        self.Generator_loss = g_loss + v_loss
        self.Discriminator_loss = d_loss


        """ Result Image """
        self.test_fake_B = self.generator(self.test_real_A, reuse=True)

        """ Training """
        t_vars = tf.trainable_variables()
        G_vars = [var for var in t_vars if 'generator' in var.name]
        D_vars = [var for var in t_vars if 'discriminator' in var.name]

        self.init_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Vgg_loss, var_list=G_vars)
        self.G_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Generator_loss, var_list=G_vars)
        self.D_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Discriminator_loss, var_list=D_vars)


        """" Summary """
        self.G_loss = tf.summary.scalar("Generator_loss", self.Generator_loss)
        self.D_loss = tf.summary.scalar("Discriminator_loss", self.Discriminator_loss)

        self.G_gan = tf.summary.scalar("G_gan", g_loss)
        self.G_vgg = tf.summary.scalar("G_vgg", v_loss)

        self.V_loss_merge = tf.summary.merge([self.G_vgg])
        self.G_loss_merge = tf.summary.merge([self.G_loss, self.G_gan, self.G_vgg])
        self.D_loss_merge = tf.summary.merge([self.D_loss])
コード例 #14
0
ファイル: TGAN_128.py プロジェクト: xiaoanshi/tf-T-GANs
    def build_model(self):
        # some parameters
        bs = self.batch_size

        """ Graph Input """
        # images
        if self.custom_dataset:
            Image_Data_Class = ImageData(self.output_height, self.c_dim)
            inputs = tf.data.Dataset.from_tensor_slices(self.data)

            gpu_device = '/gpu:0'
            inputs = inputs.apply(shuffle_and_repeat(self.dataset_num)).apply(
                map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=8,
                              drop_remainder=True)).apply(prefetch_to_device(gpu_device, self.batch_size))

            inputs_iterator = inputs.make_one_shot_iterator()

            self.inputs = inputs_iterator.get_next()

        else:
            self.inputs = tf.placeholder(tf.float32, [self.batch_size, self.output_height, self.output_height, self.c_dim],
                                         name='real_images')

        # noises
        self.z = tf.placeholder(tf.float32, [bs, self.z_dim], name='z')

        """ Loss Function """
        x_fake = self.generator(self.z, is_training=True, reuse=False)
        x_real_encoder = self.encoder(self.inputs, is_training=True, reuse=False, sn=True)
        x_fake_encoder = self.encoder(x_fake, is_training=True, reuse=True, sn=True)
        x_real_fake = tf.subtract(x_real_encoder, x_fake_encoder)
        x_fake_real = tf.subtract(x_fake_encoder, x_real_encoder)
        x_real_fake_score = self.discriminator(x_real_fake, reuse=False, sn=True)
        x_fake_real_score = self.discriminator(x_fake_real, reuse=True, sn=True)

        # get loss for discriminator
        self.d_loss = discriminator_loss(self.loss_type, real=x_real_fake_score, fake=x_fake_real_score)

        # get loss for generator
        self.g_loss = generator_loss(self.loss_type, real=x_real_fake_score, fake=x_fake_real_score)

        """ Training """
        # divide trainable variables into a group for D and a group for G
        t_vars = tf.trainable_variables()
        d_vars = [var for var in t_vars if 'discriminator' in var.name or 'encoder' in var.name]
        g_vars = [var for var in t_vars if 'generator' in var.name]

        # optimizers
        with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
            self.d_optim = tf.train.AdamOptimizer(self.learning_rate, beta1=self.beta1) \
                .minimize(self.d_loss, var_list=d_vars)
            self.g_optim = tf.train.AdamOptimizer(self.learning_rate * 5, beta1=self.beta1) \
                .minimize(self.g_loss, var_list=g_vars)

        """" Testing """
        # for test
        self.fake_images = self.generator(self.z, is_training=False, reuse=True)

        """ Summary """
        self.d_sum = tf.summary.scalar("d_loss", self.d_loss)
        self.g_sum = tf.summary.scalar("g_loss", self.g_loss)
コード例 #15
0
ファイル: GDWCT.py プロジェクト: taki0112/GDWCT-Tensorflow
    def build_model(self):
        self.lr = tf.placeholder(tf.float32, name='learning_rate')
        """ Input Image"""
        Image_Data_Class = ImageData(self.img_h, self.img_w, self.img_ch,
                                     self.augment_flag)

        trainA = tf.data.Dataset.from_tensor_slices(self.trainA_dataset)
        trainB = tf.data.Dataset.from_tensor_slices(self.trainB_dataset)

        gpu_device = '/gpu:0'

        trainA = trainA.\
            apply(shuffle_and_repeat(self.dataset_num)). \
            apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)). \
            apply(prefetch_to_device(gpu_device, None))

        trainB = trainB. \
            apply(shuffle_and_repeat(self.dataset_num)). \
            apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)). \
            apply(prefetch_to_device(gpu_device, None))
        # When using dataset.prefetch, use buffer_size=None to let it detect optimal buffer size

        trainA_iterator = trainA.make_one_shot_iterator()
        trainB_iterator = trainB.make_one_shot_iterator()

        self.domain_A = trainA_iterator.get_next()
        self.domain_B = trainB_iterator.get_next()
        """ Define Encoder, Generator, Discriminator """

        # encode
        content_a, style_a = self.encoder_A(self.domain_A)
        content_b, style_b = self.encoder_B(self.domain_B)

        # decode (cross domain)
        x_ba, U_A = self.decoder_A(content_B=content_b, style_A=style_a)
        x_ab, U_B = self.decoder_B(content_A=content_a, style_B=style_b)

        # decode (within domain)
        x_aa, _ = self.decoder_A(content_B=content_a,
                                 style_A=style_a,
                                 reuse=True)
        x_bb, _ = self.decoder_B(content_A=content_b,
                                 style_B=style_b,
                                 reuse=True)

        # encode again
        content_ba, style_ba = self.encoder_A(x_ba, reuse=True)
        content_ab, style_ab = self.encoder_B(x_ab, reuse=True)

        # decode again (if needed)
        x_aba, _ = self.decoder_A(content_B=content_ab,
                                  style_A=style_ba,
                                  reuse=True)
        x_bab, _ = self.decoder_B(content_A=content_ba,
                                  style_B=style_ab,
                                  reuse=True)

        real_A_logit, real_B_logit = self.discriminate_real(
            self.domain_A, self.domain_B)
        fake_A_logit, fake_B_logit = self.discriminate_fake(x_ba, x_ab)
        """ Define Loss """
        G_adv_A = self.gan_w * generator_loss(self.gan_type, fake_A_logit)
        G_adv_B = self.gan_w * generator_loss(self.gan_type, fake_B_logit)

        D_adv_A = self.gan_w * discriminator_loss(self.gan_type, real_A_logit,
                                                  fake_A_logit)
        D_adv_B = self.gan_w * discriminator_loss(self.gan_type, real_B_logit,
                                                  fake_B_logit)

        recon_style_A = self.recon_s_w * L1_loss(style_ba, style_a)
        recon_style_B = self.recon_s_w * L1_loss(style_ab, style_b)

        recon_content_A = self.recon_c_w * L1_loss(content_ab, content_a)
        recon_content_B = self.recon_c_w * L1_loss(content_ba, content_b)

        cyc_recon_A = self.recon_x_cyc_w * L1_loss(x_aba, self.domain_A)
        cyc_recon_B = self.recon_x_cyc_w * L1_loss(x_bab, self.domain_B)

        recon_A = self.recon_x_w * L1_loss(x_aa,
                                           self.domain_A)  # reconstruction
        recon_B = self.recon_x_w * L1_loss(x_bb,
                                           self.domain_B)  # reconstruction

        whitening_A, coloring_A = group_wise_regularization(
            deep_whitening_transform(content_a), U_A, self.group_num)
        whitening_B, coloring_B = group_wise_regularization(
            deep_whitening_transform(content_b), U_B, self.group_num)

        whitening_A = self.lambda_w * whitening_A
        whitening_B = self.lambda_w * whitening_B

        coloring_A = self.lambda_c * coloring_A
        coloring_B = self.lambda_c * coloring_B

        G_reg_A = regularization_loss('decoder_A') + regularization_loss(
            'encoder_A')
        G_reg_B = regularization_loss('decoder_B') + regularization_loss(
            'encoder_B')

        D_reg_A = regularization_loss('discriminator_A')
        D_reg_B = regularization_loss('discriminator_B')


        Generator_A_loss = G_adv_A + \
                           recon_A + \
                           recon_style_A + \
                           recon_content_A + \
                           cyc_recon_B + \
                           whitening_A + \
                           coloring_A + \
                           G_reg_A

        Generator_B_loss = G_adv_B + \
                           recon_B + \
                           recon_style_B + \
                           recon_content_B + \
                           cyc_recon_A + \
                           whitening_B + \
                           coloring_B + \
                           G_reg_B

        Discriminator_A_loss = D_adv_A + D_reg_A
        Discriminator_B_loss = D_adv_B + D_reg_B

        self.Generator_loss = Generator_A_loss + Generator_B_loss
        self.Discriminator_loss = Discriminator_A_loss + Discriminator_B_loss
        """ Training """
        t_vars = tf.trainable_variables()
        G_vars = [
            var for var in t_vars
            if 'decoder' in var.name or 'encoder' in var.name
        ]
        D_vars = [var for var in t_vars if 'discriminator' in var.name]

        self.G_optim = tf.train.AdamOptimizer(
            self.lr, beta1=0.5, beta2=0.999).minimize(self.Generator_loss,
                                                      var_list=G_vars)
        self.D_optim = tf.train.AdamOptimizer(
            self.lr, beta1=0.5, beta2=0.999).minimize(self.Discriminator_loss,
                                                      var_list=D_vars)
        """" Summary """
        self.all_G_loss = tf.summary.scalar("Generator_loss",
                                            self.Generator_loss)
        self.all_D_loss = tf.summary.scalar("Discriminator_loss",
                                            self.Discriminator_loss)

        self.G_A_loss = tf.summary.scalar("G_A_loss", Generator_A_loss)
        self.G_B_loss = tf.summary.scalar("G_B_loss", Generator_B_loss)
        self.D_A_loss = tf.summary.scalar("D_A_loss", Discriminator_A_loss)
        self.D_B_loss = tf.summary.scalar("D_B_loss", Discriminator_B_loss)

        self.G_A_adv_loss = tf.summary.scalar("G_A_adv_loss", G_adv_A)
        self.G_A_style_loss = tf.summary.scalar("G_A_style_loss",
                                                recon_style_A)
        self.G_A_content_loss = tf.summary.scalar("G_A_content_loss",
                                                  recon_content_A)
        self.G_A_cyc_loss = tf.summary.scalar("G_A_cyc_loss", cyc_recon_A)
        self.G_A_identity_loss = tf.summary.scalar("G_A_identity_loss",
                                                   recon_A)
        self.G_A_whitening_loss = tf.summary.scalar("G_A_whitening_loss",
                                                    whitening_A)
        self.G_A_coloring_loss = tf.summary.scalar("G_A_coloring_loss",
                                                   coloring_A)

        self.G_B_adv_loss = tf.summary.scalar("G_B_adv_loss", G_adv_B)
        self.G_B_style_loss = tf.summary.scalar("G_B_style_loss",
                                                recon_style_B)
        self.G_B_content_loss = tf.summary.scalar("G_B_content_loss",
                                                  recon_content_B)
        self.G_B_cyc_loss = tf.summary.scalar("G_B_cyc_loss", cyc_recon_B)
        self.G_B_identity_loss = tf.summary.scalar("G_B_identity_loss",
                                                   recon_B)
        self.G_B_whitening_loss = tf.summary.scalar("G_B_whitening_loss",
                                                    whitening_B)
        self.G_B_coloring_loss = tf.summary.scalar("G_B_coloring_loss",
                                                   coloring_B)

        self.alpha_var = []
        for var in tf.trainable_variables():
            if 'alpha' in var.name:
                self.alpha_var.append(tf.summary.histogram(var.name, var))
                self.alpha_var.append(
                    tf.summary.scalar(var.name, tf.reduce_max(var)))

        G_summary_list = [
            self.G_A_adv_loss, self.G_A_style_loss, self.G_A_content_loss,
            self.G_A_cyc_loss, self.G_A_identity_loss, self.G_A_whitening_loss,
            self.G_A_coloring_loss, self.G_A_loss, self.G_B_adv_loss,
            self.G_B_style_loss, self.G_B_content_loss, self.G_B_cyc_loss,
            self.G_B_identity_loss, self.G_B_whitening_loss,
            self.G_B_coloring_loss, self.G_B_loss, self.all_G_loss
        ]

        G_summary_list.extend(self.alpha_var)

        self.G_loss = tf.summary.merge(G_summary_list)
        self.D_loss = tf.summary.merge(
            [self.D_A_loss, self.D_B_loss, self.all_D_loss])
        """ Image """
        self.fake_A = x_ba
        self.fake_B = x_ab

        self.real_A = self.domain_A
        self.real_B = self.domain_B
        """ Test """
        """ Guided Image Translation """
        self.content_image = tf.placeholder(
            tf.float32, [1, self.img_h, self.img_w, self.img_ch],
            name='content_image')
        self.style_image = tf.placeholder(
            tf.float32, [1, self.img_h, self.img_w, self.img_ch],
            name='guide_style_image')

        if self.direction == 'a2b':
            guide_content_A, _ = self.encoder_A(self.content_image, reuse=True)
            _, guide_style_B = self.encoder_B(self.style_image, reuse=True)

            self.guide_fake_B, _ = self.decoder_B(content_A=guide_content_A,
                                                  style_B=guide_style_B,
                                                  reuse=True)

        else:
            guide_content_B, _ = self.encoder_B(self.content_image, reuse=True)
            _, guide_style_A = self.encoder_A(self.style_image, reuse=True)

            self.guide_fake_A, _ = self.decoder_A(content_B=guide_content_B,
                                                  style_A=guide_style_A,
                                                  reuse=True)
コード例 #16
0
    def build_model(self):
        """ Input Image"""
        img_data_class = Image_data(self.img_height, self.img_width, self.img_ch, self.dataset_path, self.augment_flag)
        train_captions, train_images, test_captions, test_images, idx_to_word, word_to_idx = img_data_class.preprocess()
        """
        train_captions: (8855, 10, 18), test_captions: (2933, 10, 18)
        train_images: (8855,), test_images: (2933,)
        idx_to_word : 5450 5450
        """

        if self.phase == 'train' :
            self.lr = tf.placeholder(tf.float32, name='learning_rate')

            self.dataset_num = len(train_images)


            img_and_caption = tf.data.Dataset.from_tensor_slices((train_images, train_captions))

            gpu_device = '/gpu:0'
            img_and_caption = img_and_caption.apply(shuffle_and_repeat(self.dataset_num)).apply(
                map_and_batch(img_data_class.image_processing, batch_size=self.batch_size, num_parallel_batches=16,
                              drop_remainder=True)).apply(prefetch_to_device(gpu_device, None))


            img_and_caption_iterator = img_and_caption.make_one_shot_iterator()
            real_img_256, caption = img_and_caption_iterator.get_next()
            target_sentence_index = tf.random_uniform(shape=[], minval=0, maxval=10, dtype=tf.int32)
            caption = tf.gather(caption, target_sentence_index, axis=1)

            word_emb, sent_emb, mask = self.rnn_encoder(caption, n_words=len(idx_to_word),
                                                        embed_dim=self.embed_dim, drop_rate=0.5, n_hidden=128, n_layers=1,
                                                        bidirectional=True, rnn_type='lstm')

            noise = tf.random_normal(shape=[self.batch_size, self.z_dim], mean=0.0, stddev=1.0)
            fake_imgs, _, mu, logvar = self.generator(noise, sent_emb, word_emb, mask)

            real_img_64, real_img_128 = resize(real_img_256, target_size=[64, 64]), resize(real_img_256, target_size=[128, 128])
            fake_img_64, fake_img_128, fake_img_256 = fake_imgs[0], fake_imgs[1], fake_imgs[2]
            real_imgs = [real_img_64, real_img_128, real_img_256]

            uncond_real_logits, cond_real_logits, real_emb_features = self.discriminator([real_img_64, real_img_128, real_img_256], sent_emb)
            _, cond_wrong_logits, _ = self.discriminator([real_img_64[:(self.batch_size - 1)], real_img_128[:(self.batch_size - 1)], real_img_256[:(self.batch_size - 1)]], sent_emb[1:self.batch_size])
            uncond_fake_logits, cond_fake_logits, _ = self.discriminator([fake_img_64, fake_img_128, fake_img_256], sent_emb)

            self.g_adv_loss, self.d_adv_loss = 0, 0
            self.g_vgg_loss = 0
            self.d_word_loss = 0

            for i in range(3):
                self.g_adv_loss += self.adv_weight * (generator_loss(self.gan_type, uncond_fake_logits[i]) + generator_loss(self.gan_type, cond_fake_logits[i]))
                self.g_vgg_loss += self.vgg_weight * (vgg16_perceptual_loss(real_imgs[i], fake_imgs[i], class_vgg_16))

                uncond_real_loss, uncond_fake_loss = discriminator_loss(self.gan_type, uncond_real_logits[i], uncond_fake_logits[i])
                cond_real_loss, cond_fake_loss = discriminator_loss(self.gan_type, cond_real_logits[i], cond_fake_logits[i])
                _, cond_wrong_loss = discriminator_loss(self.gan_type, None, cond_wrong_logits[i])

                self.d_adv_loss += self.adv_weight * (((uncond_real_loss + cond_real_loss) / 2) + (uncond_fake_loss + cond_fake_loss + cond_wrong_loss) / 3)
                self.d_word_loss += self.word_weight * word_level_correlation_loss(real_emb_features[i], word_emb, gamma1=4.0, gamma2=5.0)

            self.g_kl_loss = self.kl_weight * kl_loss(mu, logvar)
            self.g_vgg_loss = self.g_vgg_loss / 3.0

            self.g_loss = self.g_adv_loss + self.g_kl_loss + self.g_vgg_loss
            self.d_loss = self.d_adv_loss + self.d_word_loss

            self.real_img = real_img_256
            self.fake_img = fake_img_256


            """ Training """
            t_vars = tf.trainable_variables()
            G_vars = [var for var in t_vars if 'generator' in var.name]
            D_vars = [var for var in t_vars if 'discriminator' in var.name]

            self.g_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.g_loss, var_list=G_vars)
            self.d_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.d_loss, var_list=D_vars)


            """" Summary """
            self.summary_g_loss = tf.summary.scalar("g_loss", self.g_loss)
            self.summary_d_loss = tf.summary.scalar("d_loss", self.d_loss)

            self.summary_g_adv_loss = tf.summary.scalar("g_adv_loss", self.g_adv_loss)
            self.summary_g_kl_loss = tf.summary.scalar("g_kl_loss", self.g_kl_loss)
            self.summary_g_vgg_loss = tf.summary.scalar("g_vgg_loss", self.g_vgg_loss)

            self.summary_d_adv_loss = tf.summary.scalar("d_adv_loss", self.d_adv_loss)
            self.summary_d_word_loss = tf.summary.scalar("d_word_loss", self.d_word_loss)


            g_summary_list = [self.summary_g_loss,
                              self.summary_g_adv_loss, self.summary_g_kl_loss, self.summary_g_vgg_loss]

            d_summary_list = [self.summary_d_loss,
                              self.summary_d_adv_loss, self.summary_d_word_loss]

            self.summary_merge_g_loss = tf.summary.merge(g_summary_list)
            self.summary_merge_d_loss = tf.summary.merge(d_summary_list)

        else :
            """ Test """
            self.dataset_num = len(test_captions)

            gpu_device = '/gpu:0'
            img_and_caption = tf.data.Dataset.from_tensor_slices((test_images, test_captions))

            img_and_caption = img_and_caption.apply(
                shuffle_and_repeat(self.dataset_num)).apply(
                map_and_batch(img_data_class.image_processing, batch_size=self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply(
                prefetch_to_device(gpu_device, None))

            img_and_caption_iterator = img_and_caption.make_one_shot_iterator()
            real_img_256, caption = img_and_caption_iterator.get_next()
            target_sentence_index = tf.random_uniform(shape=[], minval=0, maxval=10, dtype=tf.int32)
            caption = tf.gather(caption, target_sentence_index, axis=1)

            word_emb, sent_emb, mask = self.rnn_encoder(caption, n_words=len(idx_to_word),
                                                        embed_dim=self.embed_dim, drop_rate=0.5, n_hidden=128,
                                                        n_layers=1,
                                                        bidirectional=True, rnn_type='lstm',
                                                        is_training=False)

            noise = tf.random_normal(shape=[self.batch_size, self.z_dim], mean=0.0, stddev=1.0)
            fake_imgs, _, _, _ = self.generator(noise, sent_emb, word_emb, mask, is_training=False)

            self.test_real_img = real_img_256
            self.test_fake_img = fake_imgs[2]