コード例 #1
0
def main(args):

    if FLAGS.cfg_file:
        print('loading config setting')
        cfg_from_file(FLAGS.cfg_file, cfg)
    print_config(cfg)

    logger = log_helper.get_logger()
    logger.info("show information about {}:".format(FLAGS.model))
    if FLAGS.model == 'res50':
        model = Res50DispNet(cfg, logger)
    else:
        logger.error('wrong model type: {}'.format(FLAGS.model))
        sys.exit(-1)
コード例 #2
0
ファイル: trainer.py プロジェクト: jyh2005xx/csc586
def main():

    config = get_oicr_config()
    print_config(config)

    # set torch seed
    if config.set_seed:
        torch.manual_seed(0)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    torch.set_default_tensor_type(torch.cuda.FloatTensor)
    torch.set_printoptions(profile="full")
    torch.set_printoptions(threshold=5000)
    torch.set_printoptions(precision=10)

    # init data loader
    dataset_tr = Dataset(config, mode='train')
    dataset_va = Dataset(config, mode='valid')

    # init dataset metadata
    metadata_tr = PascalVOCMetaData(config, mode='train')
    metadata_va = PascalVOCMetaData(config, mode='val')
    # init network
    oicr = OICR(config)

    # init network trainer
    oicr_trainer = OICRTrainer(config, oicr)

    # # resume model
    # if config.resume:
    #     mist_trainer.resume(wsdd)

    # # wirte meta data if first time run
    # if not mist_trainer.resumed:
    #     mist_trainer.write_meta_data()

    # train model
    oicr_trainer.train(oicr, dataset_tr, dataset_va, metadata_tr, metadata_va)
コード例 #3
0
def main(args):

    if FLAGS.cfg_file:
        print('loading config setting')
        cfg_from_file(FLAGS.cfg_file, cfg)
    print_config(cfg)

    output_path = FLAGS.output_path
    mask_path = FLAGS.mask_path
    if not os.path.isdir(output_path):
        os.makedirs(output_path)
    if not os.path.isdir(mask_path):
        os.makedirs(mask_path)

    image_h = cfg.IMAGE_HEIGHT
    image_w = cfg.IMAGE_WIDTH

    logger = log_helper.get_logger()

    # We use our "load_graph" function
    logger.info("accessing tf graph")
    graph = load_graph(FLAGS.graph_name)

    if FLAGS.verbose:
        # We can verify that we can access the list of operations in the graph
        for op in graph.get_operations():
            logger.info(op.name)
            # prefix/Placeholder/inputs_placeholder
            # ...
            # prefix/Accuracy/predictions
        
    # We access the input and output nodes 
    input_img = graph.get_tensor_by_name('import/input/image:0')
    pred = graph.get_tensor_by_name('import/output/prob:0')

    # launch a Session
    with tf.Session(graph=graph) as sess:

        total_time_elapsed = 0.0

        for image, fname in instance_generator(FLAGS.sample_path):
            logger.info("predicting for {}".format(fname))

            begin_ts = time.time()
            feed_dict = {
                input_img: image[np.newaxis],
            }

            # Note: we didn't initialize/restore anything, everything is stored in the graph_def
            prediction = sess.run(pred, feed_dict=feed_dict)
            end_ts = time.time()
            logger.info("cost time: {} s".format(end_ts - begin_ts))
            total_time_elapsed += end_ts - begin_ts

            # output_image to verify
            output_fname = output_path + "/" + os.path.basename(fname)
            pred_img = np.reshape(prediction, (image_h, image_w, cfg.NUM_CLASSES))
            pred_prob = genPredProb(pred_img, cfg.NUM_CLASSES)
            ret = cv2.imwrite(output_fname, pred_prob)
            if not ret:
                logger.error('writing image to {} failed!'.format(output_fname))
                sys.exit(-1)

            # masking image
            mask_fname = mask_path + "/" + os.path.basename(fname)
            r, g, b = cv2.split(image.astype(np.uint8))
            cv_img = cv2.merge([b, g, r])
            masked = image_process.prob_mask(cv_img, pred_prob)
            ret = cv2.imwrite(mask_fname, masked)
            if not ret:
                logger.error('writing image to {} failed!'.format(output_fname))
                sys.exit(-1)

        print("total time elapsed: {} s".format(total_time_elapsed))
