def main(_): tv_utils.set_gpus_to_use() if FLAGS.input_image is None: logging.error("No input_image was given.") logging.info( "Usage: python demo.py --input_image data/test.png " "[--output_image output_image] [--logdir /path/to/weights] " "[--gpus GPUs_to_use] ") exit(1) if FLAGS.logdir is None: # Download and use weights from the MultiNet Paper if 'TV_DIR_RUNS' in os.environ: runs_dir = os.path.join(os.environ['TV_DIR_RUNS'], 'KittiSeg') else: runs_dir = 'RUNS' maybe_download_and_extract(runs_dir) logdir = os.path.join(runs_dir, default_run) else: logging.info("Using weights found in {}".format(FLAGS.logdir)) logdir = FLAGS.logdir # Loading hyperparameters from logdir hypes = tv_utils.load_hypes_from_logdir(logdir, base_path='hypes') logging.info("Hypes loaded successfully.") # Loading tv modules (encoder.py, decoder.py, eval.py) from logdir modules = tv_utils.load_modules_from_logdir(logdir) logging.info("Modules loaded successfully. Starting to build tf graph.") # Create tf graph and build module. with tf.Graph().as_default(): image_placehold = tf.placeholder(tf.float32, shape=[1, None, None, 3]) # build Tensorflow graph using the model from logdir prediction = core.build_inference_graph(hypes, modules, image=image_placehold) logging.info("Graph build successfully.") config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) saver = tf.train.Saver() # Load weights from logdir core.load_weights(logdir, sess, saver) # eval_dict,_=kitti_eval.evaluate_without_crf(hypes,sess,image_placehold,prediction) eval_dict, _ = kitti_eval.evaluate_test(hypes, sess, image_placehold, prediction) for name, value in eval_dict: logging.info(' %s %s : % 0.04f ' % (name, '(raw)', value))
def train_loop(myhypes=None): utils.set_gpus_to_use() try: import tensorvision.train import tensorflow_fcn.utils except ImportError: logging.error("Could not import the submodules.") logging.error("Please execute:" "'git submodule update --init --recursive'") exit(1) if tf.app.flags.FLAGS.hypes is None: logging.error("No hype file is given.") logging.info("Usage: python train.py --hypes hypes/KittiClass.json") exit(1) with open(myhypes, 'r') as f: logging.info("f: %s", f) hypes = commentjson.load(f) utils.load_plugins() if tf.app.flags.FLAGS.mod is not None: import ast mod_dict = ast.literal_eval(tf.app.flags.FLAGS.mod) dict_merge(hypes, mod_dict) if 'TV_DIR_RUNS' in os.environ: os.environ['TV_DIR_RUNS'] = os.path.join(os.environ['TV_DIR_RUNS'], 'KittiSeg') utils.set_dirs(hypes, tf.app.flags.FLAGS.hypes) utils._add_paths_to_sys(hypes) train.maybe_download_and_extract(hypes) logging.info("Initialize training folder") # TODO initialize the train folder and copy some arg files to it--------------------------------yu train.initialize_training_folder(hypes) logging.info("Start training") train.do_training(hypes, trainable_scopes=FLAGS.trainable_scopes, exclude_scopes=FLAGS.checkpoint_exclude_scopes, checkpoint_path=FLAGS.checkpoint_path)
def main(_): """Run main function.""" if FLAGS.hypes is None: logging.error("No hypes are given.") logging.error("Usage: tv-train --hypes hypes.json") exit(1) with open(tf.app.flags.FLAGS.hypes, 'r') as f: logging.info("f: %s", f) hypes = json.load(f) utils.set_gpus_to_use() utils.load_plugins() utils.set_dirs(hypes, tf.app.flags.FLAGS.hypes) logging.info("Initialize training folder") initialize_training_folder(hypes) maybe_download_and_extract(hypes) logging.info("Start training") do_training(hypes)
def main(_): tv_utils.set_gpus_to_use() if FLAGS.input_image is None: logging.error("No input_image was given.") logging.info( "Usage: python demo.py --input_image data/test.png " "[--output_image output_image] [--logdir /path/to/weights] " "[--gpus GPUs_to_use] ") exit(1) if FLAGS.logdir is None: # Download and use weights from the MultiNet Paper if 'TV_DIR_RUNS' in os.environ: runs_dir = os.path.join(os.environ['TV_DIR_RUNS'], 'KittiSeg') else: runs_dir = 'RUNS' maybe_download_and_extract(runs_dir) logdir = os.path.join(runs_dir, default_run) else: logging.info("Using weights found in {}".format(FLAGS.logdir)) logdir = FLAGS.logdir # Loading hyperparameters from logdir hypes = tv_utils.load_hypes_from_logdir(logdir, base_path='hypes') logging.info("Hypes loaded successfully.") # Loading tv modules (encoder.py, decoder.py, eval.py) from logdir modules = tv_utils.load_modules_from_logdir(logdir) logging.info("Modules loaded successfully. Starting to build tf graph.") # Create tf graph and build module. with tf.Graph().as_default(): # Create placeholder for input # image_pl = tf.placeholder(tf.float32) # image = tf.expand_dims(image_pl, 0) image_placehold= tf.placeholder(tf.float32,shape=[1,None,None,3]) # build Tensorflow graph using the model from logdir prediction = core.build_inference_graph(hypes, modules, image=image_placehold) logging.info("Graph build successfully.") # Create a session for running Ops on the Graph. config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) saver = tf.train.Saver() # Load weights from logdir core.load_weights(logdir, sess, saver) logging.info("Weights loaded successfully.") start_time = time.time() logging.info('process start') input_image = FLAGS.input_image logging.info("Starting inference using {} as input".format(input_image)) # Load and resize input image image = scp.misc.imread(input_image) image= scp.misc.imresize(image,[150,150,3]) # image2 = scp.misc.imread('/home/yu/projects/KittiSeg/data/demo/demo2.png') if hypes['jitter']['reseize_image']: # Resize input only, if specified in hypes image_height = hypes['jitter']['image_height'] image_width = hypes['jitter']['image_width'] image = scp.misc.imresize(image, size=(image_height, image_width), interp='cubic') # Run KittiSeg model on image image_to_input=np.expand_dims(image,axis=0) feed = {image_placehold: image_to_input} softmax = prediction['softmax'] output = sess.run([softmax], feed_dict=feed) logging.info('Finished in {}s\n'.format( time.time() - start_time)) start_time = time.time() print ('the output.shape =',output[0].shape) # TODO test time -----------------------------------------------------------yu for i in xrange(100): output = sess.run([softmax],feed_dict=feed) logging.info('Finished 100 time and average is in {}s\n'.format(( time.time() - start_time)/100.0)) # TODO test time ----------------------------------------------------------yu # Reshape output from flat vector to 2D Image shape = image.shape print ('the image.shape =',shape) #TODO if use CRF-----------------------------------------------------------------------------------yu use_crf=False street_prediction=np.empty([int(shape[0])*int(shape[1]),1]) if use_crf: output_image = output[0][:, :].reshape(shape[0], shape[1], 2) # output_image = post_crf.post_process_crf(image, output_image, 2) else: output_image = output[0][:, 1].reshape(shape[0], shape[1]) # output_image = output[0][:, 1].reshape(shape[0], shape[1]) print('the first output_image.shape =', output_image.shape) # output_image=output[0][:,:].reshape(shape[0],shape[1],2) print ('the second output_image.shape =',output_image.shape) # output_image = my_array[:,:,:,1].reshape(shape[0],shape[1]) # print (output_image.shape) # print (output_image.shape) start_time=time.time() # Plot confidences as red-blue overlay # rb_image = seg.make_overlay(image, output_image) # res=post_crf.post_process_crf(image,output_image,2) output_image1=output[0][:,0].reshape(shape[0], shape[1]) street_prediction=output_image > output_image1 output_image2=output_image.reshape(shape[0],shape[1],-1) output_image2=output_image2*255.0 # Accept all pixel with conf >= 0.5 as positive prediction # This creates a `hard` prediction result for class street threshold = 0.5 street_prediction = output_image > threshold # street_prediction = res > threshold # Plot the hard prediction as green overlay green_image = tv_utils.fast_overlay(image, street_prediction) # Save output images to disk. if FLAGS.output_image is None: output_base_name = input_image else: output_base_name = FLAGS.output_image raw_image_name = output_base_name.split('.')[0] + '_raw.png' rb_image_name = output_base_name.split('.')[0] + '_rb.png' green_image_name = output_base_name.split('.')[0] + '_green.png' # scp.misc.imshow(rb_image) #TODO save the image--------------------------------------------------------------------------yu save_image=False if save_image: # np.save(FLAGS.logdir+'182000' + '.png',output_image) cv2.imwrite(FLAGS.logdir+'um_000031' + '.png', green_image) # scp.misc.imsave(FLAGS.logdir+'182000' + '.png', output_image2) scp.misc.imshow(green_image) scp.misc.imsave('prediction_umm_000036.png',green_image)