def freeze(): if os.path.exists(FLAGS.train_dir) == False: os.mkdir(FLAGS.train_dir) sess = tf.InteractiveSession() images = tf.placeholder(tf.float32, [None] + IMG_SHAPE, name='images') logits = model.mobilenet_v1(images, num_classes=NUM_CLASSES, depth_multiplier=DEPTH_MULTIPLIER, dropout_prob=0.0, is_training=False) # write a inference graph (for debug) # with tf.io.gfile.GFile(os.path.join(FLAGS.output_dir, 'inference_graph_before_quant.pb'), 'wb') as f: # f.write(sess.graph_def.SerializeToString()) quantize.create_eval_graph() # write a inference graph (for debug) # with tf.io.gfile.GFile(os.path.join(FLAGS.output_dir, 'inference_graph.pb'), 'wb') as f: # f.write(sess.graph_def.SerializeToString()) # write frozen graph saver = tf.train.Saver(tf.global_variables()) saver.restore(sess, os.path.join(FLAGS.train_dir, FLAGS.ckpt)) frozen_gd = tf.graph_util.convert_variables_to_constants( sess, sess.graph_def, [FLAGS.output_node]) tf.train.write_graph(frozen_gd, FLAGS.output_dir, FLAGS.frozen_pb_name, as_text=False)
def train_val_loop(ds_train=None, ds_val=None): """ create a train-validate loop params: ds_train: dataset for training, should be a batched tf.Dataset ds_val: dataset for validation, should be a batched tf.Dataset """ LR_DECAY_FACTOR = 0.94 LR_DECAY_STEPS = int(TRAIN_SIZE / BATCH_SIZE * 1.5) if FLAGS.quantize: LR_START = FLAGS.lr_finetune else: LR_START = FLAGS.lr_start ## create train dir if os.path.exists(FLAGS.train_dir) == False: os.mkdir(FLAGS.train_dir) ## start a new session sess = tf.InteractiveSession() ## dataset iterator ds_train_iterator = ds_train.make_initializable_iterator() next_train_images, next_train_labels = ds_train_iterator.get_next() # next_train_labels = tf.one_hot(next_train_labels, depth=NUM_CLASSES) # next_train_labels = tf.cast(next_train_labels, dtype=tf.int64) ds_train_iterator.initializer.run() ds_val_iterator = ds_val.make_initializable_iterator() next_val_images, next_val_labels = ds_val_iterator.get_next() # next_val_labels = tf.one_hot(next_val_labels, depth=NUM_CLASSES) # next_val_labels = tf.cast(next_val_labels, dtype=tf.int64) ds_val_iterator.initializer.run() ## images/labels placeholder images = tf.placeholder(tf.float32, [BATCH_SIZE] + IMG_SHAPE, name='images') labels = tf.placeholder(tf.int64, [ BATCH_SIZE, ], name='labels') ## build model logits = model.mobilenet_v1(images, num_classes=NUM_CLASSES, depth_multiplier=DEPTH_MULTIPLIER, dropout_prob=DROPOUT_PROB, is_training=True) ## create train_op # define loss_op loss_op = tf.losses.sparse_softmax_cross_entropy(labels, logits) # loss_op= tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=labels, name='loss') # define acc_op correct_pred = tf.equal(tf.argmax(logits, 1), labels) # self.accuracy = tf.reduce_mean(tf.cast(correct_predictions, "float")) acc_op = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) # create quantized training graph if FLAGS.quantize: quantize.create_training_graph(quant_delay=0) # config learning rate global_step = tf.train.get_or_create_global_step() learning_rate = tf.train.exponential_decay(LR_START, global_step, LR_DECAY_STEPS, LR_DECAY_FACTOR, staircase=True) # create train_op (global step add 1 here) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = tf.train.GradientDescentOptimizer(learning_rate).minimize( loss_op, global_step) ## create summary and merge # tf.summary.scalar('loss', loss_op) tf.summary.scalar('accuracy', acc_op) tf.summary.scalar('learning_rate', learning_rate) merged_summaries = tf.summary.merge_all() ## saver saver = tf.train.Saver(tf.global_variables()) ## writer train_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph) val_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph) ## initialize variables tf.global_variables_initializer().run() ## load checkpoint if FLAGS.start_ckpt: if FLAGS.is_first_finetuning: # first restore variables with ignore_missing_vars variables_to_restore = tf.contrib.slim.get_variables_to_restore() restore_fn = tf.contrib.slim.assign_from_checkpoint_fn( os.path.join(FLAGS.train_dir, FLAGS.start_ckpt), variables_to_restore, ignore_missing_vars=True) restore_fn(sess) # then, reset global step global_step_reset = tf.assign(global_step, 0) sess.run(global_step_reset) else: saver.restore(sess, os.path.join(FLAGS.train_dir, FLAGS.start_ckpt)) start_step = global_step.eval() for train_step in range(start_step, FLAGS.train_step_max + 1): # get current global_step curr_step = global_step.eval() # get data batch images_batch, labels_batch = sess.run( [next_train_images, next_train_labels]) # train train_acc, _, train_loss, train_summary = sess.run( [acc_op, train_op, loss_op, merged_summaries], feed_dict={ images: images_batch, labels: labels_batch }) train_writer.add_summary(train_summary, curr_step) print('Step: ', curr_step, 'Train Loss = ', train_loss) # validation if (curr_step != 0 and curr_step % FLAGS.val_step_interval == 0): #total_val_acc = 0 total_val_acc = [] acc_list = [] for i in range(0, VAL_SIZE, BATCH_SIZE): images_batch, labels_batch = sess.run( [next_val_images, next_val_labels]) val_acc, val_summary = sess.run([acc_op, merged_summaries], feed_dict={ images: images_batch, labels: labels_batch }) total_val_acc += [val_acc] # total_val_acc += val_acc * BATCH_SIZE / VAL_SIZE total_val_acc = np.mean(total_val_acc) val_writer.add_summary(val_summary, curr_step) print('Step: ', curr_step, 'Train Accuracy = ', train_acc) print('Step: ', curr_step, 'Validation Accuracy = ', total_val_acc) # save checkpoint periodically if (curr_step != 0 and curr_step % FLAGS.save_step_interval == 0): if FLAGS.quantize: ckpt_name = 'model.quant.ckpt' else: ckpt_name = 'model.ckpt' saver.save(sess, os.path.join(FLAGS.train_dir, ckpt_name), global_step=curr_step) print('Step: ', curr_step, 'Saving to ', FLAGS.train_dir)