def run(self, finetune_layers, epochs, learning_rate=0.01, batch_size=128, keep_prob=1.0, memory_usage=1.0, device='/gpu:0', save_ckpt_dir='', init_ckpt_file='', use_adam_optimizer=False): """ Run a training on part of the model (retrain/finetune) Args: finetune_layers: epochs: learning_rate: batch_size: keep_prob: memory_usage: device: show_misclassified: validate_on_each_epoch: save_ckpt_dir: init_ckpt_file: """ # create datasets data_train = self.create_dataset(is_training=True) data_val = self.create_dataset(is_training=False) # Get ops to init the dataset iterators and get a next batch init_train_iterator_op, init_val_iterator_op, get_next_batch_op = ops.get_dataset_ops( data_train, data_val, batch_size, self.data['training_count'], self.data['validation_count'], shuffle=True) # Initialize model and create input placeholders with tf.device(device): ph_images = tf.placeholder(tf.float32, [ None, self.model_def.image_size, self.model_def.image_size, 3 ]) ph_labels = tf.placeholder(tf.float32, [None, self.num_classes]) # Could set the first placholder dimension to batch_size, but this wouldn't work with leftover data that not form a whole batch ph_keep_prob = tf.placeholder(tf.float32) model = self.model_def(ph_images, keep_prob=ph_keep_prob, num_classes=self.num_classes, retrain_layer=finetune_layers) final_op = model.get_final_op() # Get a list with all trainable variables and print infos for the current run retrain_vars = model.get_retrain_vars() restore_vars = model.get_restore_vars() self.print_infos(retrain_vars, restore_vars, learning_rate, batch_size, keep_prob, use_adam_optimizer) # Add/Get the different operations to optimize (loss, train and validate) with tf.device(device): loss_op = ops.get_loss_op(final_op, ph_labels) train_op = ops.get_train_op(loss_op, learning_rate, retrain_vars, use_adam_optimizer) accuracy_op, correct_prediction_op, predicted_index_op, true_index_op = ops.get_validation_ops( final_op, ph_labels) # Get the number of training/validation steps per epoch to get through all images batches_per_epoch_train = int( math.ceil(self.data['training_count'] / (batch_size + 0.0))) batches_per_epoch_val = int( math.ceil(self.data['validation_count'] / (batch_size + 0.0))) # Initialize a saver, create a session config and start a session saver = tf.train.Saver() config = tf.ConfigProto() config.gpu_options.per_process_gpu_memory_fraction = memory_usage with tf.Session(config=config) as sess: # Init all variables sess.run(tf.global_variables_initializer()) # Load the pretrained variables or a saved checkpoint if init_ckpt_file: saver.restore(sess, init_ckpt_file) else: model.load_initial_weights(sess) utils.print_output_header(self.data['training_count'], self.data['validation_count']) for epoch in range(epochs): is_last_epoch = True if (epoch + 1) == epochs else False train_loss, train_acc = utils.run_training( sess, train_op, loss_op, accuracy_op, init_train_iterator_op, get_next_batch_op, ph_images, ph_labels, ph_keep_prob, keep_prob, batches_per_epoch_train) return_misclassified = is_last_epoch test_loss, test_acc, misclassified = utils.run_validation( sess, loss_op, accuracy_op, correct_prediction_op, predicted_index_op, true_index_op, final_op, init_val_iterator_op, get_next_batch_op, ph_images, ph_labels, ph_keep_prob, batches_per_epoch_val, return_misclassified) utils.print_output_epoch(epoch + 1, train_loss, train_acc, test_loss, test_acc) # show missclassified list on last epoch if is_last_epoch: utils.print_misclassified(sess, misclassified, self.data['labels']) # save session in a checkpoint file if self.write_checkpoints or is_last_epoch: utils.save_session_to_checkpoint_file( sess, saver, epoch, save_ckpt_dir)
def run(self, layer_inputs, epochs, learning_rate=0.01, batch_size=128, keep_prob=1.0, memory_usage=1.0, device='/gpu:0', save_ckpt_dir='', init_ckpt_file='', use_adam_optimizer=False, shuffle=True, use_regularizer=False): """ Run training Args: epochs: learning_rate: batch_size: keep_prob: memory_usage: device: show_misclassified: validate_on_each_epoch: save_ckpt_dir: init_ckpt_file: """ # create datasets data_train = self.create_dataset(is_training=True) data_val = self.create_dataset(is_training=False) # Get ops to init the dataset iterators and get a next batch init_train_iterator_op, init_val_iterator_op, get_next_batch_op = ops.get_dataset_ops( data_train, data_val, batch_size, train_size=self.data['training_count'], val_size=self.data['validation_count'], shuffle=shuffle) # Initialize model and create input placeholders with tf.device(device): ph_keep_prob = tf.placeholder(tf.float32) ph_data, ph_labels, final_op = self.create_model( layer_inputs, keep_prob, use_regularizer) # Get a list with all trainable variables and print infos for the current run train_vars = tf.trainable_variables() self.print_infos(train_vars, learning_rate, batch_size, keep_prob, use_adam_optimizer) # Add/Get the different operations to optimize (loss, train and validate) with tf.device(device): loss_op = ops.get_loss_op(final_op, ph_labels) train_op = ops.get_train_op(loss_op, learning_rate, train_vars, use_adam_optimizer) accuracy_op, correct_prediction_op, predicted_index_op, true_index_op = ops.get_validation_ops( final_op, ph_labels) # Get the number of training/validation steps per epoch to get through all images batches_per_epoch_train = int( math.ceil(self.data['training_count'] / (batch_size + 0.0))) batches_per_epoch_val = int( math.ceil(self.data['validation_count'] / (batch_size + 0.0))) # Initialize a saver, create a session config and start a session saver = tf.train.Saver() gpu_options = tf.GPUOptions() gpu_options.per_process_gpu_memory_fraction = memory_usage gpu_options.visible_device_list = "1" config = tf.ConfigProto(gpu_options=gpu_options) server = tf.train.Server.create_local_server(config=config) with tf.Session(target=server.target) as sess: # Init all variables sess.run(tf.global_variables_initializer()) # Load the pretrained variables or a saved checkpoint if init_ckpt_file: saver.restore(sess, init_ckpt_file) utils.print_output_header(self.data['training_count'], self.data['validation_count']) for epoch in range(epochs): is_last_epoch = True if (epoch + 1) == epochs else False train_loss, train_acc = utils.run_training( sess, train_op, loss_op, accuracy_op, init_train_iterator_op, get_next_batch_op, ph_data, ph_labels, ph_keep_prob, keep_prob, batches_per_epoch_train) return_misclassified = is_last_epoch test_loss, test_acc, misclassified = utils.run_validation( sess, loss_op, accuracy_op, correct_prediction_op, predicted_index_op, true_index_op, final_op, init_val_iterator_op, get_next_batch_op, ph_data, ph_labels, ph_keep_prob, batches_per_epoch_val, return_misclassified) utils.print_output_epoch(epoch + 1, train_loss, train_acc, test_loss, test_acc) # show missclassified list on last epoch if is_last_epoch: utils.print_misclassified(sess, misclassified, self.data['labels']) # save session in a checkpoint file if self.write_checkpoints or is_last_epoch: utils.save_session_to_checkpoint_file( sess, saver, epoch, save_ckpt_dir)