def main(argv=None): keep_probability = tf.placeholder(tf.float32, name="keep_probabilty") image = tf.placeholder(tf.float32, shape=[None, IMAGE_HEIGHT, IMAGE_WIDTH, 3], name="input_image") annotation = tf.placeholder(tf.int32,shape=[None, IMAGE_HEIGHT, IMAGE_WIDTH, 1], name="annotation") pred_annotation, pred_annotation_no0, logits = inference(image, keep_probability) tf.summary.image("input_image", image, max_outputs=2) tf.summary.image("ground_truth", tf.cast(annotation, tf.uint8), max_outputs=2) tf.summary.image("pred_annotation", tf.cast(pred_annotation, tf.uint8), max_outputs=2) labels = tf.squeeze(annotation, squeeze_dims=[3]) logits = FixLogitsWithIgnoreClass(logits, labels) # Calculate loss # class_weights = tf.constant([0.3, 4.7, 10., 0.56, 21.9, 0.4, 0.2]) class_weights = tf.constant([0.3, 2., 2., 0.5, 2., 0.5, 0.5]) onehot_labels = tf.one_hot(labels, depth = NUM_OF_CLASSESS) weights = tf.reduce_sum(class_weights * onehot_labels, axis=3) unweighted_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels, name="entropy") weighted_loss = unweighted_loss * weights loss = tf.reduce_mean(weighted_loss) # loss = tf.reduce_mean((tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels, name="entropy"))) tf.summary.scalar("entropy", loss) trainable_var = tf.trainable_variables() if FLAGS.debug: for var in trainable_var: utils.add_to_regularization_and_summary(var) train_op = train(loss, trainable_var) print("Setting up summary op...") summary_op = tf.summary.merge_all() print("Setting up image reader...") train_records, valid_records, test_records = scene_parsing.read_dataset(FLAGS.data_dir) print(len(train_records)) print(len(valid_records)) print(len(test_records)) print("Setting up dataset reader") image_options = {'resize': False, 'resize_size': IMAGE_SIZE} if FLAGS.mode == 'train': train_dataset_reader = dataset.BatchDatset(train_records, image_options) validation_dataset_reader = dataset.BatchDatset(valid_records, image_options) test_dataset_reader = dataset.BatchDatset(test_records, image_options) # if FLAGS.mode == 'test': # test_dataset_reader = dataset.BatchDatset(test_records, image_options) sess = tf.Session() print("Setting up Saver...") saver = tf.train.Saver(max_to_keep = 50) summary_writer = tf.summary.FileWriter(FLAGS.logs_dir, sess.graph) sess.run(tf.global_variables_initializer()) ckpt = tf.train.get_checkpoint_state(FLAGS.logs_dir) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) print("Model restored...") if FLAGS.mode == "train": for itr in xrange(MAX_ITERATION): train_images, train_annotations = train_dataset_reader.next_batch(FLAGS.batch_size) # train_annotations[train_images[:,:,:,0] > 90] = [0] # train_images[train_images[:,:,:,0] > 90] = [0, 0, 0] feed_dict = {image: train_images, annotation: train_annotations, keep_probability: 0.85} sess.run(train_op, feed_dict=feed_dict) # Debug # output = sess.run(logits, feed_dict=feed_dict) # print("logits shape : ", output.shape) # print("logits : ", logits[0][0][0]) # output = sess.run(weights, feed_dict=feed_dict) # print("weight shape : ", output.shape) # print("weight : ", output[0]) # output = sess.run(weighted_loss, feed_dict=feed_dict) # print("weighted_loss shape: ", output.shape) # print("weighted_loss : ", output) # output = sess.run(loss, feed_dict=feed_dict) # print("loss shape: ", output.shape) # print("loss : ", output) # Debug if (itr % 200 == 0) or (itr < 10): train_loss, summary_str = sess.run([loss, summary_op], feed_dict=feed_dict) print("Step: %d, Train_loss:%g" % (itr, train_loss)) valid_images, valid_annotations = validation_dataset_reader.next_batch(100) valid_loss = sess.run(loss, feed_dict={image: valid_images, annotation: valid_annotations, keep_probability: 1.0}) if (itr > 10): print("itr: %d, Validation_loss:%g" % (itr, valid_loss)) test_images, test_annotations = test_dataset_reader.next_batch(100) test_loss = sess.run(loss, feed_dict={image: test_images, annotation: test_annotations, keep_probability: 1.0}) print("itr: %d, Test_loss:%g" % (itr, test_loss)) summary_writer.add_summary(summary_str, itr) if (itr % 4000 == 0): valid_images, valid_annotations = validation_dataset_reader.next_batch(100) valid_loss = sess.run(loss, feed_dict={image: valid_images, annotation: valid_annotations, keep_probability: 1.0}) print("%s ---> Validation_loss: %g" % (datetime.datetime.now(), valid_loss)) test_images, test_annotations = test_dataset_reader.next_batch(100) test_loss = sess.run(loss, feed_dict={image: test_images, annotation: test_annotations, keep_probability: 1.0}) print("%s ---> Test_loss: %g" % (datetime.datetime.now(), test_loss)) saver.save(sess, FLAGS.logs_dir + "model.ckpt", itr) elif FLAGS.mode == "visualize": valid_images, valid_annotations, valid_files = validation_dataset_reader.get_random_batch(FLAGS.batch_size) pred = sess.run(pred_annotation, feed_dict={image: valid_images, annotation: valid_annotations, keep_probability: 1.0}) valid_annotations = np.squeeze(valid_annotations, axis=3) pred = np.squeeze(pred, axis=3) for itr in range(FLAGS.batch_size): utils.save_image(valid_images[itr].astype(np.uint8), FLAGS.vis_dir, name="inp_" + valid_files[itr]) utils.save_image(valid_annotations[itr].astype(np.uint8), FLAGS.vis_dir, name="gt_" + valid_files[itr]) utils.save_image(pred[itr].astype(np.uint8), FLAGS.vis_dir, name="pred_" + valid_files[itr]) print("Saved image: %d" % itr) elif FLAGS.mode == "test": # videoWriter = cv2.VideoWriter('test.avi', cv2.cv.CV_FOURCC('M', 'J', 'P', 'G'), 5, (IMAGE_WIDTH, IMAGE_HEIGHT), False) for itr in range(len(test_records)): # for itr in range(len(train_records)): test_images, test_annotations = test_dataset_reader.next_batch(1) # test_annotations[test_images[:,:,:,0] > 90] = [0] # test_images[test_images[:,:,:,0] > 90] = [0, 0, 0] pred, test_loss = sess.run([pred_annotation, loss], feed_dict={image: test_images, annotation: test_annotations, keep_probability: 1.0}) # valid_loss = sess.run(loss, feed_dict={image: valid_images, annotation: valid_annotations, # keep_probability: 1.0}) test_annotations = np.squeeze(test_annotations, axis=3) pred = np.squeeze(pred, axis=3) print(itr, 'loss:', test_loss) # videoWriter.write(pred[0].astype(np.uint8)) # utils.save_image(test_images[0].astype(np.uint8), FLAGS.vis_dir, name="inp_" + train_records[itr]['filename']) # utils.save_image(test_annotations[0].astype(np.uint8), FLAGS.vis_dir, name="gt_" + train_records[itr]['filename']) # utils.save_image(pred[0].astype(np.uint8), FLAGS.vis_dir, name="pred_" + train_records[itr]['filename']) utils.save_image(test_images[0].astype(np.uint8), FLAGS.vis_dir, name="inp_" + test_records[itr]['filename']) utils.save_image(test_annotations[0].astype(np.uint8), FLAGS.vis_dir, name="gt_" + test_records[itr]['filename']) utils.save_image(pred[0].astype(np.uint8), FLAGS.vis_dir, name="pred_" + test_records[itr]['filename'])
def main(argv=None): indexList = tf.constant(index) neighbourList = tf.constant(neighbour) keep_probability = tf.placeholder(tf.float32, name="keep_probabilty") image = tf.placeholder( tf.float32, shape=[None, IMAGE_HEIGHT_CROP, IMAGE_WIDTH_CROP, 3], name="input_image") annotation = tf.placeholder( tf.int32, shape=[None, IMAGE_HEIGHT_CROP, IMAGE_WIDTH_CROP, 1], name="annotation") pred_annotation, pred_annotation_no0, logits = inference( image, keep_probability) tf.summary.image("input_image", image, max_outputs=2) tf.summary.image("ground_truth", tf.cast(annotation, tf.uint8), max_outputs=2) tf.summary.image("pred_annotation", tf.cast(pred_annotation, tf.uint8), max_outputs=2) labels = tf.squeeze(annotation, squeeze_dims=[3]) fixedLogits = FixLogitsWithIgnoreClass(logits, labels) # Calculate loss class_weights = tf.constant([0.1, 1., 1., 0.1, 20., 0.1, 8., 8., 0.1]) onehot_labels = tf.one_hot(labels, depth=9) weights = tf.reduce_sum(class_weights * onehot_labels, axis=3) unweighted_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=fixedLogits, labels=labels, name="entropy") weighted_loss = unweighted_loss * weights loss = tf.reduce_mean(weighted_loss) loss_constraint = GetConstraintLoss(indexList, neighbourList, logits, image) loss = tf.add(loss, tf.reduce_mean(loss_constraint)) # loss = tf.reduce_mean((tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels, name="entropy"))) tf.summary.scalar("entropy", loss) trainable_var = tf.trainable_variables() if FLAGS.debug: for var in trainable_var: utils.add_to_regularization_and_summary(var) train_op = train(loss, trainable_var) print("Setting up summary op...") summary_op = tf.summary.merge_all() print("Setting up image reader...") train_records, valid_records, test_records = scene_parsing.read_dataset( FLAGS.data_dir) print(len(train_records)) print(len(valid_records)) print(len(test_records)) print("Setting up dataset reader") image_options = {'resize': False, 'resize_size': IMAGE_SIZE} if FLAGS.mode == 'train': train_dataset_reader = dataset.BatchDatset(train_records, image_options) validation_dataset_reader = dataset.BatchDatset(valid_records, image_options) if FLAGS.mode == 'test': test_dataset_reader = dataset.BatchDatset(test_records, image_options) sess = tf.Session() print("Setting up Saver...") saver = tf.train.Saver(max_to_keep=20) summary_writer = tf.summary.FileWriter(FLAGS.logs_dir, sess.graph) sess.run(tf.global_variables_initializer()) ckpt = tf.train.get_checkpoint_state(FLAGS.logs_dir) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) print("Model restored...") if FLAGS.timeline == True: options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() if FLAGS.mode == "train": lastTime = datetime.datetime.now() for itr in xrange(MAX_ITERATION): train_images, train_annotations = train_dataset_reader.next_batch( FLAGS.batch_size) train_images, train_annotations = RandomCrop( train_images, train_annotations) feed_dict = { image: train_images, annotation: train_annotations, keep_probability: 1.0 } if FLAGS.timeline == True: sess.run(train_op, options=options, run_metadata=run_metadata, feed_dict=feed_dict) fetched_timeline = tf_timeline.Timeline( run_metadata.step_stats) chrome_trace = fetched_timeline.generate_chrome_trace_format() with open('timeline/timeline_step_%d.json' % itr, 'w') as f: f.write(chrome_trace) else: sess.run(train_op, feed_dict=feed_dict) # Debug if (itr < 10): nowTime = datetime.datetime.now() print((nowTime - lastTime).seconds, 'sec, ', itr) lastTime = nowTime # output = sess.run(logits, feed_dict=feed_dict) # print("logits shape : ", output.shape) # print("logits : ", output) # output = sess.run(loss_constraint, feed_dict=feed_dict) # print("loss shape: ", output.shape) # print("logits : ", logits[0][0][0]) # output = sess.run(weights, feed_dict=feed_dict) # print("weight shape : ", output.shape) # print("weight : ", output[0]) # output = sess.run(weighted_loss, feed_dict=feed_dict) # print("weighted_loss shape: ", output.shape) # print("weighted_loss : ", output) # output = sess.run(loss, feed_dict=feed_dict) # print("loss shape: ", output.shape) # print("loss : ", output) # Debug if itr % 10 == 0: train_loss, summary_str = sess.run([loss, summary_op], feed_dict=feed_dict) print("Step: %d, Train_loss:%g" % (itr, train_loss)) summary_writer.add_summary(summary_str, itr) if (itr % 5000 == 0): valid_images, valid_annotations = validation_dataset_reader.next_batch( FLAGS.batch_size) valid_images, valid_annotations = RandomCrop( valid_images, valid_annotations) valid_loss = sess.run(loss, feed_dict={ image: valid_images, annotation: valid_annotations, keep_probability: 1.0 }) print("%s ---> Validation_loss: %g" % (datetime.datetime.now(), valid_loss)) saver.save(sess, FLAGS.logs_dir + "model.ckpt", itr) elif FLAGS.mode == "visualize": valid_images, valid_annotations, valid_files = validation_dataset_reader.get_random_batch( FLAGS.batch_size) pred = sess.run(pred_annotation, feed_dict={ image: valid_images, annotation: valid_annotations, keep_probability: 1.0 }) valid_annotations = np.squeeze(valid_annotations, axis=3) pred = np.squeeze(pred, axis=3) for itr in range(FLAGS.batch_size): utils.save_image(valid_images[itr].astype(np.uint8), FLAGS.vis_dir, name="inp_" + valid_files[itr]) utils.save_image(valid_annotations[itr].astype(np.uint8), FLAGS.vis_dir, name="gt_" + valid_files[itr]) utils.save_image(pred[itr].astype(np.uint8), FLAGS.vis_dir, name="pred_" + valid_files[itr]) print("Saved image: %d" % itr) elif FLAGS.mode == "test": # videoWriter = cv2.VideoWriter('test_newloss.avi', cv2.cv.CV_FOURCC('M', 'J', 'P', 'G'), 5, (IMAGE_WIDTH, IMAGE_HEIGHT), False) for itr in range(len(test_records)): # for itr in range(len(train_records)): test_images, test_annotations = test_dataset_reader.next_batch(1) pred = np.zeros_like(np.squeeze(test_annotations, axis=3)) for sub in range(IMAGE_WIDTH / IMAGE_WIDTH_CROP): sub_images, sub_annotations = SubCrop(test_images, test_annotations, sub) sub_pred = sess.run(pred_annotation, feed_dict={ image: sub_images, annotation: sub_annotations, keep_probability: 1.0 }) sub_pred = np.squeeze(sub_pred, axis=3) pred[:, :, sub * IMAGE_WIDTH_CROP:(sub + 1) * IMAGE_WIDTH_CROP] = sub_pred print(itr) test_annotations = np.squeeze(test_annotations, axis=3) # videoWriter.write(pred[0].astype(np.uint8)) # utils.save_image(test_images[0].astype(np.uint8), FLAGS.vis_dir, name="inp_" + train_records[itr]['filename']) # utils.save_image(test_annotations[0].astype(np.uint8), FLAGS.vis_dir, name="gt_" + train_records[itr]['filename']) # utils.save_image(pred[0].astype(np.uint8), FLAGS.vis_dir, name="pred_" + train_records[itr]['filename']) utils.save_image(test_images[0].astype(np.uint8), FLAGS.vis_dir, name="inp_" + test_records[itr]['filename']) utils.save_image(test_annotations[0].astype(np.uint8), FLAGS.vis_dir, name="gt_" + test_records[itr]['filename']) utils.save_image(pred[0].astype(np.uint8), FLAGS.vis_dir, name="pred_" + test_records[itr]['filename'])