コード例 #4
0
ファイル: infer.py プロジェクト: rvarun7777/tf_depth
def main(args):

    if FLAGS.cfg_file:
        print('loading config setting')
        cfg_from_file(FLAGS.cfg_file, cfg)

    do_pp = FLAGS.do_pp
    if FLAGS.do_stereo:
        do_pp = False
        cfg.DO_STEREO = True
    else:
        cfg.DO_STEREO = False

    cfg.BATCH_SIZE = 1
    if do_pp:
        cfg.BATCH_SIZE = 2

    print_config(cfg)

    output_path = FLAGS.output_path
    if output_path != '':
        if not os.path.isdir(output_path):
            os.makedirs(output_path)

    logger = log_helper.get_logger()
    do_recon = FLAGS.recon_path != ''
    if do_recon:
        if FLAGS.stereo_path == '':
            logger.error("to do reconstruction, stereo_path has to be set!")
            sys.exit(-1)
        recon_path = FLAGS.recon_path
        if not os.path.isdir(recon_path):
            os.makedirs(recon_path)
    stereo_path = FLAGS.stereo_path

    if FLAGS.model == 'res50':
        model = Res50DispNet(cfg, logger)
    else:
        logger.error('wrong model type: {}'.format(FLAGS.model))
        sys.exit(-1)

    if FLAGS.use_avg:
        # get moving avg
        variable_averages = tf.train.ExponentialMovingAverage(cfg.MOVING_AVERAGE_DECAY)
        variables_to_restore = variable_averages.variables_to_restore()
        saver = tf.train.Saver(variables_to_restore)
    else:
        saver = tf.train.Saver(model.all_variables)

    with tf.Session() as sess:
        # restore model
        logger.info("restoring model ......")
        saver.restore(sess, FLAGS.ckpt_path)
        total_time_elapsed = 0.0

        aspect_ratio = float(cfg.IMAGE_WIDTH) / cfg.IMAGE_HEIGHT
        for image, fname in instance_generator(FLAGS.sample_path, cfg.IMAGE_WIDTH, cfg.IMAGE_HEIGHT,
                                               do_pp, stereo_path, cfg.DO_STEREO, do_recon):
            if cfg.DO_STEREO or do_recon:
                sample_name = fname[0]
                stereo_name = fname[1]
                logger.info("inference for {} & {}".format(fname[0], fname[1]))
                feed_dict = {
                    model.left_image: image[0],
                    model.right_image: image[1]
                }
                fname = sample_name
            else:
                logger.info("inference for {}".format(fname))
                feed_dict = {
                    model.left_image: image
                }

            begin_ts = time.time()

            if not do_recon:
                pre_disp = sess.run(model.left_disparity[0], feed_dict=feed_dict)
            else:
                pre_disp, recon, recon_diff = sess.run([model.left_disparity[0],
                                                        model.left_reconstruction[0],
                                                        model.left_recon_diff[0]],
                                                        feed_dict=feed_dict)
                recon = recon[0,:,:,:]
                recon_diff = recon_diff[0,:,:,:]

                #print pre_disp.shape
                #print recon.shape
                #print recon_diff.shape

            end_ts = time.time()
            logger.info("cost time: {} s".format(end_ts - begin_ts))
            total_time_elapsed += end_ts - begin_ts

            if do_pp:
                disp = post_process_disparity(pre_disp.squeeze())
            else:
                disp = pre_disp[0].squeeze()

            if FLAGS.resize_ratio != 0 and FLAGS.resize_ratio != 1:
                disp = cv2.resize(disp, (FLAGS.resize_ratio*cfg.IMAGE_WIDTH, FLAGS.resize_ratio*cfg.IMAGE_HEIGHT),
                                  interpolation=cv2.INTER_LINEAR)


            # output disparity
            if output_path != '':
                if do_pp:
                    output_fname = output_path + "/pp_" + os.path.basename(fname)
                else:
                    output_fname = output_path + "/" + os.path.basename(fname)

                plt.imsave(output_fname, disp, cmap=plt.cm.gray)

            if recon_path is not None:
                o_image = cv2.resize(image[0][0],
                                     (FLAGS.resize_ratio*cfg.IMAGE_WIDTH, FLAGS.resize_ratio*cfg.IMAGE_HEIGHT),
                                     interpolation=cv2.INTER_LINEAR)
                o_recon = cv2.resize(recon,
                                     (FLAGS.resize_ratio*cfg.IMAGE_WIDTH, FLAGS.resize_ratio*cfg.IMAGE_HEIGHT),
                                     interpolation=cv2.INTER_LINEAR)
                o_diff = cv2.resize(recon_diff,
                                    (FLAGS.resize_ratio*cfg.IMAGE_WIDTH, FLAGS.resize_ratio*cfg.IMAGE_HEIGHT),
                                    interpolation=cv2.INTER_LINEAR)

                whole_fig = plt.figure(figsize=(int(aspect_ratio*8), 8))
                gs = gridspec.GridSpec(2, 2)
                a = plt.subplot(gs[0, 0])
                b = plt.subplot(gs[1, 0])
                c = plt.subplot(gs[0, 1])
                d = plt.subplot(gs[1, 1])

                a.imshow(o_image)
                a.set_title('raw_image')
                plt.gca().get_xaxis().set_visible(False)
                plt.gca().get_yaxis().set_visible(False)

                b.imshow(disp, cmap=plt.cm.gray)
                b.set_title('disparity')
                plt.gca().get_xaxis().set_visible(False)
                plt.gca().get_yaxis().set_visible(False)

                c.imshow(o_recon)
                c.set_title('reconstruct')
                plt.gca().get_xaxis().set_visible(False)
                plt.gca().get_yaxis().set_visible(False)

                d.imshow(o_diff)
                d.set_title('recon_diff')
                #plt.tight_layout()
                plt.gca().get_xaxis().set_visible(False)
                plt.gca().get_yaxis().set_visible(False)

                output_fname = recon_path + "/" + os.path.basename(fname)
                plt.savefig(output_fname)

                # for release memory
                plt.clf()
                plt.close()

        print("total time elapsed: {} s".format(total_time_elapsed))
