コード例 #1
0
        image_id, spatial_features, input_seq, target_seq, indicator = import_mscoco(
            mode=FLAGS.mode,
            batch_size=FLAGS.batch_size,
            num_epochs=1,
            is_mini=FLAGS.is_mini)
        visual_sentinel_cell = VisualSentinelCell(300, num_image_features=2048)
        attribute_image_captioner = AttributeImageCaptioner(
            visual_sentinel_cell, vocab, pretrained_matrix, attribute_map,
            attribute_embeddings_map)
        attribute_detector = AttributeDetector(1000)
        _, top_k_attributes = attribute_detector(
            tf.reduce_mean(spatial_features, [1, 2]))
        logits, ids = attribute_image_captioner(
            top_k_attributes, spatial_image_features=spatial_features)

        captioner_saver = tf.train.Saver(var_list=remap_decoder_name_scope(
            attribute_image_captioner.variables))
        attribute_detector_saver = tf.train.Saver(
            var_list=attribute_detector.variables)
        captioner_ckpt, captioner_ckpt_name = get_visual_sentinel_attribute_checkpoint(
        )
        attribute_detector_ckpt, attribute_detector_ckpt_name = get_attribute_detector_checkpoint(
        )

        with tf.Session() as sess:

            assert (captioner_ckpt is not None
                    and attribute_detector_ckpt is not None)
            captioner_saver.restore(sess, captioner_ckpt)
            attribute_detector_saver.restore(sess, attribute_detector_ckpt)
            used_ids = set()
            json_dump = []
コード例 #2
0
                                          embedding_size=300)
    with tf.Graph().as_default():

        image_id, mean_features, input_seq, target_seq, indicator = (
            import_mscoco(mode="train",
                          batch_size=BATCH_SIZE,
                          num_epochs=1,
                          is_mini=True))
        image_captioner = ImageCaptioner(ShowAndTellCell(300),
                                         vocab,
                                         pretrained_matrix,
                                         trainable=False,
                                         beam_size=BEAM_SIZE)
        logits, ids = image_captioner(mean_image_features=mean_features)
        captioner_saver = tf.train.Saver(
            var_list=remap_decoder_name_scope(image_captioner.variables))
        captioner_ckpt, captioner_ckpt_name = get_show_and_tell_checkpoint()

        with tf.Session() as sess:

            assert (captioner_ckpt is not None)
            captioner_saver.restore(sess, captioner_ckpt)
            used_ids = set()
            json_dump = []

            for i in itertools.count():
                time_start = time.time()
                try:
                    _ids, _target_seq, _image_id = sess.run(
                        [ids, target_seq, image_id])
                except:
コード例 #3
0
    with tf.Graph().as_default():

        image_id, running_ids, indicator, previous_id, next_id, pointer, image_features = (
            import_mscoco(mode="train",
                          batch_size=BATCH_SIZE,
                          num_epochs=1,
                          is_mini=True))
        best_first_module = BestFirstModule(pretrained_matrix)
        pointer_logits, word_logits = best_first_module(image_features,
                                                        running_ids,
                                                        previous_id,
                                                        indicators=indicator)
        ids = tf.argmax(word_logits, axis=-1, output_type=tf.int32)
        pointer_ids = tf.argmax(pointer_logits, axis=-1, output_type=tf.int32)
        captioner_saver = tf.train.Saver(
            var_list=remap_decoder_name_scope(best_first_module.variables))
        captioner_ckpt, captioner_ckpt_name = get_best_first_checkpoint()

        with tf.Session() as sess:

            assert (captioner_ckpt is not None)
            captioner_saver.restore(sess, captioner_ckpt)
            used_ids = set()
            json_dump = []

            for i in itertools.count():
                time_start = time.time()
                try:
                    _caption, _ids, _next_id, _model_pointer, _label_pointer, _image_id = sess.run(
                        [
                            running_ids, ids, next_id, pointer_ids, pointer,