def test(self, *args, **kwargs): self.generator = self.make_generator(**kwargs) ckpt = tf.train.Checkpoint(g_model=self.generator) fname, _ = load_checkpoint(**kwargs) print("\nCheckpoint File : {}\n".format(fname)) # model만 불러옴 ckpt.mapped = {"g_model": self.generator} ckpt.restore(fname).expect_partial() _, model_images_path, _, _ = make_folders_for_model(kwargs['folder']) fname = os.path.join(model_images_path, "Test.png") self.plot_images(fname)
def train(self, **kwargs): interval = kwargs["interval"] model_ckpt_path, model_images_path, model_logs_path, model_result_file = make_folders_for_model( kwargs['folder']) self.generator = self.make_generator() self.discriminator = self.make_discriminator() train_dataset = self.get_dataset() num_batches = ceil(self.num_train / self.batch_size) d_epoch_loss = [] d_epoch_aux_loss = [] g_epoch_loss = [] training_progbar = tf.keras.utils.Progbar(target=self.num_train) save_initial_model_info( { 'generator': self.generator, 'discriminator': self.discriminator }, model_logs_path, model_ckpt_path, **kwargs) count = 0 self.g_opt = tf.keras.optimizers.Adam(lr=self.g_lr, beta_1=0.5) self.d_opt = tf.keras.optimizers.Adam(lr=self.d_lr, beta_1=0.5) self.BC_function = tf.keras.losses.BinaryCrossentropy(from_logits=True) self.SCC_function = tf.keras.losses.SparseCategoricalCrossentropy( from_logits=True) ckpt = tf.train.Checkpoint(g_opt=self.g_opt, d_opt=self.d_opt, g_model=self.generator, d_model=self.discriminator) if kwargs["ckpt_path"] is not None: fname, self.initial_epoch = load_checkpoint(**kwargs) print("\nCheckpoint File : {}\n".format(fname)) ckpt.mapped = { "g_opt": self.g_opt, "d_opt": self.d_opt, "g_model": self.generator, "d_model": self.discriminator } ckpt.restore(fname) self.g_lr = self.g_opt.get_config()["learning_rate"] self.d_lr = self.d_opt.get_config()["learning_rate"] for epoch in range(self.initial_epoch, self.initial_epoch + 50000): count += 1 start_time = korea_time() for real_images, real_labels in train_dataset: num_images = K.int_shape(real_labels)[0] g_loss = (self.train_G(num_images)).numpy() d_BC_loss, d_SCC_loss = self.train_D(real_images, real_labels) d_BC_loss = d_BC_loss.numpy() d_SCC_loss = d_SCC_loss.numpy() d_epoch_loss.append(d_BC_loss) d_epoch_aux_loss.append(d_SCC_loss) g_epoch_loss.append(g_loss) training_progbar.add(num_images) end_time = korea_time() training_progbar.update(0) # Progress bar 초기화 d_mean_loss = np.mean(d_epoch_loss, axis=0) d_mean_aux_loss = np.mean(d_epoch_aux_loss, axis=0) g_mean_loss = np.mean(g_epoch_loss, axis=0) ckpt_prefix = os.path.join( model_ckpt_path, "Epoch-{}_G-Loss-{:.6f}_D-Loss-{:.6f}".format( epoch, g_mean_loss, d_mean_loss + d_mean_aux_loss)) ckpt.save(file_prefix=ckpt_prefix) str_ = ("Epoch = [{:5d}]\tG Loss = [{:8.6f}]\t".format( epoch, g_mean_loss) + "D Loss = [{:8.6f}]\tD AUX Loss = [{:8.6f}]\n".format( d_mean_loss, d_mean_aux_loss)) print(str_) # model result 저장 str_ = "Epoch = [{:5d}] - End Time [ {} ]\n".format( epoch, str(end_time.strftime("%Y / %m / %d %H:%M:%S"))) str_ += "Elapsed Time = {}\n".format(end_time - start_time) str_ += "G Learning Rate = [{:.6f}] - D Learning Rate = [{:.6f}]\n".format( self.g_lr, self.d_lr) str_ += "G Loss : [{:8.6f}] - D Loss : [{:8.6f}] - D AUX Loss : [{:8.6f}] - Sum : [{:8.6f}]\n".format( g_mean_loss, d_mean_loss, d_mean_aux_loss, g_mean_loss + d_mean_loss + d_mean_aux_loss) str_ += " - " * 15 + "\n\n" with open(model_result_file, "a+", encoding='utf-8') as fp: fp.write(str_) if count == interval: fname = os.path.join(model_images_path, "{}.png".format(epoch)) self.plot_images(fname) count = 0 d_epoch_loss = [] d_epoch_aux_loss = [] g_epoch_loss = []
def train(self, **kwargs): interval = kwargs["interval"] model_ckpt_path, model_images_path, model_logs_path, model_result_file = make_folders_for_model( kwargs['folder']) self.generator = self.make_generator() self.critic = self.make_critic() train_dataset = self.get_dataset() num_batches = ceil(self.num_train / self.batch_size) c_epoch_loss = [] g_epoch_loss = [] training_progbar = tf.keras.utils.Progbar(target=self.num_train) save_initial_model_info( { 'generator': self.generator, 'critic': self.critic }, model_logs_path, model_ckpt_path, **kwargs) count = 0 self.g_opt = tf.keras.optimizers.Adam(lr=self.g_lr, beta_1=0, beta_2=0.9) self.c_opt = tf.keras.optimizers.Adam(lr=self.c_lr, beta_1=0, beta_2=0.9) ckpt = tf.train.Checkpoint(g_opt=self.g_opt, c_opt=self.c_opt, g_model=self.generator, c_model=self.critic) if kwargs["ckpt_path"] is not None: fname, self.initial_epoch = load_checkpoint(**kwargs) print("\nCheckpoint File : {}\n".format(fname)) ckpt.mapped = { "g_opt": self.g_opt, "c_opt": self.c_opt, "g_model": self.generator, "c_model": self.critic } ckpt.restore(fname) self.g_lr = self.g_opt.get_config()["learning_rate"] self.c_lr = self.c_opt.get_config()["learning_rate"] for epoch in range(self.initial_epoch, self.initial_epoch + 50000): count += 1 start_time = korea_time() num_batch = 0 # 64 * 5 = 320 mult = self.n_critic * self.batch_size num_dataset = 0 # 60000 real_images_list = [] for real_images in train_dataset: # self.n_critic개 만큼의 image dataset을 불러옴 real_images_list.append(real_images) num_images = K.int_shape(real_images)[0] num_batch += num_images num_dataset += num_images if (num_batch == mult) or (num_dataset == self.num_train): critic_loss_list = [(self.train_D(real_images)).numpy() for real_images in real_images_list] g_loss = (self.train_G()).numpy() c_epoch_loss.extend(critic_loss_list) g_epoch_loss.append(g_loss) training_progbar.add(num_batch) if num_dataset == self.num_train: break num_batch = 0 real_images_list = [] end_time = korea_time() training_progbar.update(0) # Progress bar 초기화 c_mean_loss = np.mean(c_epoch_loss, axis=0) g_mean_loss = np.mean(g_epoch_loss, axis=0) ckpt_prefix = os.path.join( model_ckpt_path, "Epoch-{}_G-Loss-{:.6f}_C-Loss-{:.6f}".format( epoch, g_mean_loss, c_mean_loss)) ckpt.save(file_prefix=ckpt_prefix) print( "Epoch = [{:5d}]\tGenerator Loss = [{:8.6f}]\tCritic Loss = [{:8.6f}]\n" .format(epoch, g_mean_loss, c_mean_loss)) # model result 저장 str_ = "Epoch = [{:5d}] - End Time [ {} ]\n".format( epoch, str(end_time.strftime("%Y / %m / %d %H:%M:%S"))) str_ += "Elapsed Time = {}\n".format(end_time - start_time) str_ += "Generator Learning Rate = [{:.6f}] - Critic Learning Rate = [{:.6f}]\n".format( self.g_lr, self.c_lr) str_ += "Generator Loss : [{:8.6f}] - Critic Loss : [{:8.6f}] - Sum : [{:8.6f}]\n".format( g_mean_loss, c_mean_loss, g_mean_loss + c_mean_loss) str_ += " - " * 15 + "\n\n" with open(model_result_file, "a+", encoding='utf-8') as fp: fp.write(str_) if count == interval: fname = os.path.join(model_images_path, "{}.png".format(epoch)) self.plot_images(fname) count = 0 c_epoch_loss = [] g_epoch_loss = []