def test_decode():
    image_np = load_image_into_numpy_array(
        '/home/zisang/Documents/code/data/Flicker8k/Flicker8k_Dataset/2919459517_b8b858afa3.jpg'
    )
    if image_np is not None:
        faster_rcnn = FasterRcnnEncoder('../data/frozen_faster_rcnn.pb')
        box, feat = faster_rcnn.encode(image_np)

        # build vocabulary file
        vocab = vocabulary.Vocabulary("../data/flickr8k/word_counts.txt")
        lstm = LSTMDecoder('../data/frozen_lstm.pb',
                           vocab,
                           max_caption_length=20)
        caption, attention = lstm.decode(feat)
        lstm.show_attention(caption, attention, box, image_np, './a.jpg')
Beispiel #2
0
def main(_):

    with tf.variable_scope('placeholder'):
        image_feed = tf.placeholder(dtype=tf.string,
                                    shape=[],
                                    name="image_feed")

    images = tf.expand_dims(process_image(image_feed), 0)
    vgg_output = image_embedding.vgg_19_extract(images,
                                                trainable=False,
                                                is_training=False)
    """
    vgg_variables = tf.get_collection(
        tf.GraphKeys.GLOBAL_VARIABLES, scope="vgg_19")
    """
    saver = tf.train.Saver()
    with tf.Session() as sess:
        saver.restore(sess, FLAGS.checkpoint_path)
        vocab = vocabulary.Vocabulary(FLAGS.vocab_file)
        lstm = LSTMDecoder(FLAGS.graph_path, vocab, max_caption_length=20)

        filenames = []
        for file_pattern in FLAGS.input_files.split(","):
            filenames.extend(tf.gfile.Glob(file_pattern))

        for filename in filenames:

            with tf.gfile.FastGFile(filename, "rb") as f:
                image_data = f.read()
            context = tf.reshape(vgg_output, [-1, 196, 512])
            context = sess.run(context, feed_dict={image_feed: image_data})
            tf.get_variable_scope().reuse_variables()
            feat = np.squeeze(context)

            caption, attention = lstm.decode(feat)

            #image = Image.open(FLAGS.input_files)
            image = load_image_into_numpy_array(filename)
            if FLAGS.demo_method == "show":
                lstm.show_caption(caption, image)
            else:
                lstm.show_attention(caption, attention, image, "./pic.jpg")
Beispiel #3
0
    def __init__(self, ndim=50):
        assert(ndim in self._AVAILABLE_DIMS)

        self.vocab = None
        self.W = None
        self.zipped_filename = "data/glove/glove.6B.zip"

        # Download datasets
        if not os.path.isfile(self.zipped_filename):
            data_dir = os.path.dirname(self.zipped_filename)
            print("Downloading GloVe vectors to {:s}".format(data_dir))
            self.zipped_filename = download_glove(data_dir)
        print("Loading vectors from {:s}".format(self.zipped_filename))

        words, W = parse_glove_file(self.zipped_filename, ndim)
        # Set nonzero value for special tokens
        mean_vec = np.mean(W[3:], axis=0)
        for i in range(3):
            W[i] = mean_vec
        self.W = W
        self.vocab = vocabulary.Vocabulary(words[3:])
        assert(self.vocab.size == self.W.shape[0])
Beispiel #4
0
def run():
  """Runs evaluation in a loop, and logs summaries to TensorBoard."""
  # Create the evaluation directory if it doesn't exist.
  eval_dir = FLAGS.eval_dir
  if not tf.gfile.IsDirectory(eval_dir):
    tf.logging.info("Creating eval directory: %s", eval_dir)
    tf.gfile.MakeDirs(eval_dir)

  # build vocabulary file
  vocab = vocabulary.Vocabulary(FLAGS.vocab_file)

  g = tf.Graph()
  with g.as_default():

    config = Config()
    config.input_file_pattern = FLAGS.input_file_pattern
    config.beam_size = FLAGS.beam_size

    # Build the model for evaluation.
    model = CaptionGenerator(config, mode="eval") 
    model.build()

    # Create the Saver to restore model Variables.
    saver = tf.train.Saver()

    # Create the summary writer.
    summary_writer = tf.summary.FileWriter(eval_dir)

    g.finalize()

    # Run a new evaluation run every eval_interval_secs.
    while True:
      start = time.time()
      tf.logging.info("Starting evaluation at " + time.strftime(
          "%Y-%m-%d-%H:%M:%S", time.localtime()))
      run_once(model,vocab, saver, summary_writer)
      time_to_next_eval = start + FLAGS.eval_interval_secs - time.time()
      if time_to_next_eval > 0:
        time.sleep(time_to_next_eval)
