def train_ops(): # Get training parameters data_dir = FLAGS.data_dir batch_size = FLAGS.batch_size learning_rate = FLAGS.learning_rate # Create global step counter global_step = tf.Variable(0, name='global_step', trainable=False) # Instantiate async producers for images and labels images, labels = data.train_inputs(data_dir=data_dir) # Instantiate the model model = select.by_name(FLAGS.model) # Create a 'virtual' graph node based on images that represents the input # node to be used for graph retrieval inputs = tf.identity(images, 'inputs') # Build a Graph that computes the logits predictions from the # inference model logits = model.inference(inputs) # In the same way, create a 'virtual' node for outputs outputs = tf.identity(logits, 'predictions') # Calculate loss loss = model.loss(logits, labels) # Evaluate training accuracy accuracy = model.accuracy(logits, labels) # Attach a scalar summary only to the total loss tf.summary.scalar('loss', loss) tf.summary.scalar('batch accuracy', accuracy) # Note that for debugging purpose, we could also track other losses #for l in tf.get_collection('losses'): # tf.summary.scalar(l.op.name, l) # Build a graph that applies gradient descent to update model parameters optimizer = tf.train.GradientDescentOptimizer(learning_rate) sgd_op = optimizer.minimize(loss, global_step=global_step) # Build yet another graph to evaluate moving averages of variables after # each step: these smoothed parameters will be loaded instead of the raw # trained values during evaluation variable_averages = \ tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step) variables_averages_op = variable_averages.apply(tf.trainable_variables()) # For batch normalization, we also need to update some variables update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # Create a meta-graph that includes sgd and variables moving average with tf.control_dependencies([sgd_op, variables_averages_op] + update_ops): train_op = tf.no_op(name='train') # Build another graph to provide training summary information summary_op = tf.summary.merge_all() return (train_op, loss, accuracy, summary_op)
def main(argv=None): # Instantiate the model model = select.by_name(FLAGS.model) if FLAGS.data_aug: images = tf.zeros((1, 24, 24, 3)) else: images = tf.zeros((1, 32, 32, 3)) logits = model.inference(images) print("Model: %s" % FLAGS.model) print("Size : %.2f Millions of parameters" % (model.get_size() / 10**6)) print("Flops: %.2f Millions of operations" % (model.get_flops() / 10**6))
def evaluation_loop(): """Eval model accuracy at regular intervals""" with tf.Graph().as_default() as g: # Do we evaluate the net on the training data or the test data ? test_data = FLAGS.eval_data == 'test' # Get images and labels for CIFAR-10 images, labels = data.eval_inputs(test_data=test_data, data_dir=FLAGS.data_dir, batch_size=FLAGS.batch_size) # Instantiate the model model = select.by_name(FLAGS.model) # Force dropout to zero for evaluation model.dropout = 0.0 # Build a Graph that computes the logits predictions from the model logits = model.inference(images) # Calculate predictions (we are only interested in perfect matches, ie k=1) predictions_op = tf.nn.in_top_k(logits, labels, 1) # We restore moving averages instead of raw values # Note that at evaluation time, the decay parameter is not used variables_averages = \ tf.train.ExponentialMovingAverage(1.0) # 1.0 decay is unused variables_to_restore = variables_averages.variables_to_restore() # Instantiate a saver to restore model variables from checkpoint saver = tf.train.Saver(variables_to_restore) # Build the summary operation based on the TF collection of Summaries summary_op = tf.summary.merge_all() # Since we don't use a session, we need to write summaries ourselves run_dir = get_run_dir(FLAGS.log_dir, FLAGS.model) eval_dir = os.path.join(run_dir, 'eval', FLAGS.eval_data) tf.gfile.MakeDirs(eval_dir) summary_writer = tf.summary.FileWriter(eval_dir, g) # We need a checkpoint dir to restore model parameters checkpoint_dir = os.path.join(run_dir, 'train') last_step = 0 while True: global_step = evaluate(saver, checkpoint_dir, summary_writer, predictions_op, summary_op) if FLAGS.run_once or last_step == global_step: break last_step = global_step time.sleep(FLAGS.eval_interval_secs)
def save_weights(): """Saves CIFAR10 weights""" FLAGS.resume = True # Get saved weights, not new ones print(FLAGS.save_dir) run_dir = get_run_dir(FLAGS.save_dir, FLAGS.model) print('run_dir', run_dir) checkpoint_dir = os.path.join(run_dir, 'train') with tf.Graph().as_default() as g: # Get images and labels for CIFAR-10. images, labels = data.train_inputs(data_dir=FLAGS.data_dir) model = select.by_name(FLAGS.model, FLAGS, training=True) # Build a Graph that computes the logits predictions from the # inference model. logits = model.inference(images) print('Multiplicative depth', model.mult_depth()) saver = tf.train.Saver() with tf.Session() as sess: ckpt = tf.train.get_checkpoint_state(checkpoint_dir) if ckpt and ckpt.model_checkpoint_path: # Restores from checkpoint saver.restore(sess, ckpt.model_checkpoint_path) global_step = ckpt.model_checkpoint_path.split('/')[-1].split( '-')[-1] else: print('### ERROR No checkpoint file found###') print('ckpt_dir', checkpoint_dir) print('ckpt.model_checkpoint_path', ckpt.model_checkpoint_path) print('ckpt', ckpt) return # Save variables for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES): weight = (sess.run([var]))[0].flatten().tolist() filename = model._name_to_filename(var.name) dir_name = filename.rsplit('/', 1)[0] os.makedirs(dir_name, exist_ok=True) print("saving", filename) np.savetxt(str(filename), weight)
def optimize_model_for_inference(): """Optimizes CIFAR-10 model for inference""" FLAGS.resume = True # Get saved weights, not new ones run_dir = get_run_dir(FLAGS.log_dir, FLAGS.model) checkpoint_dir = os.path.join(run_dir, 'train') print('run_dir', run_dir) print('checkpoint dir', checkpoint_dir) ckpt = tf.train.get_checkpoint_state(checkpoint_dir) train_graph = os.path.join(checkpoint_dir, 'graph.pbtxt') frozen_graph = os.path.join(checkpoint_dir, 'graph_constants.pb') fused_graph = os.path.join(checkpoint_dir, 'fused_graph.pb') with tf.Session() as sess: # TODO this should be a placeholder, right? # Build a new inference graph, with variables to be restored from # training graph. IMAGE_SIZE = 24 if FLAGS.data_aug else 32 if FLAGS.batch_norm: images = tf.constant(1, dtype=tf.float32, shape=[1, IMAGE_SIZE, IMAGE_SIZE, 3]) else: images = tf.constant( 1, dtype=tf.float32, shape=[FLAGS.batch_size, IMAGE_SIZE, IMAGE_SIZE, 3]) model = select.by_name(FLAGS.model, FLAGS, training=False) # Create dummy input and output nodes images = tf.identity(images, 'XXX') logits = model.inference(images) logits = tf.identity(logits, 'YYY') if FLAGS.batch_norm: # Restore values from the trained model into corresponding variables in the # inference graph. ckpt = tf.train.get_checkpoint_state(checkpoint_dir) print('ckpt.model_checkpoint_path', ckpt.model_checkpoint_path) assert ckpt and ckpt.model_checkpoint_path, "No checkpoint found in {}".format( checkpoint_dir) saver = tf.train.Saver() saver.restore(sess, ckpt.model_checkpoint_path) # Write fully-assembled inference graph to a file, so freeze_graph can use it tf.io.write_graph(sess.graph, checkpoint_dir, 'inference_graph.pbtxt', as_text=True) # Freeze graph, converting variables to inline-constants in pb file constant_graph = os.path.join(checkpoint_dir, 'graph_constants.pb') freeze_graph.freeze_graph( input_graph=os.path.join(checkpoint_dir, 'inference_graph.pbtxt'), input_saver="", input_binary=False, input_checkpoint=ckpt.model_checkpoint_path, output_node_names='YYY', restore_op_name='save/restore_all', filename_tensor_name='save/Const:0', initializer_nodes=[], output_graph=os.path.join(checkpoint_dir, 'graph_constants.pb'), clear_devices=True) # Load frozen graph into a graph_def for optimize_lib to use with gfile.FastGFile(constant_graph, 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) sess.graph.as_default() tf.import_graph_def(graph_def, name='') # Optimize graph for inference, folding Batch Norm ops into conv/MM fused_graph_def = optimize_for_inference_lib.optimize_for_inference( input_graph_def=graph_def, input_node_names=['XXX'], output_node_names=['YYY'], placeholder_type_enum=dtypes.float32.as_datatype_enum, toco_compatible=False) print('Optimized for inference.') tf.io.write_graph(fused_graph_def, checkpoint_dir, name='fused_graph.pb', as_text=False) else: tf.io.write_graph(sess.graph, checkpoint_dir, 'fused_graph.pb', as_text=False)
def train_ops(): # Get training parameters data_dir = FLAGS.data_dir batch_size = FLAGS.batch_size # Create global step counter global_step = tf.Variable(0, name='global_step', trainable=False) # Instantiate async producers for images and labels images, labels = data.train_inputs(data_dir=data_dir) # Instantiate the model model = select.by_name(FLAGS.model, FLAGS, training=True) # Create a 'virtual' graph node based on images that represents the input # node to be used for graph retrieval inputs = tf.identity(images, 'XXX') # Build a Graph that computes the logits predictions from the # inference model logits = model.inference(inputs) print('Multiplicative depth', model.mult_depth()) # In the same way, create a 'virtual' node for outputs outputs = tf.identity(logits, 'YYY') # Calculate loss loss = model.loss(logits, labels) # Evaluate training accuracy accuracy = model.accuracy(logits, labels) # Attach a scalar summary only to the total loss tf.summary.scalar('loss', loss) tf.summary.scalar('batch accuracy', accuracy) # Note that for debugging purpose, we could also track other losses for l in tf.get_collection('losses'): tf.summary.scalar(l.op.name, l) learning_rate = 0.1 optimizer = tf.train.GradientDescentOptimizer(learning_rate) # Clip gradients to [-0.25, 0.25] if FLAGS.clip_grads: print("Clipping gradients to [-0.25, 0.25]") gvs = optimizer.compute_gradients(loss) capped_gvs = [] for grad, var in gvs: if grad is None: continue capped_gvs.append((tf.clip_by_value(grad, -0.25, 0.25), var)) sgd_op = optimizer.apply_gradients(capped_gvs, global_step=global_step) else: print("Not clipping gradients") sgd_op = optimizer.minimize(loss, global_step=global_step) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # Create a meta-graph that includes sgd and variables moving average with tf.control_dependencies([sgd_op] + update_ops): train_op = tf.no_op(name='train') return (train_op, loss, accuracy)