Beispiel #1
0
    def prefetch_to_consumer_device(self, dataset):
        """
    This must be called on the consumer (trainer) worker,
    i.e. after :func:`map_producer_to_consumer`.

    :param tensorflow.data.Dataset dataset:
    :rtype: tensorflow.data.Dataset
    """
        from tensorflow.python.data.experimental import prefetch_to_device
        return prefetch_to_device(self.get_consumer_device())(dataset)
Beispiel #2
0
def main():
    args = parse_args()
    checkpoint = './checkpoints'
    if not os.path.exists(checkpoint):
        os.makedirs(checkpoint)
    checkpoint_prefix = os.path.join(checkpoint, "ckpt")

    dataset_name = args.dataset
    dataset_path = './dataset'
    if (args.phase == 'train'):
        datapath = os.path.join(dataset_path, dataset_name, 'train')
    else:
        datapath = os.path.join(dataset_path, dataset_name, 'test')

    img_class = Image_data(img_width=args.img_width,
                           img_height=args.img_height,
                           img_depth=args.img_depth,
                           dataset_path=datapath)
    img_class.preprocess()
    dataset = tf.data.Dataset.from_tensor_slices(img_class.dataset)
    dataset_num = len(
        img_class.dataset)  # all the images with different domain
    print("Dataset number : ", dataset_num)
    gpu_device = '/gpu:0'
    data_set = dataset.shuffle(buffer_size=dataset_num,
                               reshuffle_each_iteration=True).repeat()
    data_set = data_set.batch(args.batch_size, drop_remainder=True)
    data_set = data_set.apply(
        prefetch_to_device(gpu_device, buffer_size=AUTOTUNE))
    data_set_iter = iter(data_set)
    gan = GAN(args.img_width,
              args.img_height,
              args.img_depth,
              args.img_channel,
              data_set_iter,
              batch_size=args.batch_size,
              epochs=args.epochs,
              save_interval=args.save_interval,
              dataset_name=args.dataset,
              checkpoint_prefix=checkpoint_prefix,
              z=args.latentdimension)
    gan.train()