Beispiel #5
0
tf.flags.DEFINE_string('mode', 'att-nic', 'Can be att-nic or ours')
tf.flags.DEFINE_string(
    'vocab_path', '../output/vocabulary.csv',
    'Vocabulary file, be ../data/flickr8k/word_counts.txt for mode=ours')
tf.flags.DEFINE_string("faster_rcnn_path", "../data/frozen_faster_rcnn.pb",
                       "Faster r-cnn frozen graph")
tf.flags.DEFINE_string("region_lstm_path", "../data/frozen_lstm.pb",
                       "region attention based lstm forzen graph")
tf.flags.DEFINE_string("att_nic_path", "../data/frozen_att_nic.pb",
                       "region attention based lstm forzen graph")

if FLAGS.mode == 'ours':
    faster_rcnn = FasterRcnnEncoder(FLAGS.faster_rcnn_path)
    # build vocabulary file
    vocab = vocabulary.Vocabulary(FLAGS.vocab_path)
    lstm = LSTMDecoder(FLAGS.region_lstm_path, vocab, max_caption_length=20)
else:
    vocab = att_nic_vocab.Vocabulary(5000, FLAGS.vocab_path)
    att_nic = ATT_NIC(FLAGS.att_nic_path, vocab, max_caption_length=20)

ALLOWED_EXTENSIONS = set(['png', 'jpg', 'jpeg'])


def allowed_file(filename):
    return '.' in filename and \
           filename.rsplit('.', 1)[1] in ALLOWED_EXTENSIONS


app = Flask(__name__)
def main(_):
    # Build the inference graph.
    config = configuration.ModelConfig()
    config.input_file_pattern = FLAGS.input_file_pattern

    g = tf.Graph()
    with g.as_default():
        model = inference_wrapper.InferenceWrapper()
        restore_fn = model.build_graph_from_config(config,
                                                   FLAGS.checkpoint_path)
        complete_sentences = input_ops.batch_input_data(
            file_name_pattern=config.input_file_pattern,
            config=config,
            mode='test')

        # for the num_epoch parameter to be set
        init_local = tf.local_variables_initializer()
        init_global = tf.global_variables_initializer()
    g.finalize()

    # Create the vocabulary.
    vocab = vocabulary.Vocabulary(FLAGS.vocab_file)
    tf.logging.info("Running Hypothesis & Reference generation.")

    # modest mode, not occupying all GPUs.
    sess_config = tf.ConfigProto()
    sess_config.gpu_options.allow_growth = True

    with tf.Session(graph=g, config=sess_config) as sess:
        # initialize all the variables especially for num_epochs
        sess.run(init_local)
        sess.run(init_global)
        # Load the model from checkpoint.
        restore_fn(sess)
        # start the queue
        threads = tf.train.start_queue_runners(sess=sess)

        # Prepare the caption generator. Here we are implicitly using the default
        # beam search parameters. See caption_generator.py for a description of the
        # available beam search parameters.
        generator = caption_generator.CaptionGenerator(model,
                                                       vocab,
                                                       beam_size=5)

        # try:
        with open("models/BasicLSTM/test_log/hypo.txt",
                  "w") as hf, open("models/BasicLSTM/test_log/ref.txt",
                                   "w") as rf:
            counter = 0
            while True:
                sentence_fetch = sess.run(complete_sentences)
                print("---------------------------------------")
                print("Test case No.{0}".format(counter))
                print(
                    "REF:\t", " ".join(
                        [x for x in map(vocab.id_to_word, sentence_fetch)]))
                # time.sleep(3)
                rf.write(" ".join(
                    [x for x in map(vocab.id_to_word, sentence_fetch)]) + '\n')

                captions = generator.beam_search(sess, sentence_fetch, k=1)

                caption = captions[0]
                # Ignore begin and end words.
                sentence = [vocab.id_to_word(w) for w in caption.sentence]
                sentence = " ".join(sentence)
                print("HYP:\t", sentence)
                hf.write(sentence + '\n')
                counter += 1