Example #1
0
    def _tfconfig(self, save: str = None, preload: str = None) -> None:
        # self.all_images
        labeling_func = self.config.labeling_func
        preprocessing_func = self.config.preprocessing_func
        if not labeling_func:
            labeling_func = self._default_labeling_func
        if not preprocessing_func:
            preprocessing_func = self._default_preprocessing_func_tensor_input
        self.labels = []
        for each_image in self.all_images:
            self.labels.append(labeling_func(each_image))

        # preload
        if preload:
            self._load_tfrecord(preload)
            return
        if labeling_func == self._default_labeling_func:
            self._process_default_labeling_list(self.labels)
        self.inputs = tf.data.Dataset.from_tensor_slices(
            (self.all_images, self.labels))
        self.inputs = self.inputs. \
            apply(shuffle_and_repeat(self.all_images_num)). \
            apply(map_and_batch(preprocessing_func, self.config.batch_size, num_parallel_batches=16, drop_remainder=True))
        if self.config.gpu_device:
            self.inputs = self.inputs.apply(
                prefetch_to_device('/gpu:{}'.format(self.config.gpu_device),
                                   None))
        self.iterator = self.inputs.make_one_shot_iterator()
        if save:
            self._save_tfrecord(save, preprocessing_func)
Example #2
0
    def build_model(self):
        #输入图像处理
        if self.custom_dataset:
            Image_Data_Class = ImageData(self.img_size, self.c_dim)
            inputs = tf.data.Dataset.from_tensor_slices(self.data)
            gpu_device = '/gpu:0'
            inputs = inputs.apply(shuffle_and_repeat(self.dataset_num)).apply(
                map_and_batch(Image_Data_Class.image_processing,
                              self.batch_size,
                              num_parallel_batches=16,
                              drop_remainder=True)).apply(
                                  prefetch_to_device(gpu_device,
                                                     self.batch_size))
            inputs_iterator = inputs.make_one_shot_iterator()
            self.inputs = inputs_iterator.get_next()
        else:
            self.inputs = tf.placeholder(
                tf.float32,
                [self.batch_size, self.img_size, self.img_size, self.c_dim],
                name='real_images')
        #噪声
        self.z = tf.placeholder(tf.float32,
                                [self.batch_size, 1, 1, self.z_dim],
                                name='z')
        #损失函数
        # D:
        real_logits = self.discriminator(self.inputs)
        fake_images = self.generator(self.z)
        fake_logits = self.discriminator(fake_images, reuse=True)
        self.d_loss = discriminator_loss(self.gan_type,
                                         real=real_logits,
                                         fake=fake_logits)
        # G:
        self.g_loss = generator_loss(self.gan_type, fake=fake_logits)
        #训练training
        t_vars = tf.trainable_variables()
        d_vars = [var for var in t_vars if 'discriminator' in var.name]
        g_vars = [var for var in t_vars if 'generator' in var.name]
        self.d_optim = tf.train.AdamOptimizer(self.d_learning_rate,
                                              beta1=self.beta1,
                                              beta2=self.beta2).minimize(
                                                  self.d_loss, var_list=d_vars)
        self.g_optim = tf.train.AdamOptimizer(self.g_learning_rate,
                                              beta1=self.beta1,
                                              beta2=self.beta2).minimize(
                                                  self.g_loss, var_list=g_vars)
        #测试testing
        self.fake_images = self.generator(self.z,
                                          is_training=False,
                                          reuse=True)

        self.d_sum = tf.summary.scalar("d_loss", self.d_loss)
        self.g_sum = tf.summary.scalar("g_loss", self.g_loss)
Example #3
0
    def _load_tfrecord(self, preload: str):
        assert os.path.exists(preload)
        self.inputs = tf.data.TFRecordDataset(preload)
        self.inputs = self.inputs.apply(
            map_and_batch(self._parse_func,
                          self.config.batch_size,
                          num_parallel_batches=16,
                          drop_remainder=True))

        if self.config.gpu_device:
            self.inputs = self.inputs.apply(
                prefetch_to_device('/gpu:{}'.format(self.config.gpu_device),
                                   None))
        self.iterator = self.inputs.make_one_shot_iterator()
        print(' TFRecord load success: ')
