def train(): """ Train unet using specified args: """ data_files, data_size = load_datafiles(FLAGS.tfrecords_prefix) images, labels, filenames = dataset_loader.inputs( data_files = data_files, image_size = FLAGS.image_size, batch_size = FLAGS.batch_size, num_epochs = FLAGS.num_epochs, train = True) logits = unet.build(images, FLAGS.num_classes, True) accuarcy = unet.accuracy(logits, labels) #load class weights if available if FLAGS.class_weights is not None: weights = np.load(FLAGS.class_weights) class_weight_tensor = tf.constant(weights, dtype=tf.float32, shape=[FLAGS.num_classes, 1]) else: class_weight_tensor = None loss = unet.loss(logits, labels, FLAGS.weight_decay_rate, class_weight_tensor) global_step = tf.Variable(0, name = 'global_step', trainable = False) train_op = unet.train(loss, FLAGS.learning_rate, FLAGS.learning_rate_decay_steps, FLAGS.learning_rate_decay_rate, global_step) init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) saver = tf.train.Saver() session_manager = tf.train.SessionManager(local_init_op = tf.local_variables_initializer()) sess = session_manager.prepare_session("", init_op = init_op, saver = saver, checkpoint_dir = FLAGS.checkpoint_dir) writer = tf.summary.FileWriter(FLAGS.checkpoint_dir + "/train_logs", sess.graph) merged = tf.summary.merge_all() coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess = sess, coord = coord) start_time = time.time() try: while not coord.should_stop(): step = tf.train.global_step(sess, global_step) _, loss_value, summary = sess.run([train_op, loss, merged]) writer.add_summary(summary, step) if step % 1000 == 0: acc_seg_value = sess.run([accuarcy]) epoch = step * FLAGS.batch_size / data_size duration = time.time() - start_time start_time = time.time() print('[PROGRESS]\tEpoch %d, Step %d: loss = %.2f, accuarcy = %.2f (%.3f sec)' % (epoch, step, loss_value, acc_seg_value, duration)) if step % 5000 == 0: print('[PROGRESS]\tSaving checkpoint') checkpoint_path = os.path.join(FLAGS.checkpoint_dir, 'unet.ckpt') saver.save(sess, checkpoint_path, global_step = step) except tf.errors.OutOfRangeError: print('[INFO ]\tDone training for %d epochs, %d steps.' % (FLAGS.num_epochs, step)) finally: # When done, ask the threads to stop. coord.request_stop() # Wait for threads to finish. coord.join(threads) writer.close() sess.close()
def evaluate(checkpoint_path, tfrecords_dir, image_size, output_dir): data_files, data_size = load_datafiles(tfrecords_dir) images, filenames = dataset_loader.inputs( data_files = data_files, image_size = image_size, batch_size = 1, num_epochs = 1, train = False) # labels = tf.stack([labels, labels], 3) # labels = tf.reshape(labels, (FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size, 2)) logits = unet.build(images, 2, False) predicted_images = unet.predict(logits, 1, image_size) # accuracy = unet.accuracy(logits, labels) init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) sess = tf.Session() sess.run(init_op) saver = tf.train.Saver() # if not tf.gfile.Exists(FLAGS.checkpoint_path + '.meta'): if not tf.gfile.Exists(checkpoint_path + '.meta'): raise ValueError("Can't find checkpoint file") else: print('[INFO ]\tFound checkpoint file, restoring model.') saver.restore(sess, checkpoint_path) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) # global_accuracy = 0.0 step = 0 try: while not coord.should_stop(): predicted_images_value, filenames_value = sess.run([predicted_images, filenames]) # global_accuracy += acc_seg_value maybe_save_images(predicted_images_value, filenames_value, output_dir) # print('[PROGRESS]\tAccuracy for current batch: %.5f' % (acc_seg_value)) step += 1 except tf.errors.OutOfRangeError: print('[INFO ]\tDone evaluating in %d steps.' % step) finally: # When done, ask the threads to stop. coord.request_stop() # global_accuracy = global_accuracy / step # print('[RESULT ]\tGlobal accuracy = %.5f' % (global_accuracy)) # Wait for threads to finish. coord.join(threads) sess.close()
def evaluate(): """ Eval unet using specified args: Note: restore the pretrained model from checkpoint!! """ data_files, data_size = load_datafiles(FLAGS.tfrecords_prefix) images, labels, filenames = dataset_loader.inputs( data_files=data_files, image_size=FLAGS.image_size, batch_size=FLAGS.batch_size, num_epochs=1, train=False) logits = unet.build(images, FLAGS.num_classes, False) predicted_images = unet.predict(logits, FLAGS.batch_size, FLAGS.image_size) accuracy = unet.accuracy(logits, labels) init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) sess = tf.Session() sess.run(init_op) saver = tf.train.Saver() if not tf.gfile.Exists(FLAGS.checkpoint_path + '.meta'): raise ValueError("Can't find checkpoint file") else: print('[INFO ]\tFound checkpoint file, restoring model.') saver.restore(sess, FLAGS.checkpoint_path) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) global_accuracy = 0.0 step = 0 try: while not coord.should_stop(): acc_seg_value, predicted_images_value, filenames_value = sess.run( [accuracy, predicted_images, filenames]) global_accuracy += acc_seg_value maybe_save_images(predicted_images_value, filenames_value) print('[PROGRESS]\tAccuracy for current batch: %.5f' % (acc_seg_value)) step += 1 except tf.errors.OutOfRangeError: print('[INFO ]\tDone evaluating in %d steps.' % step) finally: # When done, ask the threads to stop. coord.request_stop() global_accuracy = global_accuracy / step print('[RESULT ]\tGlobal accuracy = %.5f' % (global_accuracy)) # Wait for threads to finish. coord.join(threads) sess.close()
def build_encoder(input_height, input_width, input_var): # encoder = vgg16.build(input_height, input_width, input_var) encoder = unet.build(input_height, input_width, input_var) #set_pretrained_weights(encoder) return encoder