def train( self, num_epochs=200, # number of epochs learning_rate=0.0002, # learning rate of optimizer beta1=0.5, # parameter for Adam optimizer decay_rate=1.0, # learning rate decay (0, 1], 1 means no decay enable_shuffle=True, # enable shuffle of the dataset use_trained_model=True, # use the saved checkpoint to initialize the network use_init_model=True, # use the init model to initialize the network weigts=(0.0001, 0, 0) # the weights of adversarial loss and TV loss ): # *************************** load file names of images ****************************************************** file_names = glob(os.path.join('./data', self.dataset_name, '*.jpg')) size_data = len(file_names) np.random.seed(seed=2017) if enable_shuffle: np.random.shuffle(file_names) # *********************************** optimizer ************************************************************** # over all, there are three loss functions, weights may differ from the paper because of different datasets self.loss_EG = self.EG_loss + weigts[0] * self.G_img_loss + weigts[ 1] * self.E_z_loss + weigts[ 2] * self.tv_loss # slightly increase the params self.loss_Dz = self.D_z_loss_prior + self.D_z_loss_z self.loss_Di = self.D_img_loss_input + self.D_img_loss_G # set learning rate decay self.EG_global_step = tf.Variable(0, trainable=False, name='global_step') EG_learning_rate = tf.compat.v1.train.exponential_decay( learning_rate=learning_rate, global_step=self.EG_global_step, decay_steps=size_data / self.size_batch * 2, decay_rate=decay_rate, staircase=True) # optimizer for encoder + generator with tf.compat.v1.variable_scope('opt', reuse=tf.compat.v1.AUTO_REUSE): self.EG_optimizer = tf.compat.v1.train.AdamOptimizer( learning_rate=EG_learning_rate, beta1=beta1).minimize( loss=self.loss_EG, global_step=self.EG_global_step, var_list=self.E_variables + self.G_variables) # optimizer for discriminator on z self.D_z_optimizer = tf.compat.v1.train.AdamOptimizer( learning_rate=EG_learning_rate, beta1=beta1).minimize(loss=self.loss_Dz, var_list=self.D_z_variables) # optimizer for discriminator on image self.D_img_optimizer = tf.compat.v1.train.AdamOptimizer( learning_rate=EG_learning_rate, beta1=beta1).minimize(loss=self.loss_Di, var_list=self.D_img_variables) # *********************************** tensorboard ************************************************************* # for visualization (TensorBoard): $ tensorboard --logdir path/to/log-directory self.EG_learning_rate_summary = tf.summary.scalar( 'EG_learning_rate', EG_learning_rate) self.summary = tf.compat.v1.summary.merge([ self.z_summary, self.z_prior_summary, self.D_z_loss_z_summary, self.D_z_loss_prior_summary, self.D_z_logits_summary, self.D_z_prior_logits_summary, self.EG_loss_summary, self.E_z_loss_summary, self.D_img_loss_input_summary, self.D_img_loss_G_summary, self.G_img_loss_summary, self.EG_learning_rate_summary, self.D_G_logits_summary, self.D_input_logits_summary ]) self.writer = tf.summary.FileWriter( os.path.join(self.save_dir, 'summary'), self.session.graph) # ************* get some random samples as testing data to visualize the learning process ********************* sample_files = file_names[0:self.size_batch] file_names[0:self.size_batch] = [] sample = [ load_image( image_path=sample_file, image_size=self.size_image, image_value_range=self.image_value_range, is_gray=(self.num_input_channels == 1), ) for sample_file in sample_files ] if self.num_input_channels == 1: sample_images = np.array(sample).astype(np.float32)[:, :, :, None] else: sample_images = np.array(sample).astype(np.float32) sample_label_age = np.ones( shape=(len(sample_files), self.num_categories), dtype=np.float32) * self.image_value_range[0] sample_label_gender = np.ones( shape=(len(sample_files), 2), dtype=np.float32) * self.image_value_range[0] for i, label in enumerate(sample_files): # TO CHANGE: label = int(str(sample_files[i]).split('/')[-1].split('_')[0]) if 0 <= label <= 5: label = 0 elif 6 <= label <= 10: label = 1 elif 11 <= label <= 15: label = 2 elif 16 <= label <= 20: label = 3 elif 21 <= label <= 30: label = 4 elif 31 <= label <= 40: label = 5 elif 41 <= label <= 50: label = 6 elif 51 <= label <= 60: label = 7 elif 61 <= label <= 70: label = 8 else: label = 9 sample_label_age[i, label] = self.image_value_range[-1] # TO CHANGE: gender = int(str(sample_files[i]).split('/')[-1].split('_')[1]) sample_label_gender[i, gender] = self.image_value_range[-1] # ******************************************* training ******************************************************* # initialize the graph tf.compat.v1.global_variables_initializer().run() # load check point if use_trained_model: if self.load_checkpoint(): print("\tSUCCESS ^_^") else: print("\tFAILED >_<!") # load init model if use_init_model: if not os.path.exists( 'init_model/model-init.data-00000-of-00001'): from init_model.zip_opt import join try: join('init_model/model_parts', 'init_model/model-init.data-00000-of-00001') except: raise Exception('Error joining files') self.load_checkpoint(model_path='init_model') # epoch iteration num_batches = len(file_names) // self.size_batch for epoch in range(num_epochs): if enable_shuffle: np.random.shuffle(file_names) for ind_batch in range(num_batches): start_time = time.time() # read batch images and labels batch_files = file_names[ind_batch * self.size_batch:(ind_batch + 1) * self.size_batch] batch = [ load_image( image_path=batch_file, image_size=self.size_image, image_value_range=self.image_value_range, is_gray=(self.num_input_channels == 1), ) for batch_file in batch_files ] if self.num_input_channels == 1: batch_images = np.array(batch).astype(np.float32)[:, :, :, None] else: batch_images = np.array(batch).astype(np.float32) batch_label_age = np.ones( shape=(len(batch_files), self.num_categories), dtype=np.float) * self.image_value_range[0] batch_label_gender = np.ones( shape=(len(batch_files), 2), dtype=np.float) * self.image_value_range[0] for i, label in enumerate(batch_files): # TO CHANGE: label = int( str(batch_files[i]).split('/')[-1].split('_')[0]) if 0 <= label <= 5: label = 0 elif 6 <= label <= 10: label = 1 elif 11 <= label <= 15: label = 2 elif 16 <= label <= 20: label = 3 elif 21 <= label <= 30: label = 4 elif 31 <= label <= 40: label = 5 elif 41 <= label <= 50: label = 6 elif 51 <= label <= 60: label = 7 elif 61 <= label <= 70: label = 8 else: label = 9 batch_label_age[i, label] = self.image_value_range[-1] # TO CHANGE: gender = int( str(batch_files[i]).split('/')[-1].split('_')[1]) batch_label_gender[i, gender] = self.image_value_range[-1] # prior distribution on the prior of z batch_z_prior = np.random.uniform( self.image_value_range[0], self.image_value_range[-1], [self.size_batch, self.num_z_channels]).astype(np.float32) # update _, _, _, EG_err, Ez_err, Dz_err, Dzp_err, Gi_err, DiG_err, Di_err, TV = self.session.run( fetches=[ self.EG_optimizer, self.D_z_optimizer, self.D_img_optimizer, self.EG_loss, self.E_z_loss, self.D_z_loss_z, self.D_z_loss_prior, self.G_img_loss, self.D_img_loss_G, self.D_img_loss_input, self.tv_loss ], feed_dict={ self.input_image: batch_images, self.age: batch_label_age, self.gender: batch_label_gender, self.z_prior: batch_z_prior }) print( "\nEpoch: [%3d/%3d] Batch: [%3d/%3d]\n\tEG_err=%.4f\tTV=%.4f" % (epoch + 1, num_epochs, ind_batch + 1, num_batches, EG_err, TV)) print("\tEz=%.4f\tDz=%.4f\tDzp=%.4f" % (Ez_err, Dz_err, Dzp_err)) print("\tGi=%.4f\tDi=%.4f\tDiG=%.4f" % (Gi_err, Di_err, DiG_err)) # estimate left run time elapse = time.time() - start_time time_left = ((num_epochs - epoch - 1) * num_batches + (num_batches - ind_batch - 1)) * elapse print("\tTime left: %02d:%02d:%02d" % (int(time_left / 3600), int( time_left % 3600 / 60), time_left % 60)) # TO CHANGE: # add to summary summary = self.summary.eval( feed_dict={ self.input_image: batch_images, self.age: batch_label_age, self.gender: batch_label_gender, self.z_prior: batch_z_prior }) self.writer.add_summary(summary, self.EG_global_step.eval()) # save sample images for each epoch name = '{:02d}.png'.format(epoch + 1) self.sample(sample_images, sample_label_age, sample_label_gender, name) self.test(sample_images, sample_label_gender, name) # save checkpoint for each 5 epoch if np.mod(epoch, 5) == 4: self.save_checkpoint() # save the trained model self.save_checkpoint() # TO CHANGE: # close the summary writer self.writer.close()
def train( self, num_epochs=200, # number of epochs learning_rate=0.0002, # learning rate of optimizer beta1=0.5, # parameter for Adam optimizer decay_rate=1.0, # learning rate decay (0, 1], 1 means no decay enable_shuffle=True, # enable shuffle of the dataset use_trained_model=True, # use the saved checkpoint to initialize the network use_init_model=True, # use the init model to initialize the network weigts=(0.0001, 0.0001, 0.001) # the weights of adversarial loss and TV loss ): # *************************** load file names of images ****************************************************** file_names = glob(os.path.join('./data', self.dataset_name, '*.jpg')) size_data = len(file_names) np.random.seed(seed=2017) if enable_shuffle: np.random.shuffle(file_names) # *********************************** optimizer ************************************************************** # over all, there are three loss functions, weights may differ from the paper because of different datasets self.loss_EG = self.EG_loss + weigts[0] * self.G_img_loss + weigts[1] * self.E_z_loss + weigts[2] * self.tv_loss # slightly increase the params self.loss_Dz = self.D_z_loss_prior + self.D_z_loss_z self.loss_Di = self.D_img_loss_input + self.D_img_loss_G self.loss_EG_summary = tf.summary.scalar('loss_EG', self.loss_EG) self.loss_Dz_summary = tf.summary.scalar('loss_Dz', self.loss_Dz) self.loss_Di_summary = tf.summary.scalar('loss_Di', self.loss_Di) # set learning rate decay self.EG_global_step = tf.Variable(0, trainable=False, name='global_step') EG_learning_rate = tf.train.exponential_decay( learning_rate=learning_rate, global_step=self.EG_global_step, decay_steps=size_data / self.size_batch * 2, decay_rate=decay_rate, staircase=True ) # optimizer for encoder + generator with tf.variable_scope('opt', reuse=tf.AUTO_REUSE): self.EG_optimizer = tf.train.AdamOptimizer( learning_rate=EG_learning_rate, beta1=beta1 ).minimize( loss=self.loss_EG, global_step=self.EG_global_step, var_list=self.E_variables + self.G_variables ) # optimizer for discriminator on z self.D_z_optimizer = tf.train.AdamOptimizer( learning_rate=EG_learning_rate, beta1=beta1 ).minimize( loss=self.loss_Dz, var_list=self.D_z_variables ) # optimizer for discriminator on image self.D_img_optimizer = tf.train.AdamOptimizer( learning_rate=EG_learning_rate, beta1=beta1 ).minimize( loss=self.loss_Di, var_list=self.D_img_variables ) # self.D_real_optimizer = tf.train.AdamOptimizer( # learning_rate=EG_learning_rate, # beta1=beta1 # ).minimize( # loss=self.D_img_loss_input, # global_step=self.EG_global_step, # var_list=self.D_img_variables # ) # self.D_fake_optimizer = tf.train.AdamOptimizer( # learning_rate=EG_learning_rate, # beta1=beta1 # ).minimize( # loss=self.D_img_loss_G, # global_step=self.EG_global_step, # var_list=self.D_img_variables # ) # self.G_optimizer = tf.train.AdamOptimizer( # learning_rate=EG_learning_rate, # beta1=beta1 # ).minimize( # loss=self.G_img_loss, # global_step=self.EG_global_step, # var_list=self.G_variables # ) # *********************************** tensorboard ************************************************************* # for visualization (TensorBoard): $ tensorboard --logdir path/to/log-directory self.EG_learning_rate_summary = tf.summary.scalar('EG_learning_rate', EG_learning_rate) self.summary = tf.summary.merge([ self.z_summary, self.z_prior_summary, self.D_z_loss_z_summary, self.D_z_loss_prior_summary, self.D_z_logits_summary, self.D_z_prior_logits_summary, self.EG_loss_summary, self.E_z_loss_summary, self.D_img_loss_input_summary, self.D_img_loss_G_summary, self.G_img_loss_summary, self.EG_learning_rate_summary, self.D_G_logits_summary, self.D_input_logits_summary, self.tv_loss_summary, self.loss_EG_summary, self.loss_Dz_summary, self.loss_Di_summary, ]) self.writer = tf.summary.FileWriter(os.path.join(self.save_dir, 'summary'), self.session.graph) # ************* get some random samples as testing data to visualize the learning process ********************* print("\n\tLoading Dataset...") sample_files = file_names[0:self.size_batch] file_names[0:self.size_batch] = [] sample_images, sample_label_age, sample_label_gender = self.get_dataset(sample_files) if not os.path.exists(os.path.join('save', 'test')): os.makedirs(os.path.join('save', 'test')) cache_file = os.path.join('save', 'test', 'test_files.pkl') with open(cache_file, 'wb') as fid: cPickle.dump(sample_files, fid, cPickle.HIGHEST_PROTOCOL) # ******************************************* training ******************************************************* # initialize the graph tf.global_variables_initializer().run() # load check point if use_trained_model: if self.load_checkpoint(): print("\tSUCCESS ^_^") else: print("\tFAILED >_<!") # load init model if use_init_model: if not os.path.exists('init_model/model-init.data-00000-of-00001'): from init_model.zip_opt import join try: join('init_model/model_parts', 'init_model/model-init.data-00000-of-00001') except: raise Exception('Error joining files') self.load_checkpoint(model_path='init_model') # epoch iteration print("\n\tStart Training...") num_batches = len(file_names) // self.size_batch for epoch in range(num_epochs): if enable_shuffle: np.random.shuffle(file_names) for ind_batch in range(num_batches): start_time = time.time() # read batch images and labels batch_files = file_names[ind_batch*self.size_batch:(ind_batch+1)*self.size_batch] batch_images, batch_label_age, batch_label_gender = self.get_dataset(batch_files) # prior distribution on the prior of z batch_z_prior = np.random.uniform( self.image_value_range[0], self.image_value_range[-1], [self.size_batch, self.num_z_channels] ).astype(np.float32) # update _, _, _, EG_err, Ez_err, Dz_err, Dzp_err, Gi_err, DiG_err, Di_err, TV = self.session.run( fetches = [ self.EG_optimizer, self.D_z_optimizer, self.D_img_optimizer, self.EG_loss, self.E_z_loss, self.D_z_loss_z, self.D_z_loss_prior, self.G_img_loss, self.D_img_loss_G, self.D_img_loss_input, self.tv_loss ], feed_dict={ self.input_image: batch_images, self.age: batch_label_age, self.gender: batch_label_gender, self.z_prior: batch_z_prior } ) print("\nEpoch: [%3d/%3d] Batch: [%3d/%3d]\n\tEG_err=%.4f\tTV=%.4f" % (epoch+1, num_epochs, ind_batch+1, num_batches, EG_err, TV)) print("\tEz=%.4f\tDz=%.4f\tDzp=%.4f" % (Ez_err, Dz_err, Dzp_err)) print("\tGi=%.4f\tDi=%.4f\tDiG=%.4f" % (Gi_err, Di_err, DiG_err)) # estimate left run time elapse = time.time() - start_time time_left = ((num_epochs - epoch - 1) * num_batches + (num_batches - ind_batch - 1)) * elapse print("\tTime left: %02d:%02d:%02d" % (int(time_left / 3600), int(time_left % 3600 / 60), time_left % 60)) # add to summary summary = self.summary.eval( feed_dict={ self.input_image: batch_images, self.age: batch_label_age, self.gender: batch_label_gender, self.z_prior: batch_z_prior } ) self.writer.add_summary(summary, self.EG_global_step.eval()) # save sample images for each epoch name = '{:02d}.png'.format(epoch+1) self.sample(sample_images, sample_label_age, sample_label_gender, name) self.test(sample_images, sample_label_gender, name) # save checkpoint for each 5 epoch if np.mod(epoch, 5) == 4: self.save_checkpoint() # save the trained model self.save_checkpoint() # close the summary writer self.writer.close()