def prefetch_to_consumer_device(self, dataset): """ This must be called on the consumer (trainer) worker, i.e. after :func:`map_producer_to_consumer`. :param tensorflow.data.Dataset dataset: :rtype: tensorflow.data.Dataset """ from tensorflow.python.data.experimental import prefetch_to_device return prefetch_to_device(self.get_consumer_device())(dataset)
def main(): args = parse_args() checkpoint = './checkpoints' if not os.path.exists(checkpoint): os.makedirs(checkpoint) checkpoint_prefix = os.path.join(checkpoint, "ckpt") dataset_name = args.dataset dataset_path = './dataset' if (args.phase == 'train'): datapath = os.path.join(dataset_path, dataset_name, 'train') else: datapath = os.path.join(dataset_path, dataset_name, 'test') img_class = Image_data(img_width=args.img_width, img_height=args.img_height, img_depth=args.img_depth, dataset_path=datapath) img_class.preprocess() dataset = tf.data.Dataset.from_tensor_slices(img_class.dataset) dataset_num = len( img_class.dataset) # all the images with different domain print("Dataset number : ", dataset_num) gpu_device = '/gpu:0' data_set = dataset.shuffle(buffer_size=dataset_num, reshuffle_each_iteration=True).repeat() data_set = data_set.batch(args.batch_size, drop_remainder=True) data_set = data_set.apply( prefetch_to_device(gpu_device, buffer_size=AUTOTUNE)) data_set_iter = iter(data_set) gan = GAN(args.img_width, args.img_height, args.img_depth, args.img_channel, data_set_iter, batch_size=args.batch_size, epochs=args.epochs, save_interval=args.save_interval, dataset_name=args.dataset, checkpoint_prefix=checkpoint_prefix, z=args.latentdimension) gan.train()
def build_model(self): if self.phase == 'train': """ Input Image""" img_class = Image_data(self.img_height, self.img_width, self.img_ch, self.dataset_path, self.augment_flag) img_class.preprocess() dataset_num = max(len(img_class.train_A_dataset), len(img_class.train_B_dataset)) print("Dataset number : ", dataset_num) img_slice_A = tf.data.Dataset.from_tensor_slices( img_class.train_A_dataset) img_slice_B = tf.data.Dataset.from_tensor_slices( img_class.train_B_dataset) gpu_device = '/gpu:0' img_slice_A = img_slice_A. \ apply(shuffle_and_repeat(dataset_num)). \ apply(map_and_batch(img_class.image_processing, self.batch_size, num_parallel_batches=AUTOTUNE, drop_remainder=True)). \ apply(prefetch_to_device(gpu_device, AUTOTUNE)) img_slice_B = img_slice_B. \ apply(shuffle_and_repeat(dataset_num)). \ apply(map_and_batch(img_class.image_processing, self.batch_size, num_parallel_batches=AUTOTUNE, drop_remainder=True)). \ apply(prefetch_to_device(gpu_device, AUTOTUNE)) self.dataset_A_iter = iter(img_slice_A) self.dataset_B_iter = iter(img_slice_B) """ Network """ self.source_generator = Generator(self.ch, self.n_res, name='source_generator') self.target_generator = Generator(self.ch, self.n_res, name='target_generator') self.source_discriminator = Discriminator( self.ch, self.n_dis, self.sn, name='source_discriminator') self.target_discriminator = Discriminator( self.ch, self.n_dis, self.sn, name='target_discriminator') """ Optimizer """ self.g_optimizer = tf.keras.optimizers.Adam( learning_rate=self.init_lr, beta_1=0.5, beta_2=0.999, epsilon=1e-08) self.d_optimizer = tf.keras.optimizers.Adam( learning_rate=self.init_lr, beta_1=0.5, beta_2=0.999, epsilon=1e-08) """ Summary """ # mean metric self.g_adv_loss_metric = tf.keras.metrics.Mean('g_adv_loss', dtype=tf.float32) self.g_cyc_loss_metric = tf.keras.metrics.Mean('g_cyc_loss', dtype=tf.float32) self.g_identity_loss_metric = tf.keras.metrics.Mean( 'g_identity_loss', dtype=tf.float32) self.g_loss_metric = tf.keras.metrics.Mean('g_loss', dtype=tf.float32) self.d_adv_loss_metric = tf.keras.metrics.Mean('d_adv_loss', dtype=tf.float32) self.d_loss_metric = tf.keras.metrics.Mean('d_loss', dtype=tf.float32) input_shape = [self.img_height, self.img_width, self.img_ch] self.source_generator.build_summary(input_shape) self.source_discriminator.build_summary(input_shape) self.target_generator.build_summary(input_shape) self.target_discriminator.build_summary(input_shape) """ Count parameters """ params = self.source_generator.count_parameter() + self.target_generator.count_parameter() \ + self.source_discriminator.count_parameter() + self.target_discriminator.count_parameter() print("Total network parameters : ", format(params, ',')) """ Checkpoint """ self.ckpt = tf.train.Checkpoint( source_generator=self.source_generator, target_generator=self.target_generator, source_discriminator=self.source_discriminator, target_discriminator=self.target_discriminator, g_optimizer=self.g_optimizer, d_optimizer=self.d_optimizer) self.manager = tf.train.CheckpointManager(self.ckpt, self.checkpoint_dir, max_to_keep=2) self.start_iteration = 0 if self.manager.latest_checkpoint: self.ckpt.restore(self.manager.latest_checkpoint) self.start_iteration = int( self.manager.latest_checkpoint.split('-')[-1]) print('Latest checkpoint restored!!') print('start iteration : ', self.start_iteration) else: print('Not restoring from saved checkpoint') else: """ Test """ """ Network """ self.source_generator = Generator(self.ch, self.n_res, name='source_generator') self.target_generator = Generator(self.ch, self.n_res, name='target_generator') self.source_discriminator = Discriminator( self.ch, self.n_dis, self.sn, name='source_discriminator') self.target_discriminator = Discriminator( self.ch, self.n_dis, self.sn, name='target_discriminator') input_shape = [self.img_height, self.img_width, self.img_ch] self.source_generator.build_summary(input_shape) self.source_discriminator.build_summary(input_shape) self.target_generator.build_summary(input_shape) self.target_discriminator.build_summary(input_shape) """ Count parameters """ params = self.source_generator.count_parameter() + self.target_generator.count_parameter() \ + self.source_discriminator.count_parameter() + self.target_discriminator.count_parameter() print("Total network parameters : ", format(params, ',')) """ Checkpoint """ self.ckpt = tf.train.Checkpoint( source_generator=self.source_generator, target_generator=self.target_generator, source_discriminator=self.source_discriminator, target_discriminator=self.target_discriminator) self.manager = tf.train.CheckpointManager(self.ckpt, self.checkpoint_dir, max_to_keep=2) if self.manager.latest_checkpoint: self.ckpt.restore( self.manager.latest_checkpoint).expect_partial() print('Latest checkpoint restored!!') else: print('Not restoring from saved checkpoint')
def build_model(self): """ Input Image""" img_data_class = Image_data(self.img_height, self.img_width, self.img_ch, self.dataset_path, self.augment_flag) train_class_id, train_captions, train_images, test_captions, test_images, idx_to_word, word_to_idx = img_data_class.preprocess( ) self.vocab_size = len(idx_to_word) self.idx_to_word = idx_to_word self.word_to_idx = word_to_idx """ train_captions: (8855, 10, 18), test_captions: (2933, 10, 18) train_images: (8855,), test_images: (2933,) idx_to_word : 5450 5450 """ if self.phase == 'train': self.dataset_num = len(train_images) img_and_caption = tf.data.Dataset.from_tensor_slices( (train_images, train_captions, train_class_id)) gpu_device = '/gpu:0' img_and_caption = img_and_caption.apply( shuffle_and_repeat(self.dataset_num)).apply( map_and_batch(img_data_class.image_processing, batch_size=self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply( prefetch_to_device(gpu_device, None)) self.img_caption_iter = iter(img_and_caption) # real_img_256, caption = iter(img_and_caption) """ Network """ self.rnn_encoder = RnnEncoder(n_words=self.vocab_size, embed_dim=self.embed_dim, drop_rate=0.5, n_hidden=128, n_layer=1, bidirectional=True, rnn_type='lstm') self.cnn_encoder = CnnEncoder(embed_dim=self.embed_dim) self.ca_net = CA_NET(c_dim=self.z_dim) self.generator = Generator(channels=self.g_dim) self.discriminator = Discriminator(channels=self.d_dim, embed_dim=self.embed_dim) """ Optimizer """ self.g_optimizer = tf.keras.optimizers.Adam( learning_rate=self.init_lr, beta_1=0.5, beta_2=0.999, epsilon=1e-08) d_64_optimizer = tf.keras.optimizers.Adam( learning_rate=self.init_lr, beta_1=0.5, beta_2=0.999, epsilon=1e-08) d_128_optimizer = tf.keras.optimizers.Adam( learning_rate=self.init_lr, beta_1=0.5, beta_2=0.999, epsilon=1e-08) d_256_optimizer = tf.keras.optimizers.Adam( learning_rate=self.init_lr, beta_1=0.5, beta_2=0.999, epsilon=1e-08) self.d_optimizer = [ d_64_optimizer, d_128_optimizer, d_256_optimizer ] self.embed_optimizer = tf.keras.optimizers.Adam( learning_rate=self.init_lr, beta_1=0.5, beta_2=0.999, epsilon=1e-08) """ Checkpoint """ self.ckpt = tf.train.Checkpoint( rnn_encoder=self.rnn_encoder, cnn_encoder=self.cnn_encoder, ca_net=self.ca_net, generator=self.generator, discriminator=self.discriminator, g_optimizer=self.g_optimizer, d_64_optimizer=d_64_optimizer, d_128_optimizer=d_128_optimizer, d_256_optimizer=d_256_optimizer, embed_optimizer=self.embed_optimizer) self.manager = tf.train.CheckpointManager(self.ckpt, self.checkpoint_dir, max_to_keep=2) self.start_iteration = 0 if self.manager.latest_checkpoint: self.ckpt.restore(self.manager.latest_checkpoint) self.start_iteration = int( self.manager.latest_checkpoint.split('-')[-1]) print('Latest checkpoint restored!!') print('start iteration : ', self.start_iteration) else: print('Not restoring from saved checkpoint') else: """ Test """ self.dataset_num = len(test_captions) gpu_device = '/gpu:0' img_and_caption = tf.data.Dataset.from_tensor_slices( (train_images, train_captions)) img_and_caption = img_and_caption.apply( shuffle_and_repeat(self.dataset_num)).apply( map_and_batch(img_data_class.image_processing, batch_size=self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply( prefetch_to_device(gpu_device, None)) self.img_caption_iter = iter(img_and_caption) """ Network """ self.rnn_encoder = RnnEncoder(n_words=self.vocab_size, embed_dim=self.embed_dim, drop_rate=0.5, n_hidden=128, n_layer=1, bidirectional=True, rnn_type='lstm') self.cnn_encoder = CnnEncoder(embed_dim=self.embed_dim) self.ca_net = CA_NET(c_dim=self.z_dim) self.generator = Generator(channels=self.g_dim) # self.discriminator = Discriminator(channels=self.d_dim, embed_dim=self.embed_dim) # """ Optimizer """ # self.g_optimizer = tf.keras.optimizers.Adam(learning_rate=self.init_lr, beta_1=0.5, beta_2=0.999, epsilon=1e-08) # d_64_optimizer = tf.keras.optimizers.Adam(learning_rate=self.init_lr, beta_1=0.5, beta_2=0.999, epsilon=1e-08) # d_128_optimizer = tf.keras.optimizers.Adam(learning_rate=self.init_lr, beta_1=0.5, beta_2=0.999, epsilon=1e-08) # d_256_optimizer = tf.keras.optimizers.Adam(learning_rate=self.init_lr, beta_1=0.5, beta_2=0.999, epsilon=1e-08) # self.d_optimizer = [d_64_optimizer, d_128_optimizer, d_256_optimizer] # self.embed_optimizer = tf.keras.optimizers.Adam(learning_rate=self.init_lr, beta_1=0.5, beta_2=0.999, epsilon=1e-08) """ Checkpoint """ self.ckpt = tf.train.Checkpoint(rnn_encoder=self.rnn_encoder, cnn_encoder=self.cnn_encoder, ca_net=self.ca_net, generator=self.generator) # discriminator=self.discriminator, # g_optimizer=self.g_optimizer, # d_64_optimizer=d_64_optimizer, # d_128_optimizer=d_128_optimizer, # d_256_optimizer=d_256_optimizer, # embed_optimizer=self.embed_optimizer) self.manager = tf.train.CheckpointManager(self.ckpt, self.checkpoint_dir, max_to_keep=2) self.start_iteration = 0 if self.manager.latest_checkpoint: self.ckpt.restore( self.manager.latest_checkpoint).expect_partial() self.start_iteration = int( self.manager.latest_checkpoint.split('-')[-1]) print('Latest checkpoint restored!!') print('start iteration : ', self.start_iteration) else: print('Not restoring from saved checkpoint')
def build_model(self): if self.phase == 'train': """ Input Image""" img_class = Image_data(self.img_height, self.img_width, self.img_ch, self.dataset_path, self.augment_flag) img_class.preprocess() dataset_num = len(img_class.dataset) print("Dataset number : ", dataset_num) img_slice = tf.data.Dataset.from_tensor_slices(img_class.dataset) gpu_device = '/gpu:0' img_slice = img_slice. \ apply(shuffle_and_repeat(dataset_num)). \ apply(map_and_batch(img_class.image_processing, self.batch_size, num_parallel_batches=AUTOTUNE, drop_remainder=True)). \ apply(prefetch_to_device(gpu_device, AUTOTUNE)) self.dataset_iter = iter(img_slice) """ Network """ self.classifier = SubNetwork(channels=64, name='classifier') """ Optimizer """ self.optimizer = tf.keras.optimizers.Adam(learning_rate=self.init_lr, beta_1=0.5, beta_2=0.999, epsilon=1e-08) """ Summary """ # mean metric self.loss_metric = tf.keras.metrics.Mean('loss', dtype=tf.float32) # In tensorboard, make a loss to smooth graph # print summary input_shape = [self.img_height, self.img_width, self.img_ch] self.classifier.build_summary(input_shape) """ Count parameters """ params = self.classifier.count_parameter() print("Total network parameters : ", format(params, ',')) """ Checkpoint """ self.ckpt = tf.train.Checkpoint(classifier=self.classifier, optimizer=self.optimizer) self.manager = tf.train.CheckpointManager(self.ckpt, self.checkpoint_dir, max_to_keep=2) self.start_iteration = 0 if self.manager.latest_checkpoint: self.ckpt.restore(self.manager.latest_checkpoint) self.start_iteration = int(self.manager.latest_checkpoint.split('-')[-1]) print('Latest checkpoint restored!!') print('start iteration : ', self.start_iteration) else: print('Not restoring from saved checkpoint') else: """ Test """ """ Network """ self.classifier = SubNetwork(channels=64, name='classifier') """ Summary """ input_shape = [self.img_height, self.img_width, self.img_ch] self.classifier.build_summary(input_shape) """ Count parameters """ params = self.classifier.count_parameter() print("Total network parameters : ", format(params, ',')) """ Checkpoint """ self.ckpt = tf.train.Checkpoint(classifier=self.classifier) self.manager = tf.train.CheckpointManager(self.ckpt, self.checkpoint_dir, max_to_keep=2) if self.manager.latest_checkpoint: self.ckpt.restore(self.manager.latest_checkpoint).expect_partial() print('Latest checkpoint restored!!') else: print('Not restoring from saved checkpoint')
def build_model(self): """ Input Image""" img_data_class = Image_data(self.img_height, self.img_width, self.img_ch, self.dataset_path, self.augment_flag) train_captions, train_images, test_captions, test_images, idx_to_word, word_to_idx = img_data_class.preprocess( ) """ train_captions: (8855, 10, 18), test_captions: (2933, 10, 18) train_images: (8855,), test_images: (2933,) idx_to_word : 5450 5450 """ if self.phase == 'train': self.lr = tf.placeholder(tf.float32, name='learning_rate') self.dataset_num = len(train_images) img_and_caption = tf.data.Dataset.from_tensor_slices( (train_images, train_captions)) gpu_device = '/gpu:0' img_and_caption = img_and_caption.apply( shuffle_and_repeat(self.dataset_num)).apply( map_and_batch(img_data_class.image_processing, batch_size=self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply( prefetch_to_device(gpu_device, None)) img_and_caption_iterator = img_and_caption.make_one_shot_iterator() real_img_256, caption = img_and_caption_iterator.get_next() target_sentence_index = tf.random_uniform(shape=[], minval=0, maxval=10, dtype=tf.int32) caption = tf.gather(caption, target_sentence_index, axis=1) word_emb, sent_emb, mask = self.rnn_encoder( caption, n_words=len(idx_to_word), embed_dim=self.embed_dim, drop_rate=0.5, n_hidden=128, n_layers=1, bidirectional=True, rnn_type='lstm') noise = tf.random_normal(shape=[self.batch_size, self.z_dim], mean=0.0, stddev=1.0) fake_imgs, _, mu, logvar = self.generator(noise, sent_emb, word_emb, mask) real_img_64, real_img_128 = resize(real_img_256, target_size=[64, 64]), resize( real_img_256, target_size=[128, 128]) fake_img_64, fake_img_128, fake_img_256 = fake_imgs[0], fake_imgs[ 1], fake_imgs[2] uncond_real_logits, cond_real_logits = self.discriminator( [real_img_64, real_img_128, real_img_256], sent_emb) _, cond_wrong_logits = self.discriminator([ real_img_64[:(self.batch_size - 1)], real_img_128[:(self.batch_size - 1)], real_img_256[:(self.batch_size - 1)] ], sent_emb[1:self.batch_size]) uncond_fake_logits, cond_fake_logits = self.discriminator( [fake_img_64, fake_img_128, fake_img_256], sent_emb) fake_img_256_feature = self.caption_cnn(fake_img_256) fake_img_256_caption = self.caption_rnn(fake_img_256_feature, caption, n_words=len(idx_to_word), embed_dim=self.embed_dim, n_hidden=256 * 2, n_layers=1) self.g_adv_loss, self.d_adv_loss = 0, 0 for i in range(3): self.g_adv_loss += self.adv_weight * ( generator_loss(self.gan_type, uncond_fake_logits[i]) + generator_loss(self.gan_type, cond_fake_logits[i])) uncond_real_loss, uncond_fake_loss = discriminator_loss( self.gan_type, uncond_real_logits[i], uncond_fake_logits[i]) cond_real_loss, cond_fake_loss = discriminator_loss( self.gan_type, cond_real_logits[i], cond_fake_logits[i]) _, cond_wrong_loss = discriminator_loss( self.gan_type, None, cond_wrong_logits[i]) self.d_adv_loss += self.adv_weight * ( ((uncond_real_loss + cond_real_loss) / 2) + (uncond_fake_loss + cond_fake_loss + cond_wrong_loss) / 3) self.g_kl_loss = self.kl_weight * kl_loss(mu, logvar) caption = tf.one_hot(caption, len(idx_to_word)) caption = tf.reshape(caption, [-1, len(idx_to_word)]) self.g_cap_loss = self.cap_weight * caption_loss( fake_img_256_caption, caption) self.g_loss = self.g_adv_loss + self.g_kl_loss + self.g_cap_loss self.d_loss = self.d_adv_loss self.real_img = real_img_256 self.fake_img = fake_img_256 """ Training """ t_vars = tf.trainable_variables() G_vars = [var for var in t_vars if 'generator' in var.name] D_vars = [var for var in t_vars if 'discriminator' in var.name] self.g_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize( self.g_loss, var_list=G_vars) self.d_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize( self.d_loss, var_list=D_vars) """" Summary """ self.summary_g_loss = tf.summary.scalar("g_loss", self.g_loss) self.summary_d_loss = tf.summary.scalar("d_loss", self.d_loss) self.summary_g_adv_loss = tf.summary.scalar( "g_adv_loss", self.g_adv_loss) self.summary_g_kl_loss = tf.summary.scalar("g_kl_loss", self.g_kl_loss) self.summary_g_cap_loss = tf.summary.scalar( "g_cap_loss", self.g_cap_loss) self.summary_d_adv_loss = tf.summary.scalar( "d_adv_loss", self.d_adv_loss) g_summary_list = [ self.summary_g_loss, self.summary_g_adv_loss, self.summary_g_kl_loss, self.summary_g_cap_loss ] d_summary_list = [self.summary_d_loss, self.summary_d_adv_loss] self.summary_merge_g_loss = tf.summary.merge(g_summary_list) self.summary_merge_d_loss = tf.summary.merge(d_summary_list) else: """ Test """ self.dataset_num = len(test_captions) gpu_device = '/gpu:0' img_and_caption = tf.data.Dataset.from_tensor_slices( (test_images, test_captions)) img_and_caption = img_and_caption.apply( shuffle_and_repeat(self.dataset_num)).apply( map_and_batch(img_data_class.image_processing, batch_size=self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply( prefetch_to_device(gpu_device, None)) img_and_caption_iterator = img_and_caption.make_one_shot_iterator() real_img_256, caption = img_and_caption_iterator.get_next() target_sentence_index = tf.random_uniform(shape=[], minval=0, maxval=10, dtype=tf.int32) caption = tf.gather(caption, target_sentence_index, axis=1) word_emb, sent_emb, mask = self.rnn_encoder( caption, n_words=len(idx_to_word), embed_dim=self.embed_dim, drop_rate=0.5, n_hidden=128, n_layers=1, bidirectional=True, rnn_type='lstm', is_training=False) noise = tf.random_normal(shape=[self.batch_size, self.z_dim], mean=0.0, stddev=1.0) fake_imgs, _, _, _ = self.generator(noise, sent_emb, word_emb, mask, is_training=False) self.test_real_img = real_img_256 self.test_fake_img = fake_imgs[2]
def build_model(self): if self.phase == 'train': """ Input Image""" img_class = Image_data(self.img_size, self.img_ch, self.dataset_path, self.augment_flag) img_class.preprocess() dataset_num = len(img_class.mask_images) + len( img_class.nomask_images) print("Dataset number : ", dataset_num) img_and_domain = tf.data.Dataset.from_tensor_slices( (img_class.mask_images, img_class.mask_masks, img_class.nomask_images, img_class.nomask_masks, img_class.nomask_images2, img_class.nomask_masks2)) gpu_device = '/gpu:0' img_and_domain = img_and_domain.shuffle( buffer_size=dataset_num, reshuffle_each_iteration=True).repeat() img_and_domain = img_and_domain.map( map_func=img_class.image_processing, num_parallel_calls=AUTOTUNE).batch(self.batch_size, drop_remainder=True) img_and_domain = img_and_domain.apply( prefetch_to_device(gpu_device, buffer_size=AUTOTUNE)) self.img_and_domain_iter = iter(img_and_domain) """ Network """ self.generator = Generator(self.img_size, self.img_ch, self.style_dim, max_conv_dim=self.hidden_dim, sn=False, name='Generator') self.mapping_network = MappingNetwork(self.style_dim, self.hidden_dim, sn=False, name='MappingNetwork') self.style_encoder = StyleEncoder(self.img_size, self.style_dim, max_conv_dim=self.hidden_dim, sn=False, name='StyleEncoder') self.discriminator = Discriminator(self.img_size, max_conv_dim=self.hidden_dim, sn=self.sn, name='Discriminator') self.generator_ema = deepcopy(self.generator) self.mapping_network_ema = deepcopy(self.mapping_network) self.style_encoder_ema = deepcopy(self.style_encoder) """ Finalize model (build) """ x = np.ones(shape=[ self.batch_size, self.img_size, self.img_size, self.img_ch ], dtype=np.float32) z = np.ones(shape=[self.batch_size, self.latent_dim], dtype=np.float32) s = np.ones(shape=[self.batch_size, self.style_dim], dtype=np.float32) m = np.ones(shape=[self.batch_size, self.img_size, self.img_size], dtype=np.bool) _ = self.mapping_network(z) _ = self.mapping_network_ema(z) _ = self.style_encoder(x) _ = self.style_encoder_ema(x) _ = self.generator([x, s, m]) _ = self.generator_ema([x, s, m]) _ = self.discriminator([x, m]) """ Optimizer """ self.g_optimizer = tf.keras.optimizers.Adam(learning_rate=self.lr, beta_1=self.beta1, beta_2=self.beta2, epsilon=1e-08) self.e_optimizer = tf.keras.optimizers.Adam(learning_rate=self.lr, beta_1=self.beta1, beta_2=self.beta2, epsilon=1e-08) self.f_optimizer = tf.keras.optimizers.Adam( learning_rate=self.f_lr, beta_1=self.beta1, beta_2=self.beta2, epsilon=1e-08) self.d_optimizer = tf.keras.optimizers.Adam(learning_rate=self.lr, beta_1=self.beta1, beta_2=self.beta2, epsilon=1e-08) """ Checkpoint """ self.ckpt = tf.train.Checkpoint( generator=self.generator, generator_ema=self.generator_ema, mapping_network=self.mapping_network, mapping_network_ema=self.mapping_network_ema, style_encoder=self.style_encoder, style_encoder_ema=self.style_encoder_ema, discriminator=self.discriminator, g_optimizer=self.g_optimizer, e_optimizer=self.e_optimizer, f_optimizer=self.f_optimizer, d_optimizer=self.d_optimizer) self.manager = tf.train.CheckpointManager(self.ckpt, self.checkpoint_dir, max_to_keep=10) self.start_iteration = 0 if self.manager.latest_checkpoint: self.ckpt.restore( self.manager.latest_checkpoint).expect_partial() self.start_iteration = int( self.manager.latest_checkpoint.split('-')[-1]) print('Latest checkpoint restored!!') print('start iteration : ', self.start_iteration) else: print('Not restoring from saved checkpoint') else: """ Test """ """ Network """ self.generator_ema = Generator(self.img_size, self.img_ch, self.style_dim, max_conv_dim=self.hidden_dim, sn=False, name='Generator') self.mapping_network_ema = MappingNetwork(self.style_dim, self.hidden_dim, sn=False, name='MappingNetwork') self.style_encoder_ema = StyleEncoder(self.img_size, self.style_dim, max_conv_dim=self.hidden_dim, sn=False, name='StyleEncoder') """ Finalize model (build) """ x = np.ones(shape=[ self.batch_size, self.img_size, self.img_size, self.img_ch ], dtype=np.float32) z = np.ones(shape=[self.batch_size, self.latent_dim], dtype=np.float32) s = np.ones(shape=[self.batch_size, self.style_dim], dtype=np.float32) m = np.ones(shape=[self.batch_size, self.img_size, self.img_size], dtype=np.bool) _ = self.mapping_network_ema(z, training=False) _ = self.style_encoder_ema(x, training=False) _ = self.generator_ema([x, s, m], training=False) """ Checkpoint """ self.ckpt = tf.train.Checkpoint( generator_ema=self.generator_ema, mapping_network_ema=self.mapping_network_ema, style_encoder_ema=self.style_encoder_ema) self.manager = tf.train.CheckpointManager(self.ckpt, self.checkpoint_dir, max_to_keep=10) if self.manager.latest_checkpoint: self.ckpt.restore( self.manager.latest_checkpoint).expect_partial() print('Latest checkpoint restored!!') else: print('Not restoring from saved checkpoint')