Beispiel #3
0
    def build_model(self):
        if self.phase == 'train':
            """ Input Image"""
            img_class = Image_data(self.img_height, self.img_width,
                                   self.img_ch, self.dataset_path,
                                   self.augment_flag)
            img_class.preprocess()
            dataset_num = max(len(img_class.train_A_dataset),
                              len(img_class.train_B_dataset))

            print("Dataset number : ", dataset_num)

            img_slice_A = tf.data.Dataset.from_tensor_slices(
                img_class.train_A_dataset)
            img_slice_B = tf.data.Dataset.from_tensor_slices(
                img_class.train_B_dataset)

            gpu_device = '/gpu:0'
            img_slice_A = img_slice_A. \
                apply(shuffle_and_repeat(dataset_num)). \
                apply(map_and_batch(img_class.image_processing, self.batch_size, num_parallel_batches=AUTOTUNE,
                                    drop_remainder=True)). \
                apply(prefetch_to_device(gpu_device, AUTOTUNE))

            img_slice_B = img_slice_B. \
                apply(shuffle_and_repeat(dataset_num)). \
                apply(map_and_batch(img_class.image_processing, self.batch_size, num_parallel_batches=AUTOTUNE,
                                    drop_remainder=True)). \
                apply(prefetch_to_device(gpu_device, AUTOTUNE))

            self.dataset_A_iter = iter(img_slice_A)
            self.dataset_B_iter = iter(img_slice_B)
            """ Network """
            self.source_generator = Generator(self.ch,
                                              self.n_res,
                                              name='source_generator')
            self.target_generator = Generator(self.ch,
                                              self.n_res,
                                              name='target_generator')
            self.source_discriminator = Discriminator(
                self.ch, self.n_dis, self.sn, name='source_discriminator')
            self.target_discriminator = Discriminator(
                self.ch, self.n_dis, self.sn, name='target_discriminator')
            """ Optimizer """
            self.g_optimizer = tf.keras.optimizers.Adam(
                learning_rate=self.init_lr,
                beta_1=0.5,
                beta_2=0.999,
                epsilon=1e-08)
            self.d_optimizer = tf.keras.optimizers.Adam(
                learning_rate=self.init_lr,
                beta_1=0.5,
                beta_2=0.999,
                epsilon=1e-08)
            """ Summary """
            # mean metric
            self.g_adv_loss_metric = tf.keras.metrics.Mean('g_adv_loss',
                                                           dtype=tf.float32)
            self.g_cyc_loss_metric = tf.keras.metrics.Mean('g_cyc_loss',
                                                           dtype=tf.float32)
            self.g_identity_loss_metric = tf.keras.metrics.Mean(
                'g_identity_loss', dtype=tf.float32)
            self.g_loss_metric = tf.keras.metrics.Mean('g_loss',
                                                       dtype=tf.float32)

            self.d_adv_loss_metric = tf.keras.metrics.Mean('d_adv_loss',
                                                           dtype=tf.float32)
            self.d_loss_metric = tf.keras.metrics.Mean('d_loss',
                                                       dtype=tf.float32)

            input_shape = [self.img_height, self.img_width, self.img_ch]
            self.source_generator.build_summary(input_shape)
            self.source_discriminator.build_summary(input_shape)
            self.target_generator.build_summary(input_shape)
            self.target_discriminator.build_summary(input_shape)
            """ Count parameters """
            params = self.source_generator.count_parameter() + self.target_generator.count_parameter() \
                     + self.source_discriminator.count_parameter() + self.target_discriminator.count_parameter()

            print("Total network parameters : ", format(params, ','))
            """ Checkpoint """
            self.ckpt = tf.train.Checkpoint(
                source_generator=self.source_generator,
                target_generator=self.target_generator,
                source_discriminator=self.source_discriminator,
                target_discriminator=self.target_discriminator,
                g_optimizer=self.g_optimizer,
                d_optimizer=self.d_optimizer)
            self.manager = tf.train.CheckpointManager(self.ckpt,
                                                      self.checkpoint_dir,
                                                      max_to_keep=2)
            self.start_iteration = 0

            if self.manager.latest_checkpoint:
                self.ckpt.restore(self.manager.latest_checkpoint)
                self.start_iteration = int(
                    self.manager.latest_checkpoint.split('-')[-1])
                print('Latest checkpoint restored!!')
                print('start iteration : ', self.start_iteration)
            else:
                print('Not restoring from saved checkpoint')

        else:
            """ Test """
            """ Network """
            self.source_generator = Generator(self.ch,
                                              self.n_res,
                                              name='source_generator')
            self.target_generator = Generator(self.ch,
                                              self.n_res,
                                              name='target_generator')
            self.source_discriminator = Discriminator(
                self.ch, self.n_dis, self.sn, name='source_discriminator')
            self.target_discriminator = Discriminator(
                self.ch, self.n_dis, self.sn, name='target_discriminator')

            input_shape = [self.img_height, self.img_width, self.img_ch]
            self.source_generator.build_summary(input_shape)
            self.source_discriminator.build_summary(input_shape)
            self.target_generator.build_summary(input_shape)
            self.target_discriminator.build_summary(input_shape)
            """ Count parameters """
            params = self.source_generator.count_parameter() + self.target_generator.count_parameter() \
                     + self.source_discriminator.count_parameter() + self.target_discriminator.count_parameter()

            print("Total network parameters : ", format(params, ','))
            """ Checkpoint """
            self.ckpt = tf.train.Checkpoint(
                source_generator=self.source_generator,
                target_generator=self.target_generator,
                source_discriminator=self.source_discriminator,
                target_discriminator=self.target_discriminator)
            self.manager = tf.train.CheckpointManager(self.ckpt,
                                                      self.checkpoint_dir,
                                                      max_to_keep=2)

            if self.manager.latest_checkpoint:
                self.ckpt.restore(
                    self.manager.latest_checkpoint).expect_partial()
                print('Latest checkpoint restored!!')
            else:
                print('Not restoring from saved checkpoint')
    def build_model(self):
        """ Input Image"""
        img_data_class = Image_data(self.img_height, self.img_width,
                                    self.img_ch, self.dataset_path,
                                    self.augment_flag)
        train_class_id, train_captions, train_images, test_captions, test_images, idx_to_word, word_to_idx = img_data_class.preprocess(
        )
        self.vocab_size = len(idx_to_word)
        self.idx_to_word = idx_to_word
        self.word_to_idx = word_to_idx
        """
        train_captions: (8855, 10, 18), test_captions: (2933, 10, 18)
        train_images: (8855,), test_images: (2933,)
        idx_to_word : 5450 5450
        """

        if self.phase == 'train':
            self.dataset_num = len(train_images)

            img_and_caption = tf.data.Dataset.from_tensor_slices(
                (train_images, train_captions, train_class_id))

            gpu_device = '/gpu:0'
            img_and_caption = img_and_caption.apply(
                shuffle_and_repeat(self.dataset_num)).apply(
                    map_and_batch(img_data_class.image_processing,
                                  batch_size=self.batch_size,
                                  num_parallel_batches=16,
                                  drop_remainder=True)).apply(
                                      prefetch_to_device(gpu_device, None))

            self.img_caption_iter = iter(img_and_caption)
            # real_img_256, caption = iter(img_and_caption)
            """ Network """
            self.rnn_encoder = RnnEncoder(n_words=self.vocab_size,
                                          embed_dim=self.embed_dim,
                                          drop_rate=0.5,
                                          n_hidden=128,
                                          n_layer=1,
                                          bidirectional=True,
                                          rnn_type='lstm')
            self.cnn_encoder = CnnEncoder(embed_dim=self.embed_dim)

            self.ca_net = CA_NET(c_dim=self.z_dim)
            self.generator = Generator(channels=self.g_dim)

            self.discriminator = Discriminator(channels=self.d_dim,
                                               embed_dim=self.embed_dim)
            """ Optimizer """
            self.g_optimizer = tf.keras.optimizers.Adam(
                learning_rate=self.init_lr,
                beta_1=0.5,
                beta_2=0.999,
                epsilon=1e-08)

            d_64_optimizer = tf.keras.optimizers.Adam(
                learning_rate=self.init_lr,
                beta_1=0.5,
                beta_2=0.999,
                epsilon=1e-08)
            d_128_optimizer = tf.keras.optimizers.Adam(
                learning_rate=self.init_lr,
                beta_1=0.5,
                beta_2=0.999,
                epsilon=1e-08)
            d_256_optimizer = tf.keras.optimizers.Adam(
                learning_rate=self.init_lr,
                beta_1=0.5,
                beta_2=0.999,
                epsilon=1e-08)
            self.d_optimizer = [
                d_64_optimizer, d_128_optimizer, d_256_optimizer
            ]

            self.embed_optimizer = tf.keras.optimizers.Adam(
                learning_rate=self.init_lr,
                beta_1=0.5,
                beta_2=0.999,
                epsilon=1e-08)
            """ Checkpoint """
            self.ckpt = tf.train.Checkpoint(
                rnn_encoder=self.rnn_encoder,
                cnn_encoder=self.cnn_encoder,
                ca_net=self.ca_net,
                generator=self.generator,
                discriminator=self.discriminator,
                g_optimizer=self.g_optimizer,
                d_64_optimizer=d_64_optimizer,
                d_128_optimizer=d_128_optimizer,
                d_256_optimizer=d_256_optimizer,
                embed_optimizer=self.embed_optimizer)
            self.manager = tf.train.CheckpointManager(self.ckpt,
                                                      self.checkpoint_dir,
                                                      max_to_keep=2)
            self.start_iteration = 0

            if self.manager.latest_checkpoint:
                self.ckpt.restore(self.manager.latest_checkpoint)
                self.start_iteration = int(
                    self.manager.latest_checkpoint.split('-')[-1])
                print('Latest checkpoint restored!!')
                print('start iteration : ', self.start_iteration)
            else:
                print('Not restoring from saved checkpoint')

        else:
            """ Test """
            self.dataset_num = len(test_captions)

            gpu_device = '/gpu:0'
            img_and_caption = tf.data.Dataset.from_tensor_slices(
                (train_images, train_captions))

            img_and_caption = img_and_caption.apply(
                shuffle_and_repeat(self.dataset_num)).apply(
                    map_and_batch(img_data_class.image_processing,
                                  batch_size=self.batch_size,
                                  num_parallel_batches=16,
                                  drop_remainder=True)).apply(
                                      prefetch_to_device(gpu_device, None))

            self.img_caption_iter = iter(img_and_caption)
            """ Network """
            self.rnn_encoder = RnnEncoder(n_words=self.vocab_size,
                                          embed_dim=self.embed_dim,
                                          drop_rate=0.5,
                                          n_hidden=128,
                                          n_layer=1,
                                          bidirectional=True,
                                          rnn_type='lstm')
            self.cnn_encoder = CnnEncoder(embed_dim=self.embed_dim)
            self.ca_net = CA_NET(c_dim=self.z_dim)
            self.generator = Generator(channels=self.g_dim)

            # self.discriminator = Discriminator(channels=self.d_dim, embed_dim=self.embed_dim)

            # """ Optimizer """
            # self.g_optimizer = tf.keras.optimizers.Adam(learning_rate=self.init_lr, beta_1=0.5, beta_2=0.999, epsilon=1e-08)

            # d_64_optimizer = tf.keras.optimizers.Adam(learning_rate=self.init_lr, beta_1=0.5, beta_2=0.999, epsilon=1e-08)
            # d_128_optimizer = tf.keras.optimizers.Adam(learning_rate=self.init_lr, beta_1=0.5, beta_2=0.999, epsilon=1e-08)
            # d_256_optimizer = tf.keras.optimizers.Adam(learning_rate=self.init_lr, beta_1=0.5, beta_2=0.999, epsilon=1e-08)
            # self.d_optimizer = [d_64_optimizer, d_128_optimizer, d_256_optimizer]

            # self.embed_optimizer = tf.keras.optimizers.Adam(learning_rate=self.init_lr, beta_1=0.5, beta_2=0.999, epsilon=1e-08)
            """ Checkpoint """
            self.ckpt = tf.train.Checkpoint(rnn_encoder=self.rnn_encoder,
                                            cnn_encoder=self.cnn_encoder,
                                            ca_net=self.ca_net,
                                            generator=self.generator)
            # discriminator=self.discriminator,
            # g_optimizer=self.g_optimizer,
            # d_64_optimizer=d_64_optimizer,
            # d_128_optimizer=d_128_optimizer,
            # d_256_optimizer=d_256_optimizer,
            # embed_optimizer=self.embed_optimizer)
            self.manager = tf.train.CheckpointManager(self.ckpt,
                                                      self.checkpoint_dir,
                                                      max_to_keep=2)
            self.start_iteration = 0

            if self.manager.latest_checkpoint:
                self.ckpt.restore(
                    self.manager.latest_checkpoint).expect_partial()
                self.start_iteration = int(
                    self.manager.latest_checkpoint.split('-')[-1])
                print('Latest checkpoint restored!!')
                print('start iteration : ', self.start_iteration)
            else:
                print('Not restoring from saved checkpoint')