コード例 #5
0
def main(args):
    checkArgs()

    if FLAGS.cfg_file:
        print('loading config setting')
        cfg_from_file(FLAGS.cfg_file, cfg)
    if FLAGS.stereo_path != '':
        cfg.DO_STEREO = True
    else:
        cfg.DO_STEREO = False

    base_path = None
    title_str = "{:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}".format(
        'ratio', 'abs_rel_i', 'sq_rel_i', 'rmse_i', 'rmse_log_i', 'd1_all_i',
        'a1_i', 'a2_i', 'a3_i', 'abs_rel', 'sq_rel', 'rmse', 'rmse_log',
        'd1_all', 'a1', 'a2', 'a3')
    if FLAGS.base_path != '':
        base_path = FLAGS.base_path
        title_str = "{:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}".format(
            'ratio', 'abs_rel', 'sq_rel', 'rmse', 'rmse_log', 'd1_all', 'a1',
            'a2', 'a3', 'abs_rel_b', 'sq_rel_b', 'rmse_b', 'rmse_log_b',
            'd1_all_b', 'a1_b', 'a2_b', 'a3_b')

    stereo_path = FLAGS.stereo_path if cfg.DO_STEREO else None

    cfg.BATCH_SIZE = 1
    if FLAGS.do_pp and not cfg.DO_STEREO:
        cfg.BATCH_SIZE = 2

    print_config(cfg)

    if FLAGS.output_path != '':
        output_path = FLAGS.output_path
        if not os.path.isdir(output_path):
            os.mkdir(output_path)

    logger = log_helper.get_logger()
    if FLAGS.model == 'res50':
        model = Res50DispNet(cfg, logger)
    else:
        logger.error('wrong model type: {}'.format(FLAGS.model))
        sys.exit(-1)

    # get moving avg
    if FLAGS.use_avg:
        variable_averages = tf.train.ExponentialMovingAverage(
            cfg.MOVING_AVERAGE_DECAY)
        variables_to_restore = variable_averages.variables_to_restore()
        saver = tf.train.Saver(variables_to_restore)
    else:
        saver = tf.train.Saver(model.all_variables)

    total_time_elapsed = 0
    with tf.Session() as sess:
        # restore model
        logger.info("restoring model ......")
        saver.restore(sess, FLAGS.ckpt_path)

        rate_list = []
        rmse_inter_list = []
        rmse_log_inter_list = []
        abs_rel_inter_list = []
        sq_rel_inter_list = []
        d1_all_inter_list = []
        a1_inter_list = []
        a2_inter_list = []
        a3_inter_list = []
        rmse_list = []
        rmse_log_list = []
        abs_rel_list = []
        sq_rel_list = []
        d1_all_list = []
        a1_list = []
        a2_list = []
        a3_list = []

        for image, label, fname in instance_label_generator(
                FLAGS.sample_path,
                FLAGS.label_path,
                cfg.IMAGE_WIDTH,
                cfg.IMAGE_HEIGHT,
                FLAGS.do_pp,
                stereo_path,
                base_path=base_path):
            if cfg.DO_STEREO:
                sample_name = fname[0]
                stereo_name = fname[1]
                logger.info("testing for {} & {}".format(fname[0], fname[1]))
                feed_dict = {
                    model.left_image: image[0],
                    model.right_image: image[1]
                }
                fname = sample_name
            else:
                logger.info("testing for {}".format(fname))
                if base_path is None:
                    feed_dict = {model.left_image: image}
                else:
                    feed_dict = {model.left_image: image[0]}

            begin_ts = time.time()

            pre_disp = sess.run(model.left_disparity[0], feed_dict=feed_dict)

            end_ts = time.time()
            logger.info("cost time: {} s".format(end_ts - begin_ts))
            total_time_elapsed += end_ts - begin_ts

            if FLAGS.do_pp and not cfg.DO_STEREO:
                disp = post_process_disparity(pre_disp.squeeze())
            else:
                disp = pre_disp[0].squeeze()

            base_disp = None if base_path is None else image[-1]

            width = label.shape[1]
            focal = KITTI_FOCAL[width]
            base = KITTI_BASE
            rate, d1_all_inter, abs_rel_inter, sq_rel_inter, rmse_inter, rmse_log_inter, a1_inter, a2_inter, a3_inter, d1_all, abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3 = depth_metrics(
                label, disp, focal, base, base_disp)

            print(title_str)
            print(
                "{:10.3f}, {:10.3f}, {:10.3f}, {:10.3f}, {:10.3f}, {:10.3f}, {:10.3f}, {:10.3f}, {:10.3f}, {:10.3f}, {:10.3f}, {:10.3f}, {:10.3f}, {:10.3f}, {:10.3f}, {:10.3f}, {:10.3f}"
                .format(rate, abs_rel_inter, sq_rel_inter, rmse_inter,
                        rmse_log_inter, d1_all_inter, a1_inter, a2_inter,
                        a3_inter, abs_rel, sq_rel, rmse, rmse_log, d1_all, a1,
                        a2, a3))

            rate_list.append(rate)
            rmse_inter_list.append(rmse_inter)
            rmse_log_inter_list.append(rmse_log_inter)
            abs_rel_inter_list.append(abs_rel_inter)
            sq_rel_inter_list.append(sq_rel_inter)
            d1_all_inter_list.append(d1_all_inter)
            a1_inter_list.append(a1_inter)
            a2_inter_list.append(a2_inter)
            a3_inter_list.append(a3_inter)
            rmse_list.append(rmse)
            rmse_log_list.append(rmse_log)
            abs_rel_list.append(abs_rel)
            sq_rel_list.append(sq_rel)
            d1_all_list.append(d1_all)
            a1_list.append(a1)
            a2_list.append(a2)
            a3_list.append(a3)

            # output_image to verify
            if FLAGS.output_path != '':
                if FLAGS.do_pp and not cfg.DO_STEREO:
                    output_fname = output_path + "/pp_" + os.path.basename(
                        fname)
                else:
                    output_fname = output_path + "/" + os.path.basename(fname)
                plt.imsave(output_fname, disp, cmap=plt.cm.gray)

        rate_mean = np.array(rate_list).mean()
        rmse_inter_mean = np.array(rmse_inter_list).mean()
        rmse_log_inter_mean = np.array(rmse_log_inter_list).mean()
        abs_rel_inter_mean = np.array(abs_rel_inter_list).mean()
        sq_rel_inter_mean = np.array(sq_rel_inter_list).mean()
        d1_all_inter_mean = np.array(d1_all_inter_list).mean()
        a1_inter_mean = np.array(a1_inter_list).mean()
        a2_inter_mean = np.array(a2_inter_list).mean()
        a3_inter_mean = np.array(a3_inter_list).mean()
        rmse_mean = np.array(rmse_list).mean()
        rmse_log_mean = np.array(rmse_log_list).mean()
        abs_rel_mean = np.array(abs_rel_list).mean()
        sq_rel_mean = np.array(sq_rel_list).mean()
        d1_all_mean = np.array(d1_all_list).mean()
        a1_mean = np.array(a1_list).mean()
        a2_mean = np.array(a2_list).mean()
        a3_mean = np.array(a3_list).mean()

        print("============total metric============")
        print(title_str)
        print(
            "{:10.3f}, {:10.3f}, {:10.3f}, {:10.3f}, {:10.3f}, {:10.3f}, {:10.3f}, {:10.3f}, {:10.3f}, {:10.3f}, {:10.3f}, {:10.3f}, {:10.3f}, {:10.3f}, {:10.3f}, {:10.3f}, {:10.3f}"
            .format(rate_mean, abs_rel_inter_mean, sq_rel_inter_mean,
                    rmse_inter_mean, rmse_log_inter_mean, d1_all_inter_mean,
                    a1_inter_mean, a2_inter_mean, a3_inter_mean, abs_rel_mean,
                    sq_rel_mean, rmse_mean, rmse_log_mean, d1_all_mean,
                    a1_mean, a2_mean, a3_mean))

        print("total time elapsed: {} s".format(total_time_elapsed))
