def deploy_model(self, model, dir_names=None, epoch=None): # Reset the tensorflow graph tf.reset_default_graph() # Build up the graph with tf.device('/gpu:0'): input_feature, logits, preds, probs, main_graph \ = self._build_inference_graph(model) # Add saver saver = Saver(self.save_dir) # Prepare the input array self._input_processing() # Create a session with tf.Session() as sess: # restore the weights _ = saver.restore(sess, dir_names=dir_names, epoch=epoch) # initialize the unitialized variables initialize_uninitialized_vars(sess) logits_o, probs_o, preds_o = sess.run([logits, preds, probs], \ feed_dict={input_feature: self.input_array, main_graph.is_training: False}) saver.save(sess, 'model_final' + '.ckpt') # Note this save the model in the original folder, you need manully move the file # to Deploy/FinalModel # TODO: save the model directly in Deploy/FinalModel return logits_o, probs_o, preds_o
def train(self, Model): tf.reset_default_graph() # Input node im_batch, lm_batch, init_op = self._input_fn() # Build up the train graph with tf.name_scope("Train"): model = Model(self.config) out_im_batch, _, out_lm_batch, _ = model.forward(im_batch, lm_batch) im_loss, lm_loss = self._loss(im_batch, out_im_batch, lm_batch, out_lm_batch) optimizer = self._Adam_optimizer() im_solver, im_grads = self._train_op_w_grads(optimizer, im_loss) lm_solver, lm_grads = self._train_op_w_grads(optimizer, lm_loss) # Add summary if self.config.SUMMARY: summary_dict_train = {} if self.config.SUMMARY_SCALAR: scalar_train = {'appearance_loss': im_loss, \ 'landmarks_loss': lm_loss} summary_dict_train['scalar'] = scalar_train # Merge summary merged_summary_train = \ self.summary_train.add_summary(summary_dict_train) # Add saver saver = Saver(self.save_dir) sess_config = tf.ConfigProto(allow_soft_placement=True) # Use soft_placement to place those variables, which can be placed, on GPU # Create Session with tf.Session(config=sess_config) as sess: # Add graph to tensorboard if self.config.SUMMARY and self.config.SUMMARY_GRAPH: self.summary_train._graph_summary(sess.graph) # Restore the weights from the previous training if self.config.RESTORE: start_epoch = saver.restore(sess) else: # Create a new folder for saving model saver.set_save_path(comments=self.comments) start_epoch = 0 # initialize the variables init_var = tf.group(tf.global_variables_initializer(), \ tf.local_variables_initializer()) sess.run(init_var) # Start Training tf.logging.info("Start training!") for epoch in range(1, self.config.EPOCHS + 1): tf.logging.info("Training for epoch {}.".format(epoch)) train_pr_bar = tf.contrib.keras.utils.Progbar(target= \ int(800 / self.config.BATCH_SIZE)) sess.run(init_op) batch = 0 while True: try: im_loss_o, lm_loss_o, summary_o, _, _ = sess.run([im_loss, lm_loss, merged_summary_train, \ im_solver, lm_solver]) batch += 1 train_pr_bar.update(batch) if self.config.SUMMARY: # Add summary self.summary_train.summary_writer.add_summary(summary_o, epoch + start_epoch) except (tf.errors.InvalidArgumentError, tf.errors.OutOfRangeError): break tf.logging.info( "\nThe current epoch {}, appearance loss is {:.2f}, landmark loss is {:.2f}.\n" \ .format(epoch, im_loss_o, lm_loss_o)) # Save the model per SAVE_PER_EPOCH if epoch % self.config.SAVE_PER_EPOCH == 0: save_name = str(epoch + start_epoch) saver.save(sess, 'model_' + save_name.zfill(4) \ + '.ckpt') if self.config.SUMMARY: self.summary_train.summary_writer.flush() self.summary_train.summary_writer.close() # Save the model after all epochs save_name = str(epoch + start_epoch) saver.save(sess, 'model_' + save_name.zfill(4) + '.ckpt') return
def train(self, model): # Reset tf graph. tf.reset_default_graph() # Create input node init_op_train, init_op_val, real_lab_input,\ real_lab, real_unl_input, dataset_train = self._input_fn_train_val() # Build up the graph with tf.device('/gpu:0'): d_loss, g_loss, accuracy, roc_auc, pr_auc, \ update_op, reset_op, preds, probs, main_graph, scalar_train_sum_dict\ = training._build_train_graph(real_lab_input, \ real_unl_input, real_lab, model) # Create optimizer with tf.name_scope('Train'): theta_G = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='Generator') theta_D = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='Discriminator') optimizer = self._Adam_Optimizer() d_solver, d_grads = self._train_op_w_grads(optimizer, d_loss,\ var_list = theta_D) g_solver, g_grads = self._train_op_w_grads(optimizer, g_loss,\ var_list = theta_G) # Print out the variable name in debug mode tf.logging.debug(utils.variable_name_string()) # Add summary if self.config.SUMMARY: if self.config.SUMMARY_TRAIN_VAL: summary_dict_train = {} summary_dict_val = {} if self.config.SUMMARY_SCALAR: scalar_train = {'generator_loss': g_loss, \ 'discriminator_loss': d_loss} scalar_train.update(scalar_train_sum_dict) scalar_val = {'val_accuracy': accuracy, \ 'val_pr_auc': pr_auc, \ 'val_roc_auc': roc_auc} summary_dict_train['scalar'] = scalar_train summary_dict_val['scalar'] = scalar_val if self.config.SUMMARY_HISTOGRAM: ## TODO: add any vectors that you want to visulize. pass # Merge summary merged_summary_train = \ self.summary_train.add_summary(summary_dict_train) merged_summary_val = \ self.summary_val.add_summary(summary_dict_val) # Add saver saver = Saver(self.save_dir) # Whether to use a pre-trained weights if self.config.PRE_TRAIN: pre_train_saver = tf.train.Saver(theta_D) sess_config = tf.ConfigProto(allow_soft_placement=True) # Use soft_placement to place those variables, which can be placed, on GPU # Create Session with tf.Session(config=sess_config) as sess: # Add graph to tensorboard if self.config.SUMMARY and self.config.SUMMARY_GRAPH: if self.config.SUMMARY_TRAIN_VAL: self.summary_train._graph_summary(sess.graph) # Restore the weights from the previous training if self.config.RESTORE: start_epoch = saver.restore(sess) else: # Create a new folder for saving model saver.set_save_path(comments=self.comments) start_epoch = 0 if self.config.PRE_TRAIN: # Restore the pre-trained weights for D pre_train_saver.restore(sess, self.config.PRE_TRAIN_FILE_PATH) initialize_uninitialized_vars(sess) else: # initialize the variables init_var = tf.group(tf.global_variables_initializer(), \ tf.local_variables_initializer()) sess.run(init_var) # Start Training tf.logging.info("Start training!") for epoch in range(1, self.config.EPOCHS + 1): tf.logging.info("Training for epoch {}.".format(epoch)) train_pr_bar = tf.contrib.keras.utils.Progbar(target = \ int(tmp_config.TRAIN_SIZE / tmp_config.BATCH_SIZE)) sess.run(init_op_train, feed_dict={dataset_train._is_training_input: True}) for i in range( int(self.config.TRAIN_SIZE / self.config.BATCH_SIZE)): # Update discriminator d_loss_o, _, summary_out = sess.run([d_loss, d_solver, \ merged_summary_train], \ feed_dict={main_graph.is_training: True, \ dataset_train._is_training_input: True}) # Update generator g_loss_o, _, summary_out = sess.run([g_loss, g_solver, \ merged_summary_train], \ feed_dict={main_graph.is_training: True, \ dataset_train._is_training_input: True}) tf.logging.debug("Training for batch {}.".format(i)) # Update progress bar train_pr_bar.update(i) if self.config.SUMMARY_TRAIN_VAL: # Add summary self.summary_train.summary_writer.add_summary( summary_out, epoch + start_epoch) # Perform validation tf.logging.info("\nValidate for epoch {}.".format(epoch)) sess.run(init_op_val + [reset_op], \ feed_dict = {dataset_train._is_training_input: False}) count = 1 val_pr_bar = tf.contrib.keras.utils.Progbar(target = \ int(tmp_config.VAL_SIZE / tmp_config.BATCH_SIZE)) for i in range( int(self.config.VAL_SIZE / self.config.BATCH_SIZE)): try: accuracy_o, summary_out, roc_auc_o, \ pr_auc_o, val_lab_o, preds_o, \ probs_o, _, _, _ = sess.run([accuracy, merged_summary_val,\ roc_auc, pr_auc, real_lab, \ preds, probs] + update_op,\ feed_dict={main_graph.is_training: False,\ dataset_train._is_training_input: False}) tf.logging.debug( "Validate for batch {}.".format(count)) # Update progress bar val_pr_bar.update(count) count += 1 except (tf.errors.InvalidArgumentError, tf.errors.OutOfRangeError): break tf.logging.info("\nThe current validation accuracy for epoch {} is {:.2f}, roc_auc is {:.2f}, pr_auc is {:.2f}.\n"\ .format(epoch, accuracy_o, roc_auc_o, pr_auc_o)) # Add summary to tensorboard if self.config.SUMMARY_TRAIN_VAL: self.summary_val.summary_writer.add_summary( summary_out, epoch + start_epoch) # Save the model per SAVE_PER_EPOCH if epoch % self.config.SAVE_PER_EPOCH == 0: save_name = str(epoch + start_epoch) saver.save(sess, 'model_' + save_name.zfill(4) \ + '.ckpt') if self.config.SUMMARY_TRAIN_VAL: self.summary_train.summary_writer.flush() self.summary_train.summary_writer.close() self.summary_val.summary_writer.flush() self.summary_val.summary_writer.close() # Save the model after all epochs save_name = str(epoch + start_epoch) saver.save(sess, 'model_' + save_name.zfill(4) + '.ckpt') return
def train(self, Model, DataSet, SAMPLE_X=None, SAMPLE_Y=None): # Reset tf graph. tf.reset_default_graph() # Create input node if not self.config.Y_LABLE: image_batch, init_op, dataset = self._input_fn(DataSet) else: image_batch, label_batch, init_op, dataset = self._input_fn_w_label( DataSet) # image, label = self._input_fn_NP() # using numpy array as feed dict # Build up the graph and loss with tf.device('/gpu:0'): # Create placeholder if self.config.Y_LABLE: y = tf.placeholder( tf.float32, [self.config.BATCH_SIZE, self.config.NUM_CLASSES], name='y') # label batch x = tf.placeholder(tf.float32, [self.config.BATCH_SIZE] + self.config.IMAGE_DIM, name='real_images') # real image else: y = None x = tf.placeholder(tf.float32, [self.config.BATCH_SIZE] + [ self.config.IMAGE_HEIGHT_O, self.config.IMAGE_WIDTH_O, self.config.CHANNEL ], name='real_images') # real image z = tf.placeholder( tf.float32, [self.config.BATCH_SIZE, self.config.Z_DIM]) # latent variable # Build up the graph G, D, D_logits, D_, D_logits_, fm, fm_, model = self._build_train_graph( x, y, z, Model) # # Create the loss: # d_loss, g_loss = self._loss(D, D_logits, D_, D_logits_, x, G, fm, fm_, model.discriminator, y) samples = model.sampler(z, y) # Create optimizer with tf.name_scope('Train'): t_vars = tf.trainable_variables() # theta_G = [var for var in t_vars if 'g_' in var.name] # theta_D = [var for var in t_vars if 'd_' in var.name] theta_G = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="generator") theta_D = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="discriminator") optimizer = self._Adam_optimizer() # Discriminator loss if self.config.LABEL_SMOOTH: d_loss_real = self._sigmoid_cross_entopy_w_logits( 0.9 * tf.ones_like(D), D_logits) else: d_loss_real = self._sigmoid_cross_entopy_w_logits( tf.ones_like(D), D_logits) d_loss_fake = self._sigmoid_cross_entopy_w_logits( tf.zeros_like(D_), D_logits_) d_loss = d_loss_fake + d_loss_real d_updates = tf.keras.optimizers.Adam( lr=1e-4, beta_1=self.config.BETA1).get_updates(d_loss, theta_D) d_optim = tf.group(*d_updates, name="d_train_op") if self.config.UNROLLED_STEP > 0: update_dict = self._extract_update_dict(d_updates) cur_update_dict = update_dict for i in range(self.config.UNROLLED_STEP - 1): cur_update_dict = self._graph_replace( update_dict, cur_update_dict) g_loss = -self._graph_replace(d_loss, cur_update_dict) else: g_loss = -d_loss g_optim = self._train_op(optimizer, g_loss, theta_G) # Add summary if self.config.SUMMARY: summary_dict = {} if self.config.SUMMARY_SCALAR: scaler = { 'generator_loss': g_loss, 'discriminator_loss': d_loss } merged_summary = self.summary.add_summary(summary_dict) # Add saver saver = Saver(self.save_dir) # Create Session sess_config = tf.ConfigProto(allow_soft_placement=True) # Use soft_placement to place those variables, which can be placed, on GPU with tf.Session(config=sess_config) as sess: # if self.config.DEBUG: # sess = tf_debug.LocalCLIDebugWrapperSession(sess) if self.config.SUMMARY and self.config.SUMMARY_GRAPH: self.summary._graph_summary(sess.graph) if self.config.RESTORE: start_epoch = saver.restore(sess) else: saver.set_save_path(comments=self.comments) start_epoch = 0 # initialize the variables init_var = tf.group(tf.global_variables_initializer(), \ tf.local_variables_initializer()) sess.run(init_var) sample_z = np.random.normal(size=(self.config.BATCH_SIZE, 100)) if not self.config.Y_LABLE: ## TODO: support PacGAN for conditional case sess.run(init_op) sample_x = sess.run(image_batch) else: # sample_x, sample_y = sess.run([image_batch, label_batch]) sample_x, sample_y = SAMPLE_X, SAMPLE_Y # sample_x, sample_y = image[:64, ...], label[:64, ...] # for numpy input # Start Training tf.logging.info("Start unrolledGAN traininig!") for epoch in range(start_epoch + 1, self.config.EPOCHS + start_epoch + 1): tf.logging.info("Training for epoch {}.".format(epoch)) train_pr_bar = tf.contrib.keras.utils.Progbar(target= \ int(self.config.TRAIN_SIZE / self.config.BATCH_SIZE)) sess.run(init_op) for i in range( int(self.config.TRAIN_SIZE / self.config.BATCH_SIZE)): batch_z = np.random.normal( size=(self.config.BATCH_SIZE, self.config.Z_DIM)).astype(np.float32) # Fetch a data batch if not self.config.Y_LABLE: image_batch_o = sess.run(image_batch) else: image_batch_o, label_batch_o = sess.run( [image_batch, label_batch]) ## for numpy input # image_batch_o, label_batch_o = image[i * self.config.BATCH_SIZE : (i + 1) * self.config.BATCH_SIZE], \ # label[i * self.config.BATCH_SIZE : (i + 1) * self.config.BATCH_SIZE] if not self.config.Y_LABLE: # Update discriminator _, d_loss_o = sess.run([d_optim, d_loss], feed_dict={ x: image_batch_o, z: batch_z }) # Update generator _ = sess.run([g_optim], feed_dict={ x: image_batch_o, z: batch_z }) _, g_loss_o = sess.run([g_optim, g_loss], feed_dict={ x: image_batch_o, z: batch_z }) else: # Update discriminator _, d_loss_o = sess.run([d_optim, d_loss], feed_dict={ x: image_batch_o, y: label_batch_o, z: batch_z }) # Update generator _ = sess.run([g_optim], feed_dict={ x: image_batch_o, y: label_batch_o, z: batch_z }) _, g_loss_o = sess.run([g_optim, g_loss], feed_dict={ x: image_batch_o, y: label_batch_o, z: batch_z }) # Update progress bar train_pr_bar.update(i) if i % 100 == 0: if self.config.DEBUG: ## Sample image for every 100 update in debug mode if not self.config.Y_LABLE: samples_o, d_loss_o, g_loss_o, summary_o = sess.run( [samples, d_loss, g_loss, merged_summary], feed_dict={ x: sample_x, z: sample_z }) else: samples_o, d_loss_o, g_loss_o, summary_o = sess.run( [samples, d_loss, g_loss, merged_summary], feed_dict={ x: sample_x, y: sample_y, z: sample_z }) save_images(samples_o, image_manifold_size(samples_o.shape[0]), \ os.path.join(self.config.SAMPLE_DIR, 'train_{:02d}_{:02d}.png'.format(epoch, i))) # Save the model per SAVE_PER_EPOCH if epoch % self.config.SAVE_PER_EPOCH == 0: save_name = str(epoch) saver.save(sess, 'model_' + save_name.zfill(4) + '.ckpt') print("Epoch: [%2d/%2d], d_loss: %.8f, g_loss: %.8f" \ % (epoch, self.config.EPOCHS + start_epoch, d_loss_o, g_loss_o)) ## Sample image after every epoch if not self.config.Y_LABLE: samples_o, d_loss_o, g_loss_o, summary_o = sess.run( [samples, d_loss, g_loss, merged_summary], feed_dict={ x: sample_x, z: sample_z }) else: samples_o, d_loss_o, g_loss_o, summary_o = sess.run( [samples, d_loss, g_loss, merged_summary], feed_dict={ x: sample_x, y: sample_y, z: sample_z }) if self.config.SUMMARY: self.summary.summary_writer.add_summary(summary_o, epoch) print("[Sample] d_loss: %.8f, g_loss: %.8f" % (d_loss_o, g_loss_o)) save_images(samples_o, image_manifold_size(samples_o.shape[0]), \ os.path.join(self.config.SAMPLE_DIR, 'train_{:02d}.png'.format(epoch))) if self.config.SUMMARY: self.summary.summary_writer.flush() self.summary.summary_writer.close() # Save the model after all epochs save_name = str(epoch) saver.save(sess, 'model_' + save_name.zfill(4) + '.ckpt') return
def train(self, Model, DataSet): # Reset tf graph. tf.reset_default_graph() image_batch, label_batch, init_op_train, init_op_var = self._input_fn_w_label( DataSet) # Build up the graph and loss with tf.device('/gpu:0'): # Build up the graph logits = self._build_train_graph(image_batch, Model) # Create the loss: loss = self._loss(logits, label_batch) # Create metric: accuracy, update_op, reset_op = self._metric(logits, label_batch) # Create optimizer with tf.name_scope('Train'): optimizer = self._Adam_optimizer() optim = self._train_op(optimizer, loss) # Add summary if self.config.SUMMARY: summary_dict = {} if self.config.SUMMARY_SCALAR: scaler = {'loss': loss, 'accuracy': accuracy} summary_dict['scalar'] = scaler merged_summary = self.summary.add_summary(summary_dict) # Add saver saver = Saver(self.save_dir) # Create Session sess_config = tf.ConfigProto(allow_soft_placement=True) # Use soft_placement to place those variables, which can be placed, on GPU with tf.Session(config=sess_config) as sess: # if self.config.DEBUG: # sess = tf_debug.LocalCLIDebugWrapperSession(sess) if self.config.SUMMARY and self.config.SUMMARY_GRAPH: self.summary._graph_summary(sess.graph) if self.config.RESTORE: start_epoch = saver.restore(sess) else: saver.set_save_path(comments=self.comments) start_epoch = 0 # initialize the variables init_var = tf.group(tf.global_variables_initializer(), \ tf.local_variables_initializer()) sess.run(init_var) # Start Training tf.logging.info("Start traininig!") for epoch in range(start_epoch + 1, self.config.EPOCHS + start_epoch + 1): tf.logging.info("Training for epoch {}.".format(epoch)) train_pr_bar = tf.contrib.keras.utils.Progbar(target= \ int(self.config.TRAIN_SIZE / self.config.BATCH_SIZE)) sess.run(init_op_train) for i in range( int(self.config.TRAIN_SIZE / self.config.BATCH_SIZE)): _, loss_o, accuracy_o, summary_o, _ = sess.run( [optim, loss, accuracy, merged_summary] + update_op) # Update progress bar train_pr_bar.update(i) print("Epoch: [%2d/%2d], training loss: %.8f, training accuracy: %.8f" \ % (epoch, self.config.EPOCHS + start_epoch, loss_o, accuracy_o)) # Do validation sess.run(init_op_var + [reset_op]) for i in range( int(self.config.VAL_SIZE / self.config.BATCH_SIZE)): try: accuracy_o, loss_o, _ = sess.run([accuracy, loss] + update_op) except (tf.errors.InvalidArgumentError, tf.errors.OutOfRangeError): break print("Epoch: [%2d/%2d], validation loss: %.8f, validation accuracy: %.8f" \ % (epoch, self.config.EPOCHS + start_epoch, loss_o, accuracy_o)) # Save the model per SAVE_PER_EPOCH if epoch % self.config.SAVE_PER_EPOCH == 0: save_name = str(epoch) saver.save(sess, 'model_' + save_name.zfill(4) + '.ckpt') if self.config.SUMMARY: self.summary.summary_writer.add_summary(summary_o, epoch) if self.config.SUMMARY: self.summary.summary_writer.flush() self.summary.summary_writer.close() # Save the model after all epochs save_name = str(epoch) saver.save(sess, 'model_' + save_name.zfill(4) + '.ckpt') return
def train(self, Model, DataSet, SAMPLE_X=None, SAMPLE_Y=None): # Reset tf graph. tf.reset_default_graph() # Create input node if not self.config.Y_LABEL: image_batch, init_op, dataset = self._input_fn(DataSet) else: image_batch, label_batch, init_op, dataset = self._input_fn_w_label( DataSet) # image, label = self._input_fn_NP() # using numpy array as feed dict # Build up the graph and loss with tf.device('/gpu:0'): if self.config.LOSS == "PacGAN": # TODO: support conditional GAN for PacGAN y = None if self.config.CROP: image_dims = [ self.config.IMAGE_HEIGHT_O, self.config.IMAGE_WIDTH_O, self.config.CHANNEL * self.config.PAC_NUM ] else: image_dims = self.config.IMAGE_DIM[:-1] + [ self.config.CHANNEL * self.config.PAC_NUM ] x = tf.placeholder(tf.float32, [self.config.BATCH_SIZE] + image_dims, name='real_images') z = [] for i in range(self.config.PAC_NUM): z.append(tf.placeholder(tf.float32, [self.config.BATCH_SIZE, self.config.Z_DIM], \ name = 'z{}'.format(i))) # latent variable else: # Create placeholder if self.config.Y_LABEL: y = tf.placeholder( tf.float32, [self.config.BATCH_SIZE, self.config.NUM_CLASSES], name='y') # label batch x = tf.placeholder(tf.float32, [self.config.BATCH_SIZE] + self.config.IMAGE_DIM, name='real_images') # real image else: y = None x = tf.placeholder(tf.float32, [self.config.BATCH_SIZE] + [ self.config.IMAGE_HEIGHT_O, self.config.IMAGE_WIDTH_O, self.config.CHANNEL ], name='real_images') # real image z = tf.placeholder(tf.float32, [self.config.BATCH_SIZE, self.config.Z_DIM ]) # latent variable if self.config.LOSS == "MRGAN": # Build up the graph for mrGAN G, G_mr, D, D_logits, D_, D_logits_, fm, fm_, D_mr, D_mr_logits, _, model = self._build_train_graph( x, y, z, Model) # Create the loss: d_loss, g_loss, e_loss = self._loss(D, D_logits, D_, D_logits_, x, G, fm, fm_, model.discriminator, y, \ G_mr, D_mr_logits) else: # Build up the graph G, D, D_logits, D_, D_logits_, fm, fm_, model = self._build_train_graph( x, y, z, Model) # Create the loss: d_loss, g_loss = self._loss(D, D_logits, D_, D_logits_, x, G, fm, fm_, model.discriminator, y) # Sample the generated image every epoch if self.config.LOSS == "PacGAN": samples = model.sampler(z[0], y) else: samples = model.sampler(z, y) # Create optimizer with tf.name_scope('Train'): t_vars = tf.trainable_variables() theta_G = [var for var in t_vars if 'g_' in var.name] theta_D = [var for var in t_vars if 'd_' in var.name] if self.config.LOSS == "MRGAN": theta_E = [var for var in t_vars if 'e_' in var.name] if self.config.LOSS in ["WGAN", "WGAN_GP", "FMGAN"]: optimizer = self._RMSProp_optimizer() d_optim_ = self._train_op(optimizer, d_loss, theta_D) elif self.config.LOSS in [ "GAN", "LSGAN", "cGPGAN", "MRGAN", "PacGAN" ]: optimizer = self._Adam_optimizer() if self.config.LOSS == "WGAN": with tf.control_dependencies([d_optim_]): d_optim = tf.group(*(tf.assign(var, \ tf.clip_by_value(var, -self.config.WEIGHT_CLIP, \ self.config.WEIGHT_CLIP)) for var in theta_D)) else: d_optim = self._train_op(optimizer, d_loss, theta_D) if self.config.LOSS == "MRGAN": e_optim = self._train_op(optimizer, e_loss, theta_E) g_optim = self._train_op(optimizer, g_loss, theta_G) # Add summary if self.config.SUMMARY: summary_dict = {} if self.config.SUMMARY_SCALAR: scaler = { 'generator_loss': g_loss, 'discriminator_loss': d_loss } if self.config.LOSS == "MRGAN": scaler['encoder_loss'] = e_loss summary_dict['scalar'] = scaler merged_summary = self.summary.add_summary(summary_dict) # Add saver saver = Saver(self.save_dir) # Create Session sess_config = tf.ConfigProto(allow_soft_placement=True) # Use soft_placement to place those variables, which can be placed, on GPU with tf.Session(config=sess_config) as sess: # if self.config.DEBUG: # sess = tf_debug.LocalCLIDebugWrapperSession(sess) if self.config.SUMMARY and self.config.SUMMARY_GRAPH: self.summary._graph_summary(sess.graph) if self.config.RESTORE: start_epoch = saver.restore(sess) else: saver.set_save_path(comments=self.comments) start_epoch = 0 # initialize the variables init_var = tf.group(tf.global_variables_initializer(), \ tf.local_variables_initializer()) sess.run(init_var) sample_z = np.random.normal(size=(self.config.BATCH_SIZE, 100)) if not self.config.Y_LABEL: ## TODO: support PacGAN for conditional case sess.run(init_op) if self.config.LOSS == "PacGAN": sample_feed_dict_z = {} sample_x_sep = [] for i in range(self.config.PAC_NUM): sample_feed_dict_z[z[i]] = np.random.normal( size=(self.config.BATCH_SIZE, self.config.Z_DIM)).astype(np.float32) sample_x_sep.append(sess.run(image_batch)) sample_x = np.concatenate(sample_x_sep, axis=3) else: sample_x = sess.run(image_batch) else: sample_x, sample_y = SAMPLE_X, SAMPLE_Y # Start Training tf.logging.info("Start traininig!") for epoch in range(start_epoch + 1, self.config.EPOCHS + start_epoch + 1): tf.logging.info("Training for epoch {}.".format(epoch)) train_pr_bar = tf.contrib.keras.utils.Progbar(target= \ int(self.config.TRAIN_SIZE / self.config.BATCH_SIZE)) sess.run(init_op) for i in range( int(self.config.TRAIN_SIZE / self.config.BATCH_SIZE)): if self.config.LOSS == "PacGAN": image_batch_sep = [] feed_dict_z = {} for j in range(self.config.PAC_NUM): feed_dict_z[z[j]] = np.random.normal( size=(self.config.BATCH_SIZE, self.config.Z_DIM)).astype(np.float32) image_batch_sep.append(sess.run(image_batch)) image_batch_o = np.concatenate(image_batch_sep, axis=3) else: batch_z = np.random.normal( size=(self.config.BATCH_SIZE, self.config.Z_DIM)).astype(np.float32) # Fetch a data batch if not self.config.Y_LABEL: image_batch_o = sess.run(image_batch) else: image_batch_o, label_batch_o = sess.run( [image_batch, label_batch]) if not self.config.Y_LABEL: if self.config.LOSS == "PacGAN": # Update discriminator _, d_loss_o = sess.run([d_optim, d_loss], feed_dict={ x: image_batch_o, **feed_dict_z }) # Update generator _ = sess.run([g_optim], feed_dict={ x: image_batch_o, **feed_dict_z }) _, g_loss_o = sess.run([g_optim, g_loss], feed_dict={ x: image_batch_o, **feed_dict_z }) if self.config.LOSS == "MRGAN": # Update encoder _, e_loss_o = sess.run([e_optim, e_loss], feed_dict={ x: image_batch_o, **feed_dict_z }) else: # Update discriminator _, d_loss_o = sess.run([d_optim, d_loss], feed_dict={ x: image_batch_o, z: batch_z }) # Update generator _ = sess.run([g_optim], feed_dict={ x: image_batch_o, z: batch_z }) _, g_loss_o = sess.run([g_optim, g_loss], feed_dict={ x: image_batch_o, z: batch_z }) if self.config.LOSS == "MRGAN": # Update encoder _, e_loss_o = sess.run([e_optim, e_loss], feed_dict={ x: image_batch_o, z: batch_z }) else: # Update discriminator _, d_loss_o = sess.run([d_optim, d_loss], feed_dict={ x: image_batch_o, y: label_batch_o, z: batch_z }) # Update generator _ = sess.run([g_optim], feed_dict={ x: image_batch_o, y: label_batch_o, z: batch_z }) _, g_loss_o = sess.run([g_optim, g_loss], feed_dict={ x: image_batch_o, y: label_batch_o, z: batch_z }) if self.config.LOSS == "MRGAN": # Update encoder _, e_loss_o = sess.run([e_optim, e_loss], feed_dict={ x: image_batch_o, y: label_batch_o, z: batch_z }) # Update progress bar train_pr_bar.update(i) if i % 100 == 0: if self.config.DEBUG: ## Sample image for every 100 update in debug mode if not self.config.Y_LABEL: if self.config.LOSS == "PacGAN": samples_o, d_loss_o, g_loss_o, summary_o = sess.run( [ samples, d_loss, g_loss, merged_summary ], feed_dict={ x: image_batch_o, **sample_feed_dict_z }) else: samples_o, d_loss_o, g_loss_o, summary_o = sess.run( [ samples, d_loss, g_loss, merged_summary ], feed_dict={ x: sample_x, z: sample_z }) else: samples_o, d_loss_o, g_loss_o, summary_o = sess.run( [samples, d_loss, g_loss, merged_summary], feed_dict={ x: sample_x, y: sample_y, z: sample_z }) save_images(samples_o[:64], image_manifold_size(64), \ os.path.join(self.config.SAMPLE_DIR, 'train_{:02d}_{:02d}.png'.format(epoch, i))) # Save the model per SAVE_PER_EPOCH if epoch % self.config.SAVE_PER_EPOCH == 0: save_name = str(epoch) saver.save(sess, 'model_' + save_name.zfill(4) + '.ckpt') if self.config.LOSS == "MRGAN": print("Epoch: [%2d/%2d], d_loss: %.8f, g_loss: %.8f, e_loss: %.8f" \ % (epoch, self.config.EPOCHS + start_epoch, d_loss_o, g_loss_o, e_loss_o)) else: print("Epoch: [%2d/%2d], d_loss: %.8f, g_loss: %.8f" \ % (epoch, self.config.EPOCHS + start_epoch, d_loss_o, g_loss_o)) ## Sample image after every epoch if not self.config.Y_LABEL: if self.config.LOSS == "PacGAN": samples_o, d_loss_o, g_loss_o, summary_o = sess.run( [samples, d_loss, g_loss, merged_summary], feed_dict={ x: image_batch_o, **sample_feed_dict_z }) else: samples_o, d_loss_o, g_loss_o, summary_o = sess.run( [samples, d_loss, g_loss, merged_summary], feed_dict={ x: sample_x, z: sample_z }) else: samples_o, d_loss_o, g_loss_o, summary_o = sess.run( [samples, d_loss, g_loss, merged_summary], feed_dict={ x: sample_x, y: sample_y, z: sample_z }) if self.config.SUMMARY: self.summary.summary_writer.add_summary(summary_o, epoch) print("[Sample] d_loss: %.8f, g_loss: %.8f" % (d_loss_o, g_loss_o)) save_images(samples_o[:64], image_manifold_size(64), \ os.path.join(self.config.SAMPLE_DIR, 'train_{:02d}.png'.format(epoch))) if self.config.SUMMARY: self.summary.summary_writer.flush() self.summary.summary_writer.close() # Save the model after all epochs save_name = str(epoch) saver.save(sess, 'model_' + save_name.zfill(4) + '.ckpt') return
def train(self, Dataset, Model, sample_y): # Create input node init_op_train, init_op_val, NNIO = self._input_fn_train_val(Dataset) x_l_c, y_l_c, x_l_d, y_l_d, x_u = NNIO with tf.device('/gpu:0'): # Build up the graph PH, G, D, C, model = self._build_train_graph(Model) z_g_ph, y_g_ph, x_l_c_ph, y_l_c_ph, x_l_d_ph, y_l_d_ph, x_u_d_ph, x_u_c_ph, \ train_ph, Lambda = PH if self.config.DATA_NAME == "cifar10": C_real_logits, C_unl_logits, C_unl_d_logits, C_fake_logits, C_unl_logits_rep = C else: C_real_logits, C_unl_logits, C_unl_d_logits, C_fake_logits = C lambda_1_ph = Lambda[0] if self.config.DATA_NAME == "cifar10": lambda_2_ph = Lambda[1] ## Sample the images sample_z_ph = tf.placeholder( tf.float32, [self.config.SAMPLE_SIZE, self.config.Z_DIM], name='sample_latent_variable') # latent variable sample_y_ph = tf.placeholder( tf.float32, [self.config.SAMPLE_SIZE, self.config.NUM_CLASSES], name='condition_label_for_sampler') samples = model.good_sampler(sample_z_ph, sample_y_ph) # Create Loss function d_loss, g_loss, c_loss = self._goodGAN_loss( G, D, C, None, [y_g_ph, y_l_c_ph], Lambda, model.discriminator) # Create the metric accuracy, update_op, reset_op, preds, probs = self._metric( C_real_logits, y_l_c_ph) # Create optimizer with tf.name_scope('Train'): t_vars = tf.trainable_variables() g_vars = [var for var in t_vars if 'good_generator' in var.name] d_vars = [var for var in t_vars if 'discriminator' in var.name] c_vars = [var for var in t_vars if 'classifier' in var.name] d_optimizer = self._Adam_optimizer(lr=self.lr_ph, beta1=self.config.BETA1) g_optimizer = self._Adam_optimizer(lr=self.lr_ph, beta1=self.config.BETA1) c_optimizer = self._Adam_optimizer(lr=self.cla_lr_ph, beta1=0.5) # optimizer = self._Adam_optimizer(self.lr_ph, self.config.BETA1) d_solver = self._train_op(d_optimizer, d_loss, \ var_list = d_vars) g_solver = self._train_op(g_optimizer, g_loss, \ var_list = g_vars) c_solver_ = self._train_op(c_optimizer, c_loss, var_list=c_vars) ## add weight normalization for classifier # c_weight_loss = self._loss_weight_l2([var for var in t_vars if 'c_h2_lin' in var.name and 'b' not in var.name], eta = 1e-4) # c_loss += c_weight_loss # add moving average ema = tf.train.ExponentialMovingAverage(decay=0.9999) with tf.control_dependencies([c_solver_]): c_solver = ema.apply(c_vars) # Add summary if self.config.SUMMARY: summary_dict_train = {} summary_dict_val = {} if self.config.SUMMARY_SCALAR: scalar_train = { 'g_loss': g_loss, 'd_loss': d_loss, 'c_loss': c_loss, 'train_accuracy': accuracy } scalar_val = {'val_accuracy': accuracy} summary_dict_train['scalar'] = scalar_train summary_dict_val['scalar'] = scalar_val # Merge summary merged_summary_train = \ self.summary_train.add_summary(summary_dict_train) merged_summary_val = \ self.summary_val.add_summary(summary_dict_val) # Add saver saver = Saver(self.save_dir) # Create a session for training sess_config = tf.ConfigProto( allow_soft_placement=True, gpu_options=tf.GPUOptions(per_process_gpu_memory_fraction=0.8)) # Use soft_placement to place those variables, which can be placed, on GPU # Build up latent variable for samples sample_z = np.random.uniform(low=-1.0, high=1.0, size=(self.config.SAMPLE_SIZE, self.config.Z_DIM)).astype( np.float32) with tf.Session(config=sess_config) as sess: if self.config.DEBUG: sess = tf_debug.LocalCLIDebugWrapperSession(sess) # Add graph to tensorboard if self.config.SUMMARY and self.config.SUMMARY_GRAPH: self.summary_train._graph_summary(sess.graph) lr = self.config.LEARNING_RATE cla_lr = self.config.CLA_LEARNINIG_RATE # Restore teh weights from the previous training if self.config.RESTORE: start_epoch = saver.restore(sess, dir_names=self.config.RUN, epoch=self.config.RESTORE_EPOCH) if start_epoch >= 300: lr = lr * 0.995**(start_epoch - 300) cla_lr = cla_lr * 0.99**(start_epoch - 300) initialize_uninitialized_vars(sess) else: # Create a new folder for saving model saver.set_save_path(comments=self.comments) start_epoch = 0 # initialize the variables init_var = tf.group(tf.global_variables_initializer(), \ tf.local_variables_initializer()) sess.run(init_var) # Start Training tf.logging.info("Start training!") for epoch in range(1, self.config.EPOCHS + 1): tf.logging.info("Training for epoch {}.".format(epoch + start_epoch)) train_pr_bar = tf.contrib.keras.utils.Progbar(target= \ int(self.config.TRAIN_SIZE / self.config.BATCH_SIZE)) lambda_1 = self.config.FAKE_G_LAMBDA if (start_epoch + epoch) > 200 else 0. if self.config.DATA_NAME == "cifar10": # rampup_value = rampup(start_epoch + epoch) # rampdown_value = rampdown(start_epoch + epoch) # lambda_2 = rampup_value * 20 if epoch > 1 else 0 lambda_2 = 0.5 if epoch > 67 else 0 # b1_c = rampdown_value * 0.9 + (1.0 - rampdown_value) * 0.5 # lambda_2 = 0 if (start_epoch + epoch >= 300): lr = lr * 0.995 cla_lr = cla_lr * 0.99 tf.logging.info("Lambda_1 {}.".format(lambda_1)) sess.run(init_op_train) if self.config.PRE_TRAIN and (start_epoch + epoch <= 30): tf.logging.info("Pre Training!") for i in range( int(self.config.TRAIN_SIZE / self.config.BATCH_SIZE)): # Feature labeled and unlabeled data x_l_c_o, y_l_c_o, x_l_d_o, y_l_d_o, x_u_o = sess.run( [x_l_c, y_l_c, x_l_d, y_l_d, x_u]) # Define latent vector z = np.random.uniform(low=-1.0, high=1.0, size=(self.config.BATCH_SIZE_G, self.config.Z_DIM)).astype( np.float32) # y = y_l_c_o ## Todo: think about if this is the best way y_temp = np.random.randint( low=0, high=self.config.NUM_CLASSES, size=(self.config.BATCH_SIZE_G)) y = np.zeros((self.config.BATCH_SIZE_G, self.config.NUM_CLASSES)) y[np.arange(self.config.BATCH_SIZE_G), y_temp] = 1 if self.config.DATA_NAME == "cifar10": z_g_ph, y_g_ph, x_l_c_ph, y_l_c_ph, x_l_d_ph, y_l_d_ph, x_u_d_ph, x_u_c_ph, \ train_ph, Lambda_ph = PH lambda_1_ph, lambda_2_ph = Lambda_ph else: z_g_ph, y_g_ph, x_l_c_ph, y_l_c_ph, x_l_d_ph, y_l_d_ph, x_u_d_ph, x_u_c_ph, \ train_ph, lambda_1_ph = PH lambda_1_ph = lambda_1_ph[0] feed_dict = { z_g_ph: z, y_g_ph: y, x_l_c_ph: x_l_c_o, y_l_c_ph: y_l_c_o, x_l_d_ph: x_l_d_o, y_l_d_ph: y_l_d_o, x_u_d_ph: x_u_o[:self.config.BATCH_SIZE_U_D, ...], x_u_c_ph: x_u_o[self.config. BATCH_SIZE_U_D:self.config.BATCH_SIZE_U_D + self.config.BATCH_SIZE_U_C, ...], train_ph: True, lambda_1_ph: lambda_1, self.lr_ph: lr, self.cla_lr_ph: cla_lr } if self.config.DATA_NAME == "cifar10": feed_dict[lambda_2_ph] = lambda_2 # feed_dict[self.cla_beta1] = b1_c # Update classifier _, c_loss_o = sess.run([c_solver, c_loss], \ feed_dict=feed_dict) # Update progress bar train_pr_bar.update(i) else: tf.logging.info("Good-GAN Training!") for i in range( int(self.config.TRAIN_SIZE / self.config.BATCH_SIZE)): # Feature labeled and unlabeled data x_l_c_o, y_l_c_o, x_l_d_o, y_l_d_o, x_u_o = sess.run( [x_l_c, y_l_c, x_l_d, y_l_d, x_u]) # Define latent vector z = np.random.uniform(low=-1.0, high=1.0, size=(self.config.BATCH_SIZE_G, self.config.Z_DIM)).astype( np.float32) y_temp = np.random.randint( low=0, high=self.config.NUM_CLASSES, size=(self.config.BATCH_SIZE_G)) y = np.zeros((self.config.BATCH_SIZE_G, self.config.NUM_CLASSES)) y[np.arange(self.config.BATCH_SIZE_G), y_temp] = 1 if self.config.DATA_NAME == "cifar10": z_g_ph, y_g_ph, x_l_c_ph, y_l_c_ph, x_l_d_ph, y_l_d_ph, x_u_d_ph, x_u_c_ph, \ train_ph, Lambda_ph = PH lambda_1_ph, lambda_2_ph = Lambda_ph else: z_g_ph, y_g_ph, x_l_c_ph, y_l_c_ph, x_l_d_ph, y_l_d_ph, x_u_d_ph, x_u_c_ph, \ train_ph, lambda_1_ph = PH lambda_1_ph = lambda_1_ph[0] feed_dict = { z_g_ph: z, y_g_ph: y, x_l_c_ph: x_l_c_o, y_l_c_ph: y_l_c_o, x_l_d_ph: x_l_d_o, y_l_d_ph: y_l_d_o, x_u_d_ph: x_u_o[:self.config.BATCH_SIZE_U_D, ...], x_u_c_ph: x_u_o[self.config. BATCH_SIZE_U_D:self.config.BATCH_SIZE_U_D + self.config.BATCH_SIZE_U_C, ...], train_ph: True, lambda_1_ph: lambda_1, self.lr_ph: lr, self.cla_lr_ph: cla_lr } if self.config.DATA_NAME == "cifar10": feed_dict[lambda_2_ph] = lambda_2 # feed_dict[self.cla_beta1] = b1_c # Update discriminator _, d_loss_o = sess.run([d_solver, d_loss], \ feed_dict = feed_dict) # Update generator _, g_loss_o = sess.run([g_solver, g_loss], \ feed_dict = feed_dict) # _, g_loss_o = sess.run([g_solver, g_loss], \ # feed_dict=feed_dict) # Update classifier _, c_loss_o = sess.run([c_solver, c_loss], \ feed_dict = feed_dict) # Update progress bar train_pr_bar.update(i) # Get the training statistics summary_train_o, d_loss_o, g_loss_o, c_loss_o = sess.run( [merged_summary_train, d_loss, g_loss, c_loss], feed_dict=feed_dict) tf.logging.info("The training stats for epoch {}: g_loss: {:.2f}, d_loss {:.2f}, c_loss: {:.2f}."\ .format(epoch + start_epoch, g_loss_o, d_loss_o, c_loss_o)) if self.config.SUMMARY: # Add summary self.summary_train.summary_writer.add_summary( summary_train_o, epoch + start_epoch) # Perform validation tf.logging.info("\nValidate for epoch {}.".format(epoch + start_epoch)) sess.run(init_op_val + [reset_op]) if not self.config.VAL_STEP: # perform full validation while True: try: x_l_c_o, y_l_c_o = sess.run([x_l_c, y_l_c]) feed_dict = { z_g_ph: z, y_g_ph: y, x_l_c_ph: x_l_c_o, y_l_c_ph: y_l_c_o, x_l_d_ph: x_l_d_o, y_l_d_ph: y_l_d_o, x_u_d_ph: x_u_o[:self.config.BATCH_SIZE_U_D, ...], x_u_c_ph: x_u_o[self.config.BATCH_SIZE_U_D:self.config. BATCH_SIZE_U_D + self.config.BATCH_SIZE_U_C, ...], train_ph: False, lambda_1_ph: lambda_1, self.lr_ph: lr, self.cla_lr_ph: cla_lr } if self.config.DATA_NAME == "cifar10": feed_dict[lambda_2_ph] = lambda_2 # feed_dict[self.cla_beta1] = b1_c accuracy_o, summary_val_o, _ = sess.run([accuracy, merged_summary_val] + update_op, \ feed_dict = feed_dict) except (tf.errors.InvalidArgumentError, tf.errors.OutOfRangeError, ValueError): break else: for i in range(self.config.VAL_STEP): x_l_c_o, y_l_c_o = sess.run([x_l_c, y_l_c]) feed_dict = { z_g_ph: z, y_g_ph: y, x_l_c_ph: x_l_c_o, y_l_c_ph: y_l_c_o, x_l_d_ph: x_l_d_o, y_l_d_ph: y_l_d_o, x_u_d_ph: x_u_o[:self.config.BATCH_SIZE_U_D, ...], x_u_c_ph: x_u_o[self.config. BATCH_SIZE_U_D:self.config.BATCH_SIZE_U_D + self.config.BATCH_SIZE_U_C, ...], train_ph: False, lambda_1_ph: lambda_1, self.lr_ph: lr, self.cla_lr_ph: cla_lr } if self.config.DATA_NAME == "cifar10": feed_dict[lambda_2_ph] = lambda_2 # feed_dict[self.cla_beta1] = b1_c accuracy_o, summary_val_o, _ = sess.run([accuracy, merged_summary_val] + update_op, \ feed_dict=feed_dict) tf.logging.info( "\nThe current validation accuracy for epoch {} is {:.2f}.\n" \ .format(epoch + start_epoch, accuracy_o)) # Add summary to tensorboard if self.config.SUMMARY: self.summary_val.summary_writer.add_summary( summary_val_o, epoch + start_epoch) # Sample generated images samples_o = sess.run(samples, feed_dict={ sample_z_ph: sample_z, sample_y_ph: sample_y }) if self.config.DATA_NAME == "mnist": samples_o = np.reshape(samples_o, [-1, 28, 28, 1]) save_images(samples_o[:64], image_manifold_size(64), \ os.path.join(self.config.SAMPLE_DIR, 'train_{:02d}.png'.format(epoch + start_epoch))) # Save the model per SAVE_PER_EPOCH if epoch % self.config.SAVE_PER_EPOCH == 0: save_name = str(epoch + start_epoch) saver.save(sess, 'model_' + save_name.zfill(4) \ + '.ckpt') if self.config.SUMMARY: self.summary_train.summary_writer.flush() self.summary_train.summary_writer.close() self.summary_val.summary_writer.flush() self.summary_val.summary_writer.close() # Save the model after all epochs save_name = str(epoch + start_epoch) saver.save(sess, 'model_' + save_name.zfill(4) + '.ckpt') return
def train(self, model): tf.reset_default_graph() # Input node image, label, init_op_train, init_op_val \ = self._input_fn_train_val() # Build up the train graph with tf.device('/cpu:0'): #Todo: parameterize loss, accuracy, update_op, reset_op, histogram \ = self._build_train_graph(image, label, model) with tf.name_scope('Train'): optimizer = self._SGD_w_Momentum_optimizer() train_op, grads = self._train_op_w_grads(optimizer, loss) tf.logging.debug(utils.variable_name_string()) # Add summary if self.config.SUMMARY: if self.config.SUMMARY_TRAIN_VAL: summary_dict_train = {} summary_dict_val = {} if self.config.SUMMARY_SCALAR: scalar_train = {'train_loss': loss} scalar_val = {'val_loss': loss, 'val_accuracy': accuracy} summary_dict_train['scalar'] = scalar_train summary_dict_val['scalar'] = scalar_val if self.config.SUMMARY_IMAGE: image = {'input_image': image} summary_dict_train['image'] = image if self.config.SUMMARY_HISTOGRAM: histogram ['Conv_Block_0_Weight'] = \ [var for var in tf.global_variables() \ if var.name=='conv2d/kernel:0'][0] histogram = utils.grads_dict(grads, histogram) summary_dict_train['histogram'] = histogram merged_summary_train = \ self.summary_train.add_summary(summary_dict_train) merged_summary_val = \ self.summary_val.add_summary(summary_dict_val) # Add saver saver = Saver(self.save_dir) with tf.Session() as sess: if self.config.SUMMARY and self.config.SUMMARY_GRAPH: if self.config.SUMMARY_TRAIN_VAL: self.summary_train._graph_summary(sess.graph) if self.config.RESTORE: start_epoch = saver.restore(sess) else: # Create a new folder for saving model saver.set_save_path(comments=self.comments) start_epoch = 0 # initialize the variables init_var = tf.group(tf.global_variables_initializer(), \ tf.local_variables_initializer()) sess.run(init_var) for epoch in range(1, self.config.EPOCHS + 1): sess.run(init_op_train) while True: try: _, loss_out, summary_out = \ sess.run([train_op, loss, merged_summary_train]) except tf.errors.OutOfRangeError: break if self.config.SUMMARY_TRAIN_VAL: self.summary_train.summary_writer.add_summary( summary_out, epoch) ## Perform test on validation sess.run([init_op_val, reset_op]) loss_val = [] while True: try: loss_out, accuracy_out, _, summary_out = \ sess.run([loss, accuracy, update_op, merged_summary_val]) loss_val.append(loss_out) except tf.errors.OutOfRangeError: tf.logging.info("The current validation loss for epoch {} is {:.2f}, accuracy is {:.2f}."\ .format(epoch, np.mean(loss_val), accuracy_out)) break if self.config.SUMMARY_TRAIN_VAL: self.summary_val.summary_writer.add_summary( summary_out, epoch) # Save the model per SAVE_PER_EPOCH if epoch % self.config.SAVE_PER_EPOCH == 0: save_name = str(epoch + start_epoch) saver.save(sess, 'model_' + save_name.zfill(4) \ + '.ckpt') if self.config.SUMMARY_TRAIN_VAL: self.summary_train.summary_writer.flush() self.summary_train.summary_writer.close() self.summary_val.summary_writer.flush() self.summary_val.summary_writer.close() # Save the model after all epochs save_name = str(epoch + start_epoch) saver.save(sess, 'model_' + save_name.zfill(4) + '.ckpt') return