Beispiel #5
0
    def build_model(self):
        if self.phase == 'train':
            """ Input Image"""
            img_class = Image_data(self.img_height, self.img_width, self.img_ch, self.dataset_path, self.augment_flag)
            img_class.preprocess()
            dataset_num = len(img_class.dataset)

            print("Dataset number : ", dataset_num)

            img_slice = tf.data.Dataset.from_tensor_slices(img_class.dataset)

            gpu_device = '/gpu:0'
            img_slice = img_slice. \
                apply(shuffle_and_repeat(dataset_num)). \
                apply(map_and_batch(img_class.image_processing, self.batch_size, num_parallel_batches=AUTOTUNE,
                                    drop_remainder=True)). \
                apply(prefetch_to_device(gpu_device, AUTOTUNE))

            self.dataset_iter = iter(img_slice)

            """ Network """
            self.classifier = SubNetwork(channels=64, name='classifier')

            """ Optimizer """
            self.optimizer = tf.keras.optimizers.Adam(learning_rate=self.init_lr, beta_1=0.5, beta_2=0.999,
                                                      epsilon=1e-08)

            """ Summary """
            # mean metric
            self.loss_metric = tf.keras.metrics.Mean('loss',
                                                     dtype=tf.float32)  # In tensorboard, make a loss to smooth graph

            # print summary
            input_shape = [self.img_height, self.img_width, self.img_ch]
            self.classifier.build_summary(input_shape)

            """ Count parameters """
            params = self.classifier.count_parameter()
            print("Total network parameters : ", format(params, ','))

            """ Checkpoint """
            self.ckpt = tf.train.Checkpoint(classifier=self.classifier, optimizer=self.optimizer)
            self.manager = tf.train.CheckpointManager(self.ckpt, self.checkpoint_dir, max_to_keep=2)
            self.start_iteration = 0

            if self.manager.latest_checkpoint:
                self.ckpt.restore(self.manager.latest_checkpoint)
                self.start_iteration = int(self.manager.latest_checkpoint.split('-')[-1])
                print('Latest checkpoint restored!!')
                print('start iteration : ', self.start_iteration)
            else:
                print('Not restoring from saved checkpoint')

        else:
            """ Test """

            """ Network """
            self.classifier = SubNetwork(channels=64, name='classifier')

            """ Summary """
            input_shape = [self.img_height, self.img_width, self.img_ch]
            self.classifier.build_summary(input_shape)

            """ Count parameters """
            params = self.classifier.count_parameter()
            print("Total network parameters : ", format(params, ','))

            """ Checkpoint """
            self.ckpt = tf.train.Checkpoint(classifier=self.classifier)
            self.manager = tf.train.CheckpointManager(self.ckpt, self.checkpoint_dir, max_to_keep=2)

            if self.manager.latest_checkpoint:
                self.ckpt.restore(self.manager.latest_checkpoint).expect_partial()
                print('Latest checkpoint restored!!')
            else:
                print('Not restoring from saved checkpoint')
