def test(): ''' Testing the CNN algorithm on -n files in set_directory/set0/ and saving results in event dictionaries. ''' with tf.Graph().as_default(): #Get images, image_set flags, and classification labels from testing event files (non-shuffled batches of 2000) images, sets, labels = SKinput.input(0, shuffle=False) #Compute CNN prediction and comparison to truth logits = SKgraph.inference(images, sets, 1.0) correct = SKgraph.correct(logits, labels) #Compute accuracy accuracy = SKgraph.accuracy(logits, labels) #Get images for Tensorboard output summary_images = SKinput.get_summary_filter(sets, labels, correct) real_images = None #Initialize saver object for reading CNN variables from checkpoint files saver = SKinput.Saver() #Initialize Variables and Session initialize = tf.initialize_all_variables() sess = tf.InteractiveSession(config=tf.ConfigProto(inter_op_parallelism_threads=FLAGS.num_cores, intra_op_parallelism_threads=FLAGS.num_cores)) sess.run(initialize) #Actually Begin Processing the Graph coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! #!!DO NOT CHANGE THE TENSORFLOW GRAPH AFTER CALLING start_queue_runners!! #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! #Track accuracy to monitor network performance on files tracker = SKgraph.Tracker(["Accuracy"]) #Load CNN variables from checkpoint files saver.restore_all(sess) #Define the file that will have every event dictionary written to it (temp file in the case of # "FLAGS.continue_session" so that the standard file and temp file can be combined afterwards) complete_info_file = os.path.join(FLAGS.run_directory, "complete_info.txt") temp_info_file = os.path.join(FLAGS.run_directory, "temp_info.txt") active_file = complete_info_file if not FLAGS.continue_session else temp_info_file if not FLAGS.continue_session and not FLAGS.print_tensorboard: #Will be appending JSON strings, so must overwrite the file here at the beginning of process open(active_file, "w").close() #Get the paths of all the "info" files containing JSON string dictionaries from the image processing runs info_paths = SKinput.get_files(0, "info") for info_path in info_paths: #Regular Testing if not FLAGS.print_tensorboard: #Break if file number exceeds user-defined limits if info_paths.index(info_path) >= FLAGS.num_iterations: break #Open file that saves all event dictionaries with complete information with open(active_file, "a") as complete_info: #Evaluate the relevant information for testing (algorithm output, correct classification, and algorithm accuracy) output, worked, acc_value = sess.run([logits, correct, accuracy]) #Cumulatively track and print accuracy for monitoring purposes tracker.add([acc_value]) tracker.print_average("Testing",reset=False) #Open info file and stream all the dictionaries into the complete_info file with Algorithm performance information added with open(info_path, "r") as info_file: #Iterate through info file and Tensorflow batch concurrenly (MUST be a 1-1 correspondence between these sources) #Correspondence is ensured by: # 1. setting the batch size in this case to the size of the info_files # 2. not shuffling the batches in Tensorflow input # 3. stiching the batches together properly after separating them to implement separate CNNs for prob, line, did_work in zip(output, info_file, worked): #Load dictionary and add algorithm performance information info = json.loads(line[:-1]) info["worked" + ("_" + FLAGS.regime_name if FLAGS.regime_name != "" else "")] = bool(did_work) info["algorithm" + ("_" + FLAGS.regime_name if FLAGS.regime_name != "" else "")] = [float(prob[0]), float(prob[1])] #Write dictionary to the complete_info file complete_info.write(json.dumps(info) + "\n") #Tensorboard output procedure else: #Evaluate images and Tensorboard filters, as well as CNN output for use in custom filter img, mask, output = sess.run([images, summary_images, logits]) #Extra boolean mask if decide to use custom filter extra_mask = [] if FLAGS.custom_cut: #Implement custom filter by iterating over event information dictionaries and checking if event passes custom cut with open(info_path, "r") as info_file: for line, prob in zip(info_file, output): #Get dictionary info = json.loads(line[:-1]) #Append boolean "passed cut" to the extra_mask filter #DEFINE CUSTOM FILTER HERE IF DESIRED extra_mask.append(bool(info["worked_fiTQun_ms"] and prob[0] < 0.5)) # Set content to custom expression (return False for events to be cut) else: #Initialize extra_mask so that all events pass the custom cut (equivalent to not having a cut at all) extra_mask = [True]*len(mask) #Add False elements to the extra_mask so that it is the same length as mask. #This is needed to avoid errors when the batch is larger than the number of lines read (e.g. in last file of an image set). extra_mask += [False]*(len(mask) - len(extra_mask)) #Combine masks and apply to the images. final_mask = np.logical_and(mask, extra_mask) batch = img[final_mask] #Group the images from different files into real_images object. if real_images is not None: real_images = np.concatenate((real_images, batch)) else: real_images = np.array(batch) #Evaluate number of images and print size = len(real_images) print size #Break if output image number has been reached or if there are no more files to search if size >= FLAGS.num_iterations or info_paths.index(info_path) == len(info_paths) - 1: #Create summary object from images and save the images in Tensorboard format image_summary = SKinput.get_summary(real_images) SKinput.write(sess, image_summary) break #If doing testing, and --continue is called, combine all the dictionaries in the temporary output file with those in the previous output file if FLAGS.continue_session and not FLAGS.print_tensorboard: SKinput.combine_info_files(complete_info_file, temp_info_file) #Wrap up coord.request_stop() coord.join(threads)
def train(data_set): ''' Defines the training procedure for the CNN. :param data_set: Integer indicating which CNN to train, recall there is a separate CNN for each image set ''' with tf.Graph().as_default(): #Initialize global step variable that will be incrimented during training global_step = tf.Variable(0, trainable=False) #Get images and image_labels in random batches of size FLAGS.batch_size #These are just Tensor objects for now and will not actually be evaluated until sess.run(...) is called in the loop images, labels = SKinput.input(data_set) #Get output of the CNN with images as input logits = SKgraph.inference(images, data_set, FLAGS.dropout_prob) #Initialize saver object that takes care of reading and writing parameters to checkpoint files saver = SKinput.Saver() #Values and Operations to evaluate in each batch cost = SKgraph.cost(logits, labels) accuracy = SKgraph.accuracy(logits, labels) train_op = SKgraph.train(cost, saver, global_step) #Initialize all the Tensorflow Variables defined in appropriate networks, as well as the Tensorflow session object initialize = tf.initialize_all_variables() sess = tf.InteractiveSession(config=tf.ConfigProto(inter_op_parallelism_threads=FLAGS.num_cores, intra_op_parallelism_threads=FLAGS.num_cores)) sess.run(initialize) #Actually Begin Processing the Graph coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! #!!DO NOT CHANGE THE TENSORFLOW GRAPH AFTER CALLING start_queue_runners!! #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! #Initialize Tracker object that prints averages of quantities after every 20 batches tracker = SKgraph.Tracker(["Cost", "Accuracy"]) #Load network parameters from the most recent training session if desired if FLAGS.continue_session: saver.restore(sess, data_set) #Iterate over the desired number of batches for batch_num in range(1, FLAGS.num_iterations + 1): #Run the training step once and return real-number values for cost and accuracy _, cost_value, acc_value = sess.run([train_op, cost, accuracy]) assert not math.isnan(cost_value), 'Model diverged with cost = NaN' tracker.add([cost_value, acc_value]) #Periodically print cost and accuracy values to monitor training process if not batch_num % 20: tracker.print_average(batch_num) #Periodically save moving averages to checkpoint files if not batch_num % 100 or batch_num == FLAGS.num_iterations: saver.save(sess, data_set) #Wrap up coord.request_stop() coord.join(threads)
def train(data_set, plot=true): ''' Defines the training procedure for the CNN. :param data_set: Integer indicating which CNN to train, recall there is a separate CNN for each image set ''' with tf.Graph().as_default(): #Initialize global step variable that will be incrimented during training global_step = tf.Variable(0, trainable=False) #Get images and image_labels in random batches of size FLAGS.batch_size #These are just Tensor objects for now and will not actually be evaluated until sess.run(...) is called in the loop images, labels = SKinput.input(data_set) #Get output of the CNN with images as input logits, weight_list, middle_steps = SKgraph.inference( images, data_set, FLAGS.dropout_prob) W_conv1, W_fc1, W_fc2, W_fc3 = weight_list[0], weight_list[ 1], weight_list[2], weight_list[3] h_conv1, h_pool1, h_fc1, h_fc2 = middle_steps[0], middle_steps[ 1], middle_steps[2], middle_steps[3] #Initialize saver object that takes care of reading and writing parameters to checkpoint files saver = SKinput.Saver() #Values and Operations to evaluate in each batch cost = SKgraph.cost(logits, labels) accuracy = SKgraph.accuracy(logits, labels) train_op = SKgraph.train(cost, saver, global_step) #Initialize all the Tensorflow Variables defined in appropriate networks, as well as the Tensorflow session object initialize = tf.initialize_all_variables() sess = tf.InteractiveSession(config=tf.ConfigProto( inter_op_parallelism_threads=FLAGS.num_cores, intra_op_parallelism_threads=FLAGS.num_cores)) sess.run(initialize) # define a pyROOT TGraph to monitor the training train_gr = TGraph() j = 0 #Actually Begin Processing the Graph coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! #!!DO NOT CHANGE THE TENSORFLOW GRAPH AFTER CALLING start_queue_runners!! #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! #Initialize Tracker object that prints averages of quantities after every 20 batches tracker = SKgraph.Tracker(["Cost", "Accuracy"]) #Load network parameters from the most recent training session if desired if FLAGS.continue_session: saver.restore(sess, data_set) #Iterate over the desired number of batches for batch_num in range(1, FLAGS.num_iterations + 1): #Run the training step once and return real-number values for cost and accuracy _, cost_value, acc_value = sess.run([train_op, cost, accuracy]) assert not math.isnan(cost_value), 'Model diverged with cost = NaN' tracker.add([cost_value, acc_value]) #Periodically print cost and accuracy values to monitor training process if not batch_num % 20: tracker.print_average(batch_num) train_gr.SetPoint(j, batch_num, acc_value) j = j + 1 #Periodically save moving averages to checkpoint files if not batch_num % 100 or batch_num == FLAGS.num_iterations: saver.save(sess, data_set) # print "W_conv1: ", W_conv1.get_shape() # record what the trained filters looks like xarray = np.zeros(25) yarray = np.zeros(25) zarray = np.zeros(25) for y in range(5): for x in range(5): index = 5 * y + x xarray[index] = x yarray[index] = y x = array("d", xarray) y = array("d", yarray) wgraphs = [] # make a graph for each filter W_conv1_value = W_conv1.eval() for n in range(22): for j in range(5): for i in range(5): index = 5 * j + i zarray[index] = W_conv1_value[i, j, 0, n] z = array("d", zarray) wgraph = TGraph2D(25, x, y, z) wgraph.SetName("w%d" % n) wgraphs.append(wgraph) # open a file fout = TFile('filters.root', 'RECREATE') # write out the training graph train_gr.Write('train_gr') # write out the weights for n in range(22): wgraph = wgraphs[n] wgraph.Write('filter%d' % n) fout.Write() fout.Close() # plot how the image is processed through convolution and pooling with different filters if plot == true: SKplot.plot(images, labels, W_conv1, h_conv1, h_pool1) #Wrap up coord.request_stop() coord.join(threads)