def train(self, args): print("\n [*] Training....") if not os.path.exists(args.log_directory): os.makedirs(args.log_directory) self.optimizer = tf.train.GradientDescentOptimizer( self.learning_rate).minimize(self.loss, var_list=self.vars) self.saver = tf.train.Saver() self.summary_op = tf.summary.merge_all(self.model_collection[0]) self.writer = tf.summary.FileWriter(args.log_directory + "/summary/", graph=self.sess.graph) total_num_parameters = 0 for variable in tf.trainable_variables(): total_num_parameters += np.array( variable.get_shape().as_list()).prod() print(" [*] Number of trainable parameters: {}".format( total_num_parameters)) self.sess.run(tf.global_variables_initializer()) self.sess.run(tf.local_variables_initializer()) print(' [*] Loading training set...') dataloader = Dataloader(file=args.dataset_training, isTraining=self.isTraining) patches_disp, patches_gt = dataloader.get_training_patches( self.patch_size, args.threshold) line = dataloader.disp_filename num_samples = dataloader.count_text_lines(args.dataset_training) print(' [*] Training data loaded successfully') epoch = 0 iteration = 0 lr = self.initial_learning_rate coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) print(" [*] Start Training...") while epoch < self.epoch: for i in range(num_samples): batch_disp, batch_gt, filename = self.sess.run( [patches_disp, patches_gt, line]) print(" [*] Training image: " + filename) step_image = 0 while step_image < len(batch_disp): offset = (step_image * self.batch_size) % ( batch_disp.shape[0] - self.batch_size) batch_data = batch_disp[offset:(offset + self.batch_size), :, :, :] batch_labels = batch_gt[offset:(offset + self.batch_size), self.radius:self.radius + 1, self.radius:self.radius + 1, :] _, loss, summary_str = self.sess.run( [self.optimizer, self.loss, self.summary_op], feed_dict={ self.disp: batch_data, self.gt: batch_labels, self.learning_rate: lr }) print("Epoch: [%2d]" % epoch + ", Image: [%2d]" % i + ", Iter: [%2d]" % iteration + ", Loss: [%2f]" % loss) self.writer.add_summary(summary_str, global_step=iteration) iteration = iteration + 1 step_image = step_image + self.batch_size epoch = epoch + 1 if np.mod(epoch, args.save_epoch_freq) == 0: self.saver.save(self.sess, args.log_directory + '/' + self.model_name, global_step=iteration) if epoch == 10: lr = lr / 10 coord.request_stop() coord.join(threads)
def train(self, args): print("\n [*] Training....") if not os.path.exists(args.log_directory): os.makedirs(args.log_directory) self.optimizer = tf.train.GradientDescentOptimizer( self.learning_rate).minimize(self.loss, var_list=self.vars) self.saver = tf.train.Saver() self.summary_op = tf.summary.merge_all(self.model_collection[0]) self.writer = tf.summary.FileWriter(args.log_directory + "/summary/", graph=self.sess.graph) total_num_parameters = 0 for variable in tf.trainable_variables(): total_num_parameters += np.array( variable.get_shape().as_list()).prod() print(" [*] Number of trainable parameters: {}".format( total_num_parameters)) init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) self.sess.run(init_op) print(' [*] Loading training set...') dataloader = Dataloader(file=args.dataset_training) left_files, gt_files = dataloader.read_list_file() print(' [*] Training data loaded successfully') epoch = 0 iteration = 0 lr = self.initial_learning_rate print(" [*] Start Training...") while epoch < self.epoch: for i, item in enumerate(left_files): print(" [*] Loading train image: " + left_files[i]) disp_patches, gt_patches = dataloader.get_training_patches( left_files[i], gt_files[i], self.patch_size) batch_disp, batch_gt = self.sess.run( [disp_patches, gt_patches]) step_image = 0 while step_image < len(batch_disp): offset = (step_image * self.batch_size) % ( batch_disp.shape[0] - self.batch_size) batch_data = batch_disp[offset:(offset + self.batch_size), :, :, :] batch_labels = batch_gt[offset:(offset + self.batch_size), int(self.patch_size / 2):int(self.patch_size / 2) + 1, int(self.patch_size / 2):int(self.patch_size / 2) + 1, :] _, loss, summary_str = self.sess.run( [self.optimizer, self.loss, self.summary_op], feed_dict={ self.disp: batch_data, self.gt: batch_labels, self.learning_rate: lr }) print("Epoch: [%2d]" % epoch + ", Image: [%2d]" % i + ", Iter: [%2d]" % iteration + ", Loss: [%2f]" % loss) self.writer.add_summary(summary_str, global_step=iteration) iteration = iteration + 1 step_image = step_image + self.batch_size epoch = epoch + 1 if np.mod(epoch, args.save_epoch_freq) == 0: self.saver.save(self.sess, args.log_directory + '/' + self.model_name, global_step=iteration) if (epoch == 10): lr = lr / 10