Beispiel #6
0
    def build_model(self):
        """ Input Image"""
        img_data_class = Image_data(self.img_height, self.img_width,
                                    self.img_ch, self.dataset_path,
                                    self.augment_flag)
        train_captions, train_images, test_captions, test_images, idx_to_word, word_to_idx = img_data_class.preprocess(
        )
        """
        train_captions: (8855, 10, 18), test_captions: (2933, 10, 18)
        train_images: (8855,), test_images: (2933,)
        idx_to_word : 5450 5450
        """

        if self.phase == 'train':
            self.lr = tf.placeholder(tf.float32, name='learning_rate')

            self.dataset_num = len(train_images)

            img_and_caption = tf.data.Dataset.from_tensor_slices(
                (train_images, train_captions))

            gpu_device = '/gpu:0'
            img_and_caption = img_and_caption.apply(
                shuffle_and_repeat(self.dataset_num)).apply(
                    map_and_batch(img_data_class.image_processing,
                                  batch_size=self.batch_size,
                                  num_parallel_batches=16,
                                  drop_remainder=True)).apply(
                                      prefetch_to_device(gpu_device, None))

            img_and_caption_iterator = img_and_caption.make_one_shot_iterator()
            real_img_256, caption = img_and_caption_iterator.get_next()
            target_sentence_index = tf.random_uniform(shape=[],
                                                      minval=0,
                                                      maxval=10,
                                                      dtype=tf.int32)
            caption = tf.gather(caption, target_sentence_index, axis=1)

            word_emb, sent_emb, mask = self.rnn_encoder(
                caption,
                n_words=len(idx_to_word),
                embed_dim=self.embed_dim,
                drop_rate=0.5,
                n_hidden=128,
                n_layers=1,
                bidirectional=True,
                rnn_type='lstm')

            noise = tf.random_normal(shape=[self.batch_size, self.z_dim],
                                     mean=0.0,
                                     stddev=1.0)
            fake_imgs, _, mu, logvar = self.generator(noise, sent_emb,
                                                      word_emb, mask)

            real_img_64, real_img_128 = resize(real_img_256,
                                               target_size=[64, 64]), resize(
                                                   real_img_256,
                                                   target_size=[128, 128])
            fake_img_64, fake_img_128, fake_img_256 = fake_imgs[0], fake_imgs[
                1], fake_imgs[2]

            uncond_real_logits, cond_real_logits = self.discriminator(
                [real_img_64, real_img_128, real_img_256], sent_emb)
            _, cond_wrong_logits = self.discriminator([
                real_img_64[:(self.batch_size - 1)],
                real_img_128[:(self.batch_size - 1)],
                real_img_256[:(self.batch_size - 1)]
            ], sent_emb[1:self.batch_size])
            uncond_fake_logits, cond_fake_logits = self.discriminator(
                [fake_img_64, fake_img_128, fake_img_256], sent_emb)

            fake_img_256_feature = self.caption_cnn(fake_img_256)
            fake_img_256_caption = self.caption_rnn(fake_img_256_feature,
                                                    caption,
                                                    n_words=len(idx_to_word),
                                                    embed_dim=self.embed_dim,
                                                    n_hidden=256 * 2,
                                                    n_layers=1)

            self.g_adv_loss, self.d_adv_loss = 0, 0
            for i in range(3):
                self.g_adv_loss += self.adv_weight * (
                    generator_loss(self.gan_type, uncond_fake_logits[i]) +
                    generator_loss(self.gan_type, cond_fake_logits[i]))

                uncond_real_loss, uncond_fake_loss = discriminator_loss(
                    self.gan_type, uncond_real_logits[i],
                    uncond_fake_logits[i])
                cond_real_loss, cond_fake_loss = discriminator_loss(
                    self.gan_type, cond_real_logits[i], cond_fake_logits[i])
                _, cond_wrong_loss = discriminator_loss(
                    self.gan_type, None, cond_wrong_logits[i])

                self.d_adv_loss += self.adv_weight * (
                    ((uncond_real_loss + cond_real_loss) / 2) +
                    (uncond_fake_loss + cond_fake_loss + cond_wrong_loss) / 3)

            self.g_kl_loss = self.kl_weight * kl_loss(mu, logvar)
            caption = tf.one_hot(caption, len(idx_to_word))
            caption = tf.reshape(caption, [-1, len(idx_to_word)])
            self.g_cap_loss = self.cap_weight * caption_loss(
                fake_img_256_caption, caption)

            self.g_loss = self.g_adv_loss + self.g_kl_loss + self.g_cap_loss
            self.d_loss = self.d_adv_loss

            self.real_img = real_img_256
            self.fake_img = fake_img_256
            """ Training """
            t_vars = tf.trainable_variables()
            G_vars = [var for var in t_vars if 'generator' in var.name]
            D_vars = [var for var in t_vars if 'discriminator' in var.name]

            self.g_optim = tf.train.AdamOptimizer(self.lr,
                                                  beta1=0.5,
                                                  beta2=0.999).minimize(
                                                      self.g_loss,
                                                      var_list=G_vars)
            self.d_optim = tf.train.AdamOptimizer(self.lr,
                                                  beta1=0.5,
                                                  beta2=0.999).minimize(
                                                      self.d_loss,
                                                      var_list=D_vars)
            """" Summary """
            self.summary_g_loss = tf.summary.scalar("g_loss", self.g_loss)
            self.summary_d_loss = tf.summary.scalar("d_loss", self.d_loss)

            self.summary_g_adv_loss = tf.summary.scalar(
                "g_adv_loss", self.g_adv_loss)
            self.summary_g_kl_loss = tf.summary.scalar("g_kl_loss",
                                                       self.g_kl_loss)
            self.summary_g_cap_loss = tf.summary.scalar(
                "g_cap_loss", self.g_cap_loss)

            self.summary_d_adv_loss = tf.summary.scalar(
                "d_adv_loss", self.d_adv_loss)

            g_summary_list = [
                self.summary_g_loss, self.summary_g_adv_loss,
                self.summary_g_kl_loss, self.summary_g_cap_loss
            ]

            d_summary_list = [self.summary_d_loss, self.summary_d_adv_loss]

            self.summary_merge_g_loss = tf.summary.merge(g_summary_list)
            self.summary_merge_d_loss = tf.summary.merge(d_summary_list)

        else:
            """ Test """
            self.dataset_num = len(test_captions)

            gpu_device = '/gpu:0'
            img_and_caption = tf.data.Dataset.from_tensor_slices(
                (test_images, test_captions))

            img_and_caption = img_and_caption.apply(
                shuffle_and_repeat(self.dataset_num)).apply(
                    map_and_batch(img_data_class.image_processing,
                                  batch_size=self.batch_size,
                                  num_parallel_batches=16,
                                  drop_remainder=True)).apply(
                                      prefetch_to_device(gpu_device, None))

            img_and_caption_iterator = img_and_caption.make_one_shot_iterator()
            real_img_256, caption = img_and_caption_iterator.get_next()
            target_sentence_index = tf.random_uniform(shape=[],
                                                      minval=0,
                                                      maxval=10,
                                                      dtype=tf.int32)
            caption = tf.gather(caption, target_sentence_index, axis=1)

            word_emb, sent_emb, mask = self.rnn_encoder(
                caption,
                n_words=len(idx_to_word),
                embed_dim=self.embed_dim,
                drop_rate=0.5,
                n_hidden=128,
                n_layers=1,
                bidirectional=True,
                rnn_type='lstm',
                is_training=False)

            noise = tf.random_normal(shape=[self.batch_size, self.z_dim],
                                     mean=0.0,
                                     stddev=1.0)
            fake_imgs, _, _, _ = self.generator(noise,
                                                sent_emb,
                                                word_emb,
                                                mask,
                                                is_training=False)

            self.test_real_img = real_img_256
            self.test_fake_img = fake_imgs[2]