Example #4
0
    def minibatch(self, dataset, subset, cache_data=False):

        with tf.name_scope('batch_processing'):

            glob_pattern = dataset.tf_record_pattern(subset)
            file_names = gfile.Glob(glob_pattern)
            if not file_names:
                raise ValueError(
                    'Found no files in --data_dir matching: {}'.format(
                        glob_pattern))
            ds = tf.data.TFRecordDataset.list_files(file_names)

            # number of parallel open files and tfrecords should be tuned according to
            # different batch size
            ds = ds.apply(
                parallel_interleave(tf.data.TFRecordDataset,
                                    cycle_length=28,
                                    block_length=5,
                                    sloppy=True,
                                    buffer_output_elements=10000,
                                    prefetch_input_elements=10000))

            if cache_data:
                ds = ds.take(1).cache().repeat()

            ds = ds.prefetch(buffer_size=10000)
            #ds = ds.prefetch(buffer_size=self.batch_size)

            ds = ds.apply(
                map_and_batch(
                    map_func=self.parse_and_preprocess,
                    batch_size=self.batch_size,
                    num_parallel_batches=56,
                    num_parallel_calls=None))  # this number should be tuned

            ds = ds.prefetch(buffer_size=tf.contrib.data.AUTOTUNE
                             )  # this number can be tuned

            ds_iterator = ds.make_one_shot_iterator()
            images, labels = ds_iterator.get_next()
            # reshape
            labels = tf.reshape(labels, [self.batch_size])

            return images, labels
Example #5
0
    def build_model(self):

        self.lr = tf.compat.v1.placeholder(tf.float32, name='learning_rate')
        """ Input Image"""
        img_class = Image_data(self.img_height, self.img_width, self.img_ch,
                               self.segmap_ch, self.dataset_path,
                               self.augment_flag)
        img_class.preprocess()

        self.dataset_num = len(img_class.image)
        self.test_dataset_num = len(img_class.segmap_test)

        img_and_segmap = tf.data.Dataset.from_tensor_slices(
            (img_class.image, img_class.segmap))
        segmap_test = tf.data.Dataset.from_tensor_slices(img_class.segmap_test)

        gpu_device = '/gpu:0'
        img_and_segmap = img_and_segmap.apply(
            shuffle_and_repeat(self.dataset_num)).apply(
                map_and_batch(img_class.image_processing,
                              self.batch_size,
                              num_parallel_batches=16,
                              drop_remainder=True)).apply(
                                  prefetch_to_device(gpu_device,
                                                     self.batch_size))

        segmap_test = segmap_test.apply(shuffle_and_repeat(
            self.dataset_num)).apply(
                map_and_batch(img_class.test_image_processing,
                              batch_size=self.batch_size,
                              num_parallel_batches=16,
                              drop_remainder=True)).apply(
                                  prefetch_to_device(gpu_device,
                                                     self.batch_size))

        img_and_segmap_iterator = img_and_segmap.make_one_shot_iterator()
        segmap_test_iterator = segmap_test.make_one_shot_iterator()

        self.real_x, self.real_x_segmap, self.real_x_segmap_onehot = img_and_segmap_iterator.get_next(
        )
        self.real_x_segmap_test, self.real_x_segmap_test_onehot = segmap_test_iterator.get_next(
        )
        """ Define Generator, Discriminator """
        fake_x, x_mean, x_var = self.image_translate(
            segmap_img=self.real_x_segmap_onehot, x_img=self.real_x)
        real_logit, fake_logit = self.image_discriminate(
            segmap_img=self.real_x_segmap_onehot,
            real_img=self.real_x,
            fake_img=fake_x)

        if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan':
            GP = self.gradient_penalty(real=self.real_x,
                                       segmap=self.real_x_segmap_onehot,
                                       fake=fake_x)
        else:
            GP = 0
        """ Define Loss """
        g_adv_loss = self.adv_weight * generator_loss(self.gan_type,
                                                      fake_logit)
        g_kl_loss = self.kl_weight * kl_loss(x_mean, x_var)
        g_vgg_loss = self.vgg_weight * VGGLoss()(self.real_x, fake_x)
        g_feature_loss = self.feature_weight * feature_loss(
            real_logit, fake_logit)
        g_reg_loss = regularization_loss('generator') + regularization_loss(
            'encoder')

        d_adv_loss = self.adv_weight * (
            discriminator_loss(self.gan_type, real_logit, fake_logit) + GP)
        d_reg_loss = regularization_loss('discriminator')

        self.g_loss = g_adv_loss + g_kl_loss + g_vgg_loss + g_feature_loss + g_reg_loss
        self.d_loss = d_adv_loss + d_reg_loss
        """ Result Image """
        self.fake_x = fake_x
        self.random_fake_x, _, _ = self.image_translate(
            segmap_img=self.real_x_segmap_onehot,
            random_style=True,
            reuse=True)
        """ Test """
        self.test_segmap_image = tf.compat.v1.placeholder(
            tf.float32, [
                1, self.img_height, self.img_width,
                len(img_class.color_value_dict)
            ])
        self.random_test_fake_x, _, _ = self.image_translate(
            segmap_img=self.test_segmap_image, random_style=True, reuse=True)

        self.test_guide_image = tf.compat.v1.placeholder(
            tf.float32, [1, self.img_height, self.img_width, self.img_ch])
        self.guide_test_fake_x, _, _ = self.image_translate(
            segmap_img=self.test_segmap_image,
            x_img=self.test_guide_image,
            reuse=True)
        """ Training """
        t_vars = tf.compat.v1.trainable_variables()
        G_vars = [
            var for var in t_vars
            if 'encoder' in var.name or 'generator' in var.name
        ]
        D_vars = [var for var in t_vars if 'discriminator' in var.name]

        if self.TTUR:
            beta1 = 0.0
            beta2 = 0.9

            g_lr = self.lr / 2
            d_lr = self.lr * 2

        else:
            beta1 = self.beta1
            beta2 = self.beta2
            g_lr = self.lr
            d_lr = self.lr

        self.G_optim = tf.compat.v1.train.AdamOptimizer(g_lr,
                                                        beta1=beta1,
                                                        beta2=beta2).minimize(
                                                            self.g_loss,
                                                            var_list=G_vars)
        self.D_optim = tf.compat.v1.train.AdamOptimizer(d_lr,
                                                        beta1=beta1,
                                                        beta2=beta2).minimize(
                                                            self.d_loss,
                                                            var_list=D_vars)
        """" Summary """
        self.summary_g_loss = tf.compat.v1.summary.scalar(
            "g_loss", self.g_loss)
        self.summary_d_loss = tf.compat.v1.summary.scalar(
            "d_loss", self.d_loss)

        self.summary_g_adv_loss = tf.compat.v1.summary.scalar(
            "g_adv_loss", g_adv_loss)
        self.summary_g_kl_loss = tf.compat.v1.summary.scalar(
            "g_kl_loss", g_kl_loss)
        self.summary_g_vgg_loss = tf.compat.v1.summary.scalar(
            "g_vgg_loss", g_vgg_loss)
        self.summary_g_feature_loss = tf.compat.v1.summary.scalar(
            "g_feature_loss", g_feature_loss)

        g_summary_list = [
            self.summary_g_loss, self.summary_g_adv_loss,
            self.summary_g_kl_loss, self.summary_g_vgg_loss,
            self.summary_g_feature_loss
        ]
        d_summary_list = [self.summary_d_loss]

        self.G_loss = tf.compat.v1.summary.merge(g_summary_list)
        self.D_loss = tf.compat.v1.summary.merge(d_summary_list)
