def main(args): # Get the list of filenames and corresponding list of labels for training et validation if os.path.isfile(args.data_desc): p = Pre().loadData(args.data_desc) else: p = Pre().createData(args.data_desc, args.data_path, data_thresh=4, verbose=True) p.saveData(args.data_desc) # train_filenames, train_labels = list_images(args.train_dir) # val_filenames, val_labels = list_images(args.val_dir) for fold in [args.data_folds]: fold_no = 0 if args.data_augment: data_iterator = p.get_augmented_train_and_test_data( n_augmentations=args.n_augmentations, n_splits=args.data_folds, balance=args.data_oversample, augmentation_dir=args.augment_dir) else: data_iterator = p.get_cv_train_and_test_data(n_splits=args.data_folds, balance=args.data_oversample) for train_filenames, train_labels, val_filenames, val_labels in data_iterator: tf.reset_default_graph() # Count the fold iterations fold_no += 1 assert set(train_labels) == set(val_labels),\ "Train and val labels don't correspond:\n{}\n{}".format(set(train_labels), set(val_labels)) num_classes = len(set(train_labels)) unique_name = args.model_name # -------------------------------------------------------------------------- # In TensorFlow, you first want to define the computation graph with all the # necessary operations: loss, training op, accuracy... # Any tensor created in the `graph.as_default()` scope will be part of `graph` graph = tf.Graph() with graph.as_default(): # Standard preprocessing for VGG on ImageNet taken from here: # https://github.com/tensorflow/models/blob/master/research/slim/preprocessing/vgg_preprocessing.py # Also see the VGG paper for more details: https://arxiv.org/pdf/1409.1556.pdf # Preprocessing (for both training and validation): # (1) Decode the image from jpg format # (2) Resize the image so its smaller side is 256 pixels long def _parse_function(filename, label): image_string = tf.read_file(filename) image_decoded = tf.image.decode_jpeg(image_string, channels=3) # (1) image = tf.cast(image_decoded, tf.float32) smallest_side = 256.0 height, width = tf.shape(image)[0], tf.shape(image)[1] height = tf.to_float(height) width = tf.to_float(width) scale = tf.cond(tf.greater(height, width), lambda: smallest_side / width, lambda: smallest_side / height) new_height = tf.to_int32(height * scale) new_width = tf.to_int32(width * scale) resized_image = tf.image.resize_images(image, [new_height, new_width]) # (2) return resized_image, label # Preprocessing (for training) # (3) Take a random 224x224 crop to the scaled image # (4) Horizontally flip the image with probability 1/2 # (5) Substract the per color mean `VGG_MEAN` # Note: we don't normalize the data here, as VGG was trained without normalization def training_preprocess(image, label): crop_image = tf.random_crop(image, [224, 224, 3]) # (3) flip_image = tf.image.random_flip_left_right(crop_image) # (4) means = tf.reshape(tf.constant(VGG_MEAN), [1, 1, 3]) centered_image = flip_image - means # (5) return centered_image, label # Preprocessing (for validation) # (3) Take a central 224x224 crop to the scaled image # (4) Substract the per color mean `VGG_MEAN` # Note: we don't normalize the data here, as VGG was trained without normalization def val_preprocess(image, label): crop_image = tf.image.resize_image_with_crop_or_pad(image, 224, 224) # (3) means = tf.reshape(tf.constant(VGG_MEAN), [1, 1, 3]) centered_image = crop_image - means # (4) return centered_image, label # ---------------------------------------------------------------------- # DATASET CREATION using tf.contrib.data.Dataset # https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/data # The tf.contrib.data.Dataset framework uses queues in the background to feed in # data to the model. # We initialize the dataset with a list of filenames and labels, and then apply # the preprocessing functions described above. # Behind the scenes, queues will load the filenames, preprocess them with multiple # threads and apply the preprocessing in parallel, and then batch the data # Training dataset train_dataset = tf.contrib.data.Dataset.from_tensor_slices((train_filenames, train_labels)) train_dataset = train_dataset.map(_parse_function, num_threads=args.num_workers, output_buffer_size=args.batch_size) train_dataset = train_dataset.map(training_preprocess, num_threads=args.num_workers, output_buffer_size=args.batch_size) train_dataset = train_dataset.shuffle(buffer_size=10000) # don't forget to shuffle batched_train_dataset = train_dataset.batch(args.batch_size) # Validation dataset val_dataset = tf.contrib.data.Dataset.from_tensor_slices((val_filenames, val_labels)) val_dataset = val_dataset.map(_parse_function, num_threads=args.num_workers, output_buffer_size=args.batch_size) val_dataset = val_dataset.map(val_preprocess, num_threads=args.num_workers, output_buffer_size=args.batch_size) batched_val_dataset = val_dataset.batch(args.batch_size) # Now we define an iterator that can operator on either dataset. # The iterator can be reinitialized by calling: # - sess.run(train_init_op) for 1 epoch on the training set # - sess.run(val_init_op) for 1 epoch on the valiation set # Once this is done, we don't need to feed any value for images and labels # as they are automatically pulled out from the iterator queues. # A reinitializable iterator is defined by its structure. We could use the # `output_types` and `output_shapes` properties of either `train_dataset` # or `validation_dataset` here, because they are compatible. iterator = tf.contrib.data.Iterator.from_structure(batched_train_dataset.output_types, batched_train_dataset.output_shapes) images, labels = iterator.get_next() train_init_op = iterator.make_initializer(batched_train_dataset) val_init_op = iterator.make_initializer(batched_val_dataset) # Indicates whether we are in training or in test mode is_training = tf.placeholder(tf.bool) # --------------------------------------------------------------------- # Now that we have set up the data, it's time to set up the model. # For this example, we'll use VGG-16 pretrained on ImageNet. We will remove the # last fully connected layer (fc8) and replace it with our own, with an # output size num_classes=8 # We will first train the last layer for a few epochs. # Then we will train the entire model on our dataset for a few epochs. # Get the pretrained model, specifying the num_classes argument to create a new # fully connected replacing the last one, called "vgg_16/fc8" # Each model has a different architecture, so "vgg_16/fc8" will change in another model. # Here, logits gives us directly the predicted scores we wanted from the images. # We pass a scope to initialize "vgg_16/fc8" weights with he_initializer vgg = tf.contrib.slim.nets.vgg with slim.arg_scope(vgg.vgg_arg_scope(weight_decay=args.weight_decay)): logits, _ = vgg.vgg_16(images, num_classes=num_classes, is_training=is_training, dropout_keep_prob=args.dropout_keep_prob) # Specify where the model checkpoint is (pretrained weights). model_path = args.model_path assert(os.path.isfile(model_path)) # Restore only the layers up to fc7 (included) # Calling function `init_fn(sess)` will load all the pretrained weights. variables_to_restore = tf.contrib.framework.get_variables_to_restore(exclude=['vgg_16/fc8']) init_fn = tf.contrib.framework.assign_from_checkpoint_fn(model_path, variables_to_restore) # Initialization operation from scratch for the new "fc8" layers # `get_variables` will only return the variables whose name starts with the given pattern fc8_variables = tf.contrib.framework.get_variables('vgg_16/fc8') fc8_init = tf.variables_initializer(fc8_variables) # --------------------------------------------------------------------- # Using tf.losses, any loss is added to the tf.GraphKeys.LOSSES collection # We can then call the total loss easily with tf.name_scope('Loss'): tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) loss = tf.losses.get_total_loss() tf.summary.scalar("train", loss, collections=['train']) tf.summary.scalar("valid", loss, collections=['valid']) # First we want to train only the reinitialized last layer fc8 for a few epochs. # We run minimize the loss only with respect to the fc8 variables (weight and bias). fc8_optimizer = tf.train.GradientDescentOptimizer(args.learning_rate1) fc8_train_op = fc8_optimizer.minimize(loss, var_list=fc8_variables) # Then we want to finetune the entire model for a few epochs. # We run minimize the loss only with respect to all the variables. full_optimizer = tf.train.GradientDescentOptimizer(args.learning_rate2) full_train_op = full_optimizer.minimize(loss) # Evaluation metrics with tf.name_scope('Accuracy'): prediction = tf.to_int32(tf.argmax(logits, 1)) correct_prediction = tf.equal(prediction, tf.cast(labels, tf.int32)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) # tf.summary.scalar("train", accuracy, collections=['train']) # tf.summary.scalar("valid", accuracy, collections=['valid']) merged_summary_op = tf.cond( is_training, lambda: tf.summary.merge_all('train'), lambda: tf.summary.merge_all('valid')) tf.get_default_graph().finalize() # -------------------------------------------------------------------------- # Now that we have built the graph and finalized it, we define the session. # The session is the interface to *run* the computational graph. # We can call our training operations with `sess.run(train_op)` for instance with tf.Session(graph=graph) as sess: init_fn(sess) # load the pretrained weights sess.run(fc8_init) # initialize the new fc8 layer # Create a log writer, run 'tensorboard --logdir=./logs/vgg_logs' dirs = ["./logs/vgg_logs/", str(num_classes), unique_name, str(args.data_folds), str(fold_no)] run_log_dir = os.path.join(*dirs) summary_writer = tf.summary.FileWriter(run_log_dir, graph=tf.get_default_graph()) summary_writer.add_graph(sess.graph) # Update only the last layer for a few epochs. step = 0 for epoch in range(1, args.num_epochs1 + 1): # Run an epoch over the training data. print('Starting epoch %d / %d' % (epoch, args.num_epochs1)) # Here we initialize the iterator with the training set. # This means that we can go through an entire epoch until the iterator becomes empty. sess.run(train_init_op) while True: try: _, summary = sess.run( [fc8_train_op, merged_summary_op], feed_dict={is_training: True}) except tf.errors.OutOfRangeError: break step += 1 # # Initialize the correct dataset # sess.run(val_init_op) # while True: # try: # acc, summary_2 = sess.run([accuracy, merged_summary_op], {is_training: False}) # print('Validation accuracy: %f' % acc) # except tf.errors.OutOfRangeError: # break # Check accuracy on the train and val sets every epoch. train_acc = check_accuracy(sess, correct_prediction, is_training, train_init_op) val_acc = check_accuracy(sess, correct_prediction, is_training, val_init_op) print('Train accuracy: %f' % train_acc) print('Val accuracy: %f\n' % val_acc) # Create a new Summary for accuracy summary_2 = tf.Summary() summary_2.value.add(tag="accuracy/validation", simple_value=val_acc) summary_2.value.add(tag="accuracy/training", simple_value=train_acc) # Add it to the Tensorboard summary writer summary_writer.add_summary(summary_2, epoch) # Write logs at every epoch summary_writer.add_summary(summary, epoch) # summary_writer.add_summary(summary_2, epoch) summary_writer.flush()
def main(args): # Initialise data paths if os.path.isfile(args.data_desc): p = Pre().loadData(args.data_desc) else: exit(1) # p = Pre().createData(args.data_desc, args.data_path, data_thresh=4, verbose=True) # p.saveData(args.data_desc) # Cross-validation iterator with optional augmentation, # number of folds and training data oversampling if args.data_augment: data_iterator = p.get_augmented_train_and_test_data( n_augmentations=args.n_augmentations, n_splits=args.data_folds, balance=args.data_oversample, augmentation_dir=args.augment_dir) else: data_iterator = p.get_cv_train_and_test_data( n_splits=args.data_folds, balance=args.data_oversample) fold_no = 0 for train_path, train_label, test_path, test_label in data_iterator: # Reset graph for each iteration tf.reset_default_graph() # Count the fold iterations fold_no += 1 # Ensure same labels for train and validation data assert set(train_label) == set(test_label),\ "Train and val labels don't correspond:\n{}\n{}".format(set(train_label), set(test_label)) # Get number of classes from data labels num_classes = len(set(train_label)) # Build the training data input X, Y = read_images(train_path, train_label, args, is_training=True) # Build the validation data input X_V, Y_V = read_images(test_path, test_label, args, is_training=False) # Create a graph for training logits_train = conv_net(X, num_classes, args, reuse=False, is_training=True) # Create a graph for testing that reuses the same weights and has no dropout logits_test = conv_net(X, num_classes, args, reuse=True, is_training=False) # Create a graph for validation that reuses the same weights and has no dropout logits_validation = conv_net(X_V, num_classes, args, reuse=True, is_training=False) with tf.name_scope("loss"): # Define loss and optimizer (with train logits, for dropout to take effect) loss_op = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits( logits=logits_train, labels=Y)) tf.summary.scalar("training", loss_op) # Define loss of validation validation_op = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits( logits=logits_validation, labels=Y_V)) tf.summary.scalar("validation", validation_op) optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate) train_op = optimizer.minimize(loss_op) with tf.name_scope("accuracy"): # Evaluate model (with test logits, for dropout to be disabled) correct_pred = tf.equal(tf.argmax(logits_test, 1), tf.cast(Y, tf.int64)) accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) tf.summary.scalar("training", accuracy) # Evaluate model (with validation logits, for dropout to be disabled) correct_pred_v = tf.equal(tf.argmax(logits_validation, 1), tf.cast(Y_V, tf.int64)) accuracy_valid = tf.reduce_mean(tf.cast(correct_pred_v, tf.float32)) tf.summary.scalar("validation", accuracy_valid) # Initialize the variables (i.e. assign their default value) init = tf.global_variables_initializer() # Start training with tf.Session() as sess: # Create a log writer, run 'tensorboard --logdir=./logs/nn_logs' dirs = ["./logs/nn_logs/", str(num_classes), args.model_name, str(args.data_folds), str(fold_no)] run_log_dir = os.path.join(*dirs) writer = tf.summary.FileWriter(run_log_dir, sess.graph) merged = tf.summary.merge_all() desc = ("net type: " + "Basic CNN" + "\nname: " + str(args.model_name) + "\nn_classes: " + str(num_classes) + "\nimg_size: " + str(args.img_dim) + "\nbatch_mode: " + "shuffle" + "\nbatch_size: " + str(args.batch_size) + "\ndata_folds: " + str(args.data_folds) + "\nnum_epochs: " + str(args.num_epochs) + "\nlearning_rate: " + str(args.learning_rate) + "\ndropout_keep_prob: " + str(args.dropout_keep_prob) ) # Run the initializer sess.run(init) # Add run description with open(run_log_dir + '/run_info.txt', 'w') as f: f.write(desc) # Start the data queue coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) # Training cycle samples = len(train_path) batches_per_epoch = int(ceil(float(samples) / args.batch_size)) for step in range(1, batches_per_epoch * (args.num_epochs + 1)): # Track values every epoch if step % batches_per_epoch == 0: # Run optimization and calculate batch loss and accuracy _, loss, acc = sess.run([train_op, loss_op, accuracy]) # Run through an epoch of validation data y_true = [] y_pred = [] valid_acc = 0 for v in range(batches_per_epoch): valid, logits_valid, y_v = sess.run([accuracy_valid, tf.argmax(logits_validation, 1), Y_V]) y_true += map(lambda x: str(x), y_v.tolist()) y_pred += map(lambda x: str(x), logits_valid.tolist()) valid_acc += valid # Average validation batch accuracy valid_acc = valid_acc / batches_per_epoch # Validation data confusion matrix summary img_d_summary = plot_confusion_matrix( y_true, y_pred, map(lambda x: str(x), range(num_classes)), tensor_name='ConfusionMatrix') writer.add_summary(img_d_summary, step / batches_per_epoch) summary = sess.run(merged) writer.add_summary(summary, step / batches_per_epoch) print("Epoch=" + str(int(step / batches_per_epoch)) + ", Minibatch Loss= " + "{:.4f}".format(loss) + ", Validation Accuracy= " + "{:.4f}".format(valid_acc) + ", Training Accuracy= " + "{:.3f}".format(acc)) print("True batch labels= " + str(y_v) + "\n" + "Assigned labels= " + str(logits_valid)) else: # Only run the optimization op (backprop) sess.run(train_op) print("Optimization Finished!") # Free resources from session and stop queue threads coord.request_stop() coord.join(threads) sess.close()