Beispiel #7
0
    def build_model(self):
        if self.phase == 'train':
            """ Input Image"""
            img_class = Image_data(self.img_size, self.img_ch,
                                   self.dataset_path, self.augment_flag)
            img_class.preprocess()

            dataset_num = len(img_class.mask_images) + len(
                img_class.nomask_images)
            print("Dataset number : ", dataset_num)

            img_and_domain = tf.data.Dataset.from_tensor_slices(
                (img_class.mask_images, img_class.mask_masks,
                 img_class.nomask_images, img_class.nomask_masks,
                 img_class.nomask_images2, img_class.nomask_masks2))

            gpu_device = '/gpu:0'

            img_and_domain = img_and_domain.shuffle(
                buffer_size=dataset_num,
                reshuffle_each_iteration=True).repeat()
            img_and_domain = img_and_domain.map(
                map_func=img_class.image_processing,
                num_parallel_calls=AUTOTUNE).batch(self.batch_size,
                                                   drop_remainder=True)
            img_and_domain = img_and_domain.apply(
                prefetch_to_device(gpu_device, buffer_size=AUTOTUNE))

            self.img_and_domain_iter = iter(img_and_domain)
            """ Network """
            self.generator = Generator(self.img_size,
                                       self.img_ch,
                                       self.style_dim,
                                       max_conv_dim=self.hidden_dim,
                                       sn=False,
                                       name='Generator')
            self.mapping_network = MappingNetwork(self.style_dim,
                                                  self.hidden_dim,
                                                  sn=False,
                                                  name='MappingNetwork')
            self.style_encoder = StyleEncoder(self.img_size,
                                              self.style_dim,
                                              max_conv_dim=self.hidden_dim,
                                              sn=False,
                                              name='StyleEncoder')
            self.discriminator = Discriminator(self.img_size,
                                               max_conv_dim=self.hidden_dim,
                                               sn=self.sn,
                                               name='Discriminator')

            self.generator_ema = deepcopy(self.generator)
            self.mapping_network_ema = deepcopy(self.mapping_network)
            self.style_encoder_ema = deepcopy(self.style_encoder)
            """ Finalize model (build) """
            x = np.ones(shape=[
                self.batch_size, self.img_size, self.img_size, self.img_ch
            ],
                        dtype=np.float32)
            z = np.ones(shape=[self.batch_size, self.latent_dim],
                        dtype=np.float32)
            s = np.ones(shape=[self.batch_size, self.style_dim],
                        dtype=np.float32)
            m = np.ones(shape=[self.batch_size, self.img_size, self.img_size],
                        dtype=np.bool)

            _ = self.mapping_network(z)
            _ = self.mapping_network_ema(z)
            _ = self.style_encoder(x)
            _ = self.style_encoder_ema(x)
            _ = self.generator([x, s, m])
            _ = self.generator_ema([x, s, m])
            _ = self.discriminator([x, m])
            """ Optimizer """
            self.g_optimizer = tf.keras.optimizers.Adam(learning_rate=self.lr,
                                                        beta_1=self.beta1,
                                                        beta_2=self.beta2,
                                                        epsilon=1e-08)
            self.e_optimizer = tf.keras.optimizers.Adam(learning_rate=self.lr,
                                                        beta_1=self.beta1,
                                                        beta_2=self.beta2,
                                                        epsilon=1e-08)
            self.f_optimizer = tf.keras.optimizers.Adam(
                learning_rate=self.f_lr,
                beta_1=self.beta1,
                beta_2=self.beta2,
                epsilon=1e-08)
            self.d_optimizer = tf.keras.optimizers.Adam(learning_rate=self.lr,
                                                        beta_1=self.beta1,
                                                        beta_2=self.beta2,
                                                        epsilon=1e-08)
            """ Checkpoint """
            self.ckpt = tf.train.Checkpoint(
                generator=self.generator,
                generator_ema=self.generator_ema,
                mapping_network=self.mapping_network,
                mapping_network_ema=self.mapping_network_ema,
                style_encoder=self.style_encoder,
                style_encoder_ema=self.style_encoder_ema,
                discriminator=self.discriminator,
                g_optimizer=self.g_optimizer,
                e_optimizer=self.e_optimizer,
                f_optimizer=self.f_optimizer,
                d_optimizer=self.d_optimizer)
            self.manager = tf.train.CheckpointManager(self.ckpt,
                                                      self.checkpoint_dir,
                                                      max_to_keep=10)
            self.start_iteration = 0

            if self.manager.latest_checkpoint:
                self.ckpt.restore(
                    self.manager.latest_checkpoint).expect_partial()
                self.start_iteration = int(
                    self.manager.latest_checkpoint.split('-')[-1])
                print('Latest checkpoint restored!!')
                print('start iteration : ', self.start_iteration)
            else:
                print('Not restoring from saved checkpoint')

        else:
            """ Test """
            """ Network """
            self.generator_ema = Generator(self.img_size,
                                           self.img_ch,
                                           self.style_dim,
                                           max_conv_dim=self.hidden_dim,
                                           sn=False,
                                           name='Generator')
            self.mapping_network_ema = MappingNetwork(self.style_dim,
                                                      self.hidden_dim,
                                                      sn=False,
                                                      name='MappingNetwork')
            self.style_encoder_ema = StyleEncoder(self.img_size,
                                                  self.style_dim,
                                                  max_conv_dim=self.hidden_dim,
                                                  sn=False,
                                                  name='StyleEncoder')
            """ Finalize model (build) """
            x = np.ones(shape=[
                self.batch_size, self.img_size, self.img_size, self.img_ch
            ],
                        dtype=np.float32)
            z = np.ones(shape=[self.batch_size, self.latent_dim],
                        dtype=np.float32)
            s = np.ones(shape=[self.batch_size, self.style_dim],
                        dtype=np.float32)
            m = np.ones(shape=[self.batch_size, self.img_size, self.img_size],
                        dtype=np.bool)

            _ = self.mapping_network_ema(z, training=False)
            _ = self.style_encoder_ema(x, training=False)
            _ = self.generator_ema([x, s, m], training=False)
            """ Checkpoint """
            self.ckpt = tf.train.Checkpoint(
                generator_ema=self.generator_ema,
                mapping_network_ema=self.mapping_network_ema,
                style_encoder_ema=self.style_encoder_ema)
            self.manager = tf.train.CheckpointManager(self.ckpt,
                                                      self.checkpoint_dir,
                                                      max_to_keep=10)

            if self.manager.latest_checkpoint:
                self.ckpt.restore(
                    self.manager.latest_checkpoint).expect_partial()
                print('Latest checkpoint restored!!')
            else:
                print('Not restoring from saved checkpoint')