コード例 #6
0
ファイル: freeze_graph.py プロジェクト: rvarun7777/tf_depth
def main(args):
    if FLAGS.cfg_file:
        print('loading config setting')
        cfg_from_file(FLAGS.cfg_file, cfg)
    cfg.BATCH_SIZE = 1
    print_config(cfg)

    output_path = FLAGS.output_path
    if not os.path.isdir(output_path):
        os.makedirs(output_path)

    batch_size = 1
    image_h = cfg.IMAGE_HEIGHT
    image_w = cfg.IMAGE_WIDTH
    image_c = cfg.IMAGE_DEPTH
    output_name = FLAGS.output_name

    whole_graph_ext = 'pb' if FLAGS.whole_graph_bin else 'pbtxt'
    infer_graph_ext = 'pb' if FLAGS.infer_graph_bin else 'pbtxt'
    whole_graph_name = "{}_whole.{}".format(output_name, whole_graph_ext)
    infer_graph_name = "{}_infer.{}".format(output_name, whole_graph_ext)
    uff_graph_name = "{}_uff.{}".format(output_name, whole_graph_ext)
    output_graph_path = "{}/{}.{}".format(output_path, output_name, infer_graph_ext)
    output_uff_graph_path = "{}/{}_uff.{}".format(output_path, output_name, infer_graph_ext)
    print whole_graph_name
    print infer_graph_name
    print uff_graph_name
    print output_graph_path
    print output_uff_graph_path

    # We clear devices to allow TensorFlow to control on which device it will load operations
    clear_devices = True

    # Build graph
    logger = log_helper.get_logger()
    if FLAGS.model == 'sq':
        model = SQSegNet(cfg, logger)
    elif FLAGS.model == 'erf':
        model = ERFSegNet(cfg, logger)

    output_node_names = "output/prob"


    if FLAGS.restore_avg:
        # get moving avg
        variable_averages = tf.train.ExponentialMovingAverage(cfg.MOVING_AVERAGE_DECAY)
        variables_to_restore = variable_averages.variables_to_restore()
        saver = tf.train.Saver(variables_to_restore)
    else:
        saver = tf.train.Saver(model.all_variables)
    saver = tf.train.Saver(variables_to_restore)

    with tf.Session() as sess:
        # Load checkpoint
        whole_graph_def = sess.graph.as_graph_def()

        # fix whole_graph_def for bn
        for node in whole_graph_def.node:
            if node.op == 'RefSwitch':
                node.op = 'Switch'
                for index in xrange(len(node.input)):
                    if 'moving_' in node.input[index]:
                        node.input[index] = node.input[index] + '/read'
            elif node.op == 'AssignSub':
                node.op = 'Sub'
                if 'use_locking' in node.attr: del node.attr['use_locking']
            elif node.op == 'AssignAdd':
                node.op = 'Add'
                if 'use_locking' in node.attr: del node.attr['use_locking']

        print("%d ops in the whole graph." % len(whole_graph_def.node))

        tf.train.write_graph(whole_graph_def, output_path,
                             whole_graph_name, as_text=not FLAGS.whole_graph_bin)

        infer_graph_def = graph_util.extract_sub_graph(whole_graph_def, output_node_names.split(","))
        print("%d ops in the infer graph." % len(infer_graph_def.node))

        tf.train.write_graph(infer_graph_def, output_path,
                             infer_graph_name, as_text=not FLAGS.whole_graph_bin)


        # fix infer_graph_def for bn for converstion to tensorRT uff
        for node in infer_graph_def.node:
            name_fields = node.name.split('/')
            if name_fields[-2] == 'batchnorm':
                if name_fields[-1] == 'add':
                    for index in xrange(len(node.input)):
                        if 'cond/Merge' in node.input[index]:
                            node.input[index] = '/'.join(name_fields[:-2] + ['moving_variance', 'read'])
                if name_fields[-1] == 'mul_2':
                    for index in xrange(len(node.input)):
                        if 'cond/Merge' in node.input[index]:
                            node.input[index] = '/'.join(name_fields[:-2] + ['moving_mean', 'read'])

        uff_graph_def = graph_util.extract_sub_graph(infer_graph_def, output_node_names.split(","))
        print("%d ops in the uff graph." % len(uff_graph_def.node))

        tf.train.write_graph(uff_graph_def, output_path,
                             uff_graph_name, as_text=not FLAGS.whole_graph_bin)

        saver.restore(sess, FLAGS.ckpt_path)

        output_graph_def = graph_util.convert_variables_to_constants(
            sess, # The session is used to retrieve the weights
            whole_graph_def, # The graph_def is used to retrieve the nodes 
            output_node_names.split(",") # The output node names are used to select the usefull nodes
        ) 

        output_uff_graph_def = graph_util.convert_variables_to_constants(
            sess, # The session is used to retrieve the weights
            infer_graph_def, # The graph_def is used to retrieve the nodes 
            output_node_names.split(",") # The output node names are used to select the usefull nodes
        ) 

        # Finally we serialize and dump the output graph to the filesystem
        mode = "wb" if FLAGS.infer_graph_bin else "w"
        with tf.gfile.GFile(output_graph_path, mode) as f:
            if FLAGS.infer_graph_bin:
                f.write(output_graph_def.SerializeToString())
            else:
                f.write(str(output_graph_def))

        print("%d ops in the output graph." % len(output_graph_def.node))

        with tf.gfile.GFile(output_uff_graph_path, mode) as f:
            if FLAGS.infer_graph_bin:
                f.write(output_uff_graph_def.SerializeToString())
            else:
                f.write(str(output_uff_graph_def))

        print("%d ops in the output uff graph." % len(output_uff_graph_def.node))