Example #6
0
    def build_model(self):
        os.environ["CUDA_VISIBLE_DEVICES"] = "0"
        self.input = tf.placeholder(tf.float32,
                                    shape=[None, 4, 4, 64],
                                    name='input')
        self.lr_g = tf.placeholder(tf.float32, name='lr_g')
        self.lr_d = tf.placeholder(tf.float32, name='lr_d')
        """ Dataset """
        self.Image_Data = ImageData(self.train_dataset_path,
                                    img_shape=self.img_shape,
                                    augment_flag=self.augment_flag,
                                    data_type=self.train_dataset_type,
                                    img_type=self.train_dataset_img_type,
                                    label_size=self.label_size)

        trainA = tf.data.Dataset.from_tensor_slices(
            (self.Image_Data.train_dataset, self.Image_Data.train_label))

        dataset_num = len(self.Image_Data.train_dataset)
        gpu_device = '/gpu:0'
        trainA = trainA.\
            apply(shuffle_and_repeat(dataset_num)).\
            apply(map_and_batch(self.Image_Data.image_processing, self.batch_size, num_parallel_batches=8, drop_remainder=True)).\
            apply(prefetch_to_device(gpu_device, self.batch_size))

        trainA_iterator = trainA.make_one_shot_iterator()

        self.real_imgs, self.label_o = trainA_iterator.get_next()
        """ generation """
        self.fake_imgs = self.generator(self.input, self.label_o)
        """ Discriminator for real """
        real_logits, real_label = self.discriminator(self.real_imgs)
        """ Discriminator for fake """
        fake_logits, fake_label = self.discriminator(self.fake_imgs,
                                                     reuse=True)
        """ Define Loss """
        if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan':
            grad_pen = self.gradient_panalty(self.real_imgs, self.fake_imgs,
                                             real_logits, fake_logits)
        else:
            grad_pen = 0

        g_cls_loss = classification_loss2(logit=fake_label, label=self.label_o)
        d_cls_loss = classification_loss2(logit=real_label, label=self.label_o)

        dis_loss = discriminator_loss(self.gan_type, real_logits,
                                      fake_logits) + self.ld * grad_pen
        gen_loss = generator_loss(self.gan_type, fake_logits)

        D_loss = dis_loss + self.cls_weight * d_cls_loss
        G_loss = gen_loss + self.cls_weight * g_cls_loss
        """ Optimizer """
        D_loss += regularization_loss('discriminator')
        G_loss += regularization_loss('generator')
        self.gen_optimizer, self.dis_optimizer = self.optimizer_graph_generator(
            G_loss, D_loss, self.lr_g, self.lr_d, self.beta1)
        """ Summaries """
        self.g_summary = summary({
            G_loss: 'G_loss',
            gen_loss: 'gen_loss',
            g_cls_loss: 'g_cls_loss'
        })
        self.d_summary = summary({
            D_loss: 'D_loss',
            dis_loss: 'dis_loss',
            d_cls_loss: 'd_cls_loss'
        })