コード例 #1
0
    with open(args.vocabulary, 'r') as f:
        for line in f:
            vocab.append(line.strip())

    # get the image paths
    im_paths = glob.glob('./data/demo/*.jpg')
    print(im_paths)

    # read checkpoint file
    if args.ckpt:
        ckpt = tf.train.get_checkpoint_state(args.ckpt)
    else:
        raise ValueError

    # set config
    tfconfig = tf.ConfigProto(allow_soft_placement=True)
    tfconfig.gpu_options.allow_growth = True

    # init session
    saver = tf.train.Saver()
    with tf.Session(config=tfconfig) as sess:
        print('Restored from {}'.format(ckpt.model_checkpoint_path))
        saver.restore(sess, ckpt.model_checkpoint_path)

        # for n in tf.get_default_graph().as_graph_def().node:
        #     if 'input_feed' in n.name:
        #         print(n.name)

        for path in im_paths:
            test_im(sess, net, path, vocab)
コード例 #2
0
ファイル: demo.py プロジェクト: zhangbinxy/im2p-tensorflow
    im_paths = glob.glob('./data/demo/*.jpg')
    print(im_paths)

    # read checkpoint file
    if args.ckpt:
        ckpt = tf.train.get_checkpoint_state(args.ckpt)
    else:
        raise ValueError

    # set config
    tfconfig = tf.ConfigProto(allow_soft_placement=True)
    tfconfig.gpu_options.allow_growth = True

    # init session
    saver = tf.train.Saver()
    with tf.Session(config=tfconfig) as sess:
        print('Restored from {}'.format(ckpt.model_checkpoint_path))
        saver.restore(sess, ckpt.model_checkpoint_path)

        # for n in tf.get_default_graph().as_graph_def().node:
        #     if 'input_feed' in n.name:
        #         print(n.name)
        # for html visualization
        pre_results = {}
        save_path = './vis/data'
        for path in im_paths:
            test_im(sess, net, path, vocab, pre_results)

        # with open(save_path + '/results.json', 'w') as f:
        # json.dump(pre_results, f)