コード例 #7
0
def main(args):
    if FLAGS.cfg_file:
        print('loading config setting')
        cfg_from_file(FLAGS.cfg_file, cfg)
    cfg.BATCH_SIZE = 1
    print_config(cfg)

    if FLAGS.output_dir and not os.path.isdir(FLAGS.output_dir):
        os.mkdir(FLAGS.output_dir)

    logger = log_helper.get_logger()

    with tf.Graph().as_default() as graph:

        if FLAGS.model == 'sq':
            model = SQSegNet(cfg, logger)
        elif FLAGS.model == 'erf':
            model = ERFSegNet(cfg, logger)

        if FLAGS.restore_avg:
            # get moving avg
            variable_averages = tf.train.ExponentialMovingAverage(
                cfg.MOVING_AVERAGE_DECAY)
            variables_to_restore = variable_averages.variables_to_restore()
            saver = tf.train.Saver(variables_to_restore)
        else:
            saver = tf.train.Saver(model.all_variables)

        graph_def = graph.as_graph_def()
        sub_graph_def = tf.graph_util.extract_sub_graph(
            graph_def, ['output/prob'])
        all_inference_nodes = sub_graph_def.node
        all_nodes = graph_def.node

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

        print('======================================================')
        print('all nodes:')
        print('\n'.join([x.name for x in all_nodes]))
        print('======================================================')
        print('all inference nodes:')
        print('\n'.join([x.name for x in all_inference_nodes]))
        print('======================================================')
        print('all update ops:')
        print('\n'.join([x.name for x in update_ops]))
        print('======================================================')

        with tf.Session() as sess:
            print("retoring")
            saver.restore(sess, FLAGS.ckpt_path)

            img = np.array(skimage.io.imread(FLAGS.input_image), np.float32)
            #imageValue = tf.read_file(FLAGS.input_image)
            #image_bytes = tf.image.decode_png(imageValue)
            #image_reshape = tf.reshape(image_bytes, (cfg.IMAGE_HEIGHT, cfg.IMAGE_WIDTH, cfg.IMAGE_DEPTH))
            #image = tf.cast(image_reshape, tf.float32)
            #img = sess.run(image)
            print(img)
            input_img_ph = graph.get_tensor_by_name('input/image:0')
            is_training = graph.get_tensor_by_name('input/is_training:0')
            feed_dict = {input_img_ph: [img], is_training: FLAGS.is_training}

            for node in all_inference_nodes:
                if FLAGS.tname:
                    if node.name == FLAGS.tname:
                        print(node.name)
                        tensor = graph.get_tensor_by_name(node.name + ':0')
                        try:
                            _tensor = sess.run(tensor, feed_dict=feed_dict)
                        except:
                            print("cannot be fetched!")
                            continue
                        print(_tensor)
                        print(_tensor.shape)
                        print(_tensor.max())
                        print(_tensor.min())
                    if node.name == 'output/prob':
                        tensor = graph.get_tensor_by_name(node.name + ':0')
                        _tensor = sess.run(tensor, feed_dict=feed_dict)
                        prob = _tensor.reshape(
                            (cfg.IMAGE_HEIGHT, cfg.IMAGE_WIDTH, 2))
                        output_img0 = (prob[:, :, 0] * 255).astype(np.uint8)
                        output_img1 = (prob[:, :, 1] * 255).astype(np.uint8)
                        cv2.imwrite(FLAGS.output_image0, output_img0)
                        cv2.imwrite(FLAGS.output_image1, output_img1)
                else:
                    print(node.name)
                    name_fields = node.name.split('/')
                    tensor = graph.get_tensor_by_name(node.name + ':0')
                    try:
                        _tensor = sess.run(tensor, feed_dict=feed_dict)
                    except:
                        print("cannot be fetched!")
                        continue
                    print(_tensor)
                    print(_tensor.shape)
                    print(_tensor.max())
                    print(_tensor.min())

                    print(
                        '======================================================'
                    )

                    if FLAGS.output_dir and node.name == 'output/prob':
                        #tensor = graph.get_tensor_by_name(node.name + ':0')
                        #_tensor = sess.run(tensor, feed_dict=feed_dict)
                        prob = _tensor.reshape(
                            (cfg.IMAGE_HEIGHT, cfg.IMAGE_WIDTH, 2))
                        output_img0 = (prob[:, :, 0] * 255).astype(np.uint8)
                        output_img1 = (prob[:, :, 1] * 255).astype(np.uint8)
                        cv2.imwrite('{}/out0.png'.format(FLAGS.output_dir),
                                    output_img0)
                        cv2.imwrite('{}/out1.png'.format(FLAGS.output_dir),
                                    output_img1)

                    if FLAGS.output_dir and output_layers and node.name in output_layers:
                        layer_shape = _tensor.shape
                        output_subdir = node.name.replace('/', '_')
                        output_path = '{}/{}'.format(FLAGS.output_dir,
                                                     output_subdir)
                        if not os.path.isdir(output_path):
                            os.mkdir(output_path)
                        for i in range(layer_shape[3]):
                            output_layer = _tensor[0, :, :, i]
                            output_layer_min = output_layer.min()
                            output_layer_max = output_layer.max()
                            output_layer_range = output_layer_max - output_layer_min
                            if output_layer_range < 1e-8:
                                output_layer[:, :] = 0
                            else:
                                output_layer = (output_layer - output_layer_min
                                                ) / output_layer_range * 255
                            output_layer = output_layer.astype(np.uint8)
                            cv2.imwrite('{}/{}.png'.format(output_path, i),
                                        output_layer)
コード例 #8
0
ファイル: train.py プロジェクト: rvarun7777/tf_depth
def main(args):
    checkArgs()

    if FLAGS.cfg_file:
        print('loading config setting')
        cfg_from_file(FLAGS.cfg_file, cfg)
    if FLAGS.debug:
        cfg.BATCH_SIZE = 2
        cfg.NUM_EPOCH = 1
        cfg.TRAIN_QUEUE_CAPACITY = 10
        cfg.DO_VALIDATE = False
        print('set to DEBUG mode')
    print_config(cfg)

    if cfg.DO_VALIDATE:
        val_step = np.ceil(
            float(cfg.NUM_EXAMPLES_PER_EPOCH_FOR_EVAL) /
            cfg.VAL_BATCH_SIZE).astype(np.int32)
    steps_per_epoch = np.ceil(float(cfg.NUM_EXAMPLES) / cfg.BATCH_SIZE).astype(
        np.int32)
    max_step = steps_per_epoch * cfg.NUM_EPOCH
    if FLAGS.debug:
        max_step = 1000
    ckpt_path = FLAGS.ckpt_path
    if not os.path.isdir(ckpt_path):
        os.mkdir(ckpt_path)

    logger = log_helper.get_logger()
    if FLAGS.model == 'res50':
        model = Res50DispNet(cfg, logger)
    else:
        logger.error('wrong model type: {}'.format(FLAGS.model))
        sys.exit(-1)

    data_pipeline = TFLoadingPipeline(cfg, logger)
    data_pipeline.setup(FLAGS.sample_path,
                        FLAGS.label_path,
                        cfg.BATCH_SIZE,
                        cfg.TRAIN_QUEUE_CAPACITY,
                        augment=True)
    if cfg.DO_VALIDATE:
        data_pipeline.setup_validate(FLAGS.val_sample_path,
                                     FLAGS.val_label_path, cfg.VAL_BATCH_SIZE,
                                     cfg.EVAL_QUEUE_CAPACITY)

    global_step = tf.Variable(0, trainable=False)

    train_op, grads = get_train_op(cfg, model.total_loss, global_step,
                                   max_step)

    if FLAGS.debug:
        nan_checks = []
        for grad, var in grads:
            nan_checks.append(tf.is_nan(grad))

        layers = model.layers[-8:]
        layers += model.left_reconstruction + model.right_reconstruction + model.left_disparity + model.right_disparity
        if cfg.WEIGHT_DECAY is not None:
            _, img_loss, img_l1_loss, img_ssim_loss, smth_loss, cons_loss, total_loss = model.model_losses
        else:
            img_loss, img_l1_loss, img_ssim_loss, smth_loss, cons_loss, total_loss = model.model_losses
        layer_grads_total = tf.gradients(total_loss, layers)
        layer_grads_img = tf.gradients(img_loss, layers)
        layer_grads_img_l1 = tf.gradients(img_l1_loss, layers)
        layer_grads_img_ssim = tf.gradients(img_ssim_loss, layers)
        layer_grads_smth = tf.gradients(smth_loss, layers)
        layer_grads_cons = tf.gradients(cons_loss, layers)
        layer_n_grads = zip(layers, layer_grads_img, layer_grads_img_l1,
                            layer_grads_img_ssim, layer_grads_smth,
                            layer_grads_cons, layer_grads_total)

    saver = tf.train.Saver(var_list=model.all_variables, max_to_keep=10)

    # add summaries
    images = [
        model.left_image, model.right_image, model.left_disparity[0],
        model.right_disparity[0], model.left_reconstruction[0],
        model.right_reconstruction[0], model.left_recon_diff[0],
        model.right_recon_diff[0]
    ]
    add_summaries(images, 'image', cfg)
    add_summaries(model.trainables, 'hist')
    add_summaries(model.bn_variables, 'hist')
    add_summaries(model.bn_mean_variance, 'hist')
    add_summaries(model.losses, 'scala')
    summary_op = tf.summary.merge_all()

    if cfg.DO_VALIDATE:
        # add evaluation metric summaries
        val_loss_summary_op, val_loss_summary_ph = add_metric_summary(
            'val_loss')
        metric_summary_op = tf.summary.merge([val_loss_summary_op])

    with tf.Session() as sess:
        total_start_time = time.time()
        summary_writer = tf.summary.FileWriter(ckpt_path, sess.graph)

        # initialize
        logger.info('initializing model params...')
        sess.run(model.initializer)

        data_pipeline.start(sess)
        start_time = time.time()
        for step in xrange(max_step):

            image_batch, label_batch = data_pipeline.load_batch()
            feed_dict = {
                model.left_image: image_batch,
                model.train_phase: True,
                model.right_image: label_batch
            }

            # run training
            if FLAGS.debug:
                nan_check_res = sess.run(nan_checks, feed_dict=feed_dict)
                for i, res in enumerate(nan_check_res[::-1]):
                    if res.any():
                        grad, var = grads[::-1][i]
                        print('{}\'s gradients has nan'.format(var.op.name))
                        grad_res, var_res = sess.run([grad, var],
                                                     feed_dict=feed_dict)
                        print grad_res.shape
                        print var_res.shape
                        print('========================================')
                        print(grad_res)
                        print('========================================')
                        print(var_res)
                        print('========================================')
                        #print('all grads & vars:')
                        #print('========================================')
                        #for grad, var in grads:
                        #grad_res, var_res = sess.run([grad, var], feed_dict=feed_dict)
                        #print('{} gradient'.format(var.op.name))
                        #print(grad_res)
                        #print('========================================')
                        #print(var.op.name)
                        #print(var_res)
                        print('all layers & grads:')
                        for var, grad_img, grad_img_l1, grad_img_ssim, grad_smth, grad_cons, grad_ttl in layer_n_grads:
                            print(var.op.name)
                            var_res = sess.run(var, feed_dict=feed_dict)
                            print(var_res.shape)
                            print(var_res)
                            print('========================================')
                            if grad_img is not None:
                                grad_img_res = sess.run(grad_img,
                                                        feed_dict=feed_dict)
                                print('{} gradient img'.format(var.op.name))
                                print(grad_img_res.shape)
                                print(grad_img_res)
                                print(
                                    '========================================')
                            if grad_img_l1 is not None:
                                grad_img_l1_res = sess.run(grad_img_l1,
                                                           feed_dict=feed_dict)
                                print('{} gradient img l1'.format(var.op.name))
                                print(grad_img_l1_res.shape)
                                print(grad_img_l1_res)
                                print(
                                    '========================================')
                            if grad_img_ssim is not None:
                                grad_img_ssim_res = sess.run(
                                    grad_img_ssim, feed_dict=feed_dict)
                                print('{} gradient img ssim'.format(
                                    var.op.name))
                                print(grad_img_ssim_res.shape)
                                print(grad_img_ssim_res)
                                print(
                                    '========================================')

                            if grad_smth is not None:
                                grad_smth_res = sess.run(grad_smth,
                                                         feed_dict=feed_dict)
                                print('{} gradient smth'.format(var.op.name))
                                print(grad_smth_res.shape)
                                print(grad_smth_res)
                                print(
                                    '========================================')
                            if grad_cons is not None:
                                grad_cons_res = sess.run(grad_cons,
                                                         feed_dict=feed_dict)
                                print('{} gradient cons'.format(var.op.name))
                                print(grad_cons_res.shape)
                                print(grad_cons_res)
                                print(
                                    '========================================')
                            if grad_ttl is not None:
                                grad_ttl_res = sess.run(grad_ttl,
                                                        feed_dict=feed_dict)
                                print('{} gradient total'.format(var.op.name))
                                print(grad_ttl_res.shape)
                                print(grad_ttl_res)
                                print(
                                    '========================================')

                        print('all losses:')
                        for loss in model.losses:
                            loss_res = sess.run(loss, feed_dict=feed_dict)
                            print('{}: {}'.format(loss.op.name, loss_res))

                        sys.exit(-1)

            _, loss, summary_str = sess.run(
                [train_op, model.total_loss, summary_op], feed_dict=feed_dict)

            if FLAGS.debug:
                summary_writer.add_summary(summary_str, step)

            # every 10 step, output metrics
            if step and step % 10 == 0:
                #pred, loss = sess.run([model.output, model.total_loss], feed_dict=feed_dict)

                duration = time.time() - start_time
                logger.info('step {}: {} sec elapsed, loss = {}'.format(
                    step, duration, loss))

                start_time = time.time()
            # every 100 step, do validation & write summary
            if not FLAGS.debug and step % 500 == 0:
                # write summary
                #summary_str = sess.run(summary_op, feed_dict=feed_dict)
                summary_writer.add_summary(summary_str, step)

                if cfg.DO_VALIDATE and step:
                    logger.info('start validating......')
                    total_val_loss = 0.0
                    for test_step in range(val_step):
                        val_image_batch, val_label_batch = data_pipeline.load_validate_batch(
                        )
                        val_feed_dict = {
                            model.left_image: val_image_batch,
                            model.right_image: val_label_batch
                        }

                        val_loss = sess.run(model.total_loss,
                                            feed_dict=val_feed_dict)
                        total_val_loss += val_loss

                    avg_loss = total_val_loss / val_step
                    logger.info("val loss: {}".format(avg_loss))

                    metric_summary_str = sess.run(
                        metric_summary_op,
                        feed_dict={val_loss_summary_ph: avg_loss})
                    summary_writer.add_summary(metric_summary_str, step)

                    logger.info(" end validating.... ")

            # every 10000 steps, save the model checkpoint
            if (step and step % 10000 == 0) or (step + 1) == max_step:
                checkpoint_path = os.path.join(
                    ckpt_path, '{}_model.ckpt'.format(FLAGS.model))
                saver.save(sess, checkpoint_path, global_step=step)

        # Done training
        logger.info('training complete')
        data_pipeline.shutdown()
        total_end_time = time.time()
        logger.info('total time elapsed: {} h'.format(
            (total_end_time - total_start_time) / 3600.0))