Exemplo n.º 1
0
def main(unused_argv):
    def _is_valid_num_shards(num_shards):
        """Returns True if num_shards is compatible with FLAGS.num_threads."""
        return num_shards < FLAGS.num_threads or not num_shards % FLAGS.num_threads

    assert _is_valid_num_shards(FLAGS.train_shards), (
        "Please make the FLAGS.num_threads commensurate with FLAGS.train_shards")
    assert _is_valid_num_shards(FLAGS.val_shards), (
        "Please make the FLAGS.num_threads commensurate with FLAGS.val_shards")
    assert _is_valid_num_shards(FLAGS.test_shards), (
        "Please make the FLAGS.num_threads commensurate with FLAGS.test_shards")
        
    # Create vocabulary from the glove embeddings.
    vocab, _ = load_glove(vocab_size=FLAGS.vocab_size, embedding_size=50)

    if not tf.gfile.IsDirectory(FLAGS.output_dir):
        tf.gfile.MakeDirs(FLAGS.output_dir)

    # Load image metadata from caption files.
    mscoco_train_dataset = _load_and_process_metadata(FLAGS.train_captions_file, FLAGS.train_image_dir)
    mscoco_val_dataset = _load_and_process_metadata(FLAGS.val_captions_file, FLAGS.val_image_dir)

    # Redistribute the MSCOCO data as follows:
    #   train_dataset = 99% of mscoco_train_dataset
    #   val_dataset = 1% of mscoco_train_dataset (for validation during training).
    #   test_dataset = 100% of mscoco_val_dataset (for final evaluation).
    train_cutoff = int(0.99 * len(mscoco_train_dataset))
    train_dataset = mscoco_train_dataset[:train_cutoff]
    val_dataset = mscoco_train_dataset[train_cutoff:]
    test_dataset = mscoco_val_dataset
    
    dataset_extractor = Extractor()
    _process_dataset("train", train_dataset, vocab, FLAGS.train_shards, dataset_extractor)
    _process_dataset("val", val_dataset, vocab, FLAGS.val_shards, dataset_extractor)
    _process_dataset("test", test_dataset, vocab, FLAGS.test_shards, dataset_extractor)
Exemplo n.º 2
0
def main(unused_argv):
    
    vocab, pretrained_matrix = load_glove(vocab_size=100000, embedding_size=300)
    with tf.Graph().as_default():

        image_id, image_features, indicator, word_ids, pointer_ids = import_mscoco(
            mode="train", batch_size=FLAGS.batch_size, num_epochs=FLAGS.num_epochs, is_mini=FLAGS.is_mini)
        lengths = tf.reduce_sum(indicator, [1])
        show_and_tell_cell = ShowAndTellCell(300)
        best_first_image_captioner = BestFirstImageCaptioner(show_and_tell_cell, vocab, pretrained_matrix)
        word_logits, wids, pointer_logits, pids, ids, _lengths = best_first_image_captioner(
            mean_image_features=image_features,
            word_ids=word_ids, pointer_ids=pointer_ids, lengths=lengths)
        tf.losses.sparse_softmax_cross_entropy(pointer_ids, pointer_logits)
        tf.losses.sparse_softmax_cross_entropy(word_ids, word_logits)
        loss = tf.losses.get_total_loss()
        
        global_step = tf.train.get_or_create_global_step()
        optimizer = tf.train.AdamOptimizer()
        learning_step = optimizer.minimize(loss, var_list=best_first_image_captioner.variables, 
            global_step=global_step)

        captioner_saver = tf.train.Saver(var_list=best_first_image_captioner.variables + [global_step])
        captioner_ckpt, captioner_ckpt_name = get_best_first_checkpoint()
        with tf.Session() as sess:
            
            sess.run(tf.variables_initializer(optimizer.variables()))
            if captioner_ckpt is not None:
                captioner_saver.restore(sess, captioner_ckpt)
            else:
                sess.run(tf.variables_initializer(best_first_image_captioner.variables + [global_step]))
            captioner_saver.save(sess, captioner_ckpt_name, global_step=global_step)
            last_save = time.time()
            
            for i in itertools.count():
                
                time_start = time.time()
                try:
                    twids, tpids, _ids, _lengths, _loss, _learning_step = sess.run([
                        word_ids, pointer_ids, ids, lengths, loss, learning_step])
                except:
                    break
                    
                iteration = sess.run(global_step)
                
                insertion_sequence = insertion_sequence_to_array(twids, tpids, _lengths, vocab)
                    
                print(PRINT_STRING.format(
                    iteration, _loss, 
                    list_of_ids_to_string(insertion_sequence[0], vocab), 
                    list_of_ids_to_string(twids[0, :].tolist(), vocab), 
                    FLAGS.batch_size / (time.time() - time_start)))
                
                new_save = time.time()
                if new_save - last_save > 3600: # save the model every hour
                    captioner_saver.save(sess, captioner_ckpt_name, global_step=global_step)
                    last_save = new_save
                    
            captioner_saver.save(sess, captioner_ckpt_name, global_step=global_step)
            print("Finishing training.")
def main(unused_argv):
    
    vocab, pretrained_matrix = load_glove(vocab_size=100000, embedding_size=300)
    with tf.Graph().as_default():

        image_id, running_ids, indicator, previous_id, next_id, pointer, image_features = (
            import_mscoco(mode="train", batch_size=BATCH_SIZE, num_epochs=100, is_mini=True))
        best_first_module = BestFirstModule(pretrained_matrix)
        pointer_logits, word_logits = best_first_module(
            image_features, running_ids, previous_id, indicators=indicator, pointer_ids=pointer)
        tf.losses.sparse_softmax_cross_entropy(pointer, pointer_logits)
        tf.losses.sparse_softmax_cross_entropy(next_id, word_logits)
        loss = tf.losses.get_total_loss()
        
        ids = tf.argmax(word_logits, axis=-1, output_type=tf.int32)
        
        global_step = tf.train.get_or_create_global_step()
        learning_rate = LEARNING_RATE
        learning_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, 
            var_list=best_first_module.variables, global_step=global_step)

        captioner_saver = tf.train.Saver(var_list=best_first_module.variables + [global_step])
        captioner_ckpt, captioner_ckpt_name = get_best_first_checkpoint()
        with tf.Session() as sess:
            
            if captioner_ckpt is not None:
                captioner_saver.restore(sess, captioner_ckpt)
            else:
                sess.run(tf.variables_initializer(best_first_module.variables + [global_step]))
            captioner_saver.save(sess, captioner_ckpt_name, global_step=global_step)
            last_save = time.time()
            _ids, _loss, _learning_step = sess.run([ids, loss, learning_step])
            
            for i in itertools.count():
                
                time_start = time.time()
                try:
                    _ids, _loss, _learning_step = sess.run([ids, loss, learning_step])
                except:
                    break
                    
                iteration = sess.run(global_step)
                    
                print(PRINT_STRING.format(
                    iteration, _loss, list_of_ids_to_string(_ids.tolist(), vocab), 
                    BATCH_SIZE / (time.time() - time_start)))
                
                new_save = time.time()
                if new_save - last_save > 3600: # save the model every hour
                    captioner_saver.save(sess, captioner_ckpt_name, global_step=global_step)
                    last_save = new_save
                    
            captioner_saver.save(sess, captioner_ckpt_name, global_step=global_step)
            print("Finishing training.")
Exemplo n.º 4
0
def main(unused_argv):

    vocab, pretrained_matrix = load_glove(vocab_size=100000,
                                          embedding_size=300)
    with tf.Graph().as_default():

        image_id, spatial_features, input_seq, target_seq, indicator = (
            import_mscoco(mode="train",
                          batch_size=BATCH_SIZE,
                          num_epochs=100,
                          is_mini=True))
        visual_sentinel_cell = VisualSentinelCell(300)
        image_captioner = ImageCaptioner(visual_sentinel_cell, vocab,
                                         pretrained_matrix)
        logits, ids = image_captioner(lengths=tf.reduce_sum(indicator, axis=1),
                                      spatial_image_features=spatial_features,
                                      seq_inputs=input_seq)
        tf.losses.sparse_softmax_cross_entropy(target_seq,
                                               logits,
                                               weights=indicator)
        loss = tf.losses.get_total_loss()
        learning_step = tf.train.GradientDescentOptimizer(1.0).minimize(
            loss, var_list=image_captioner.variables)

        captioner_saver = tf.train.Saver(var_list=image_captioner.variables)
        captioner_ckpt, captioner_ckpt_name = get_visual_sentinel_checkpoint()
        with tf.Session() as sess:
            sess.run(tf.variables_initializer(image_captioner.variables))
            if captioner_ckpt is not None:
                captioner_saver.restore(sess, captioner_ckpt)
            captioner_saver.save(sess, captioner_ckpt_name)
            last_save = time.time()
            for i in itertools.count():
                time_start = time.time()
                try:
                    _ids, _loss, _learning_step = sess.run(
                        [ids, loss, learning_step])
                except:
                    break
                print(
                    PRINT_STRING.format(
                        i, _loss,
                        list_of_ids_to_string(_ids[0, :].tolist(), vocab),
                        BATCH_SIZE / (time.time() - time_start)))
                new_save = time.time()
                if new_save - last_save > 3600:  # save the model every hour
                    captioner_saver.save(sess, captioner_ckpt_name)
                    last_save = new_save

            captioner_saver.save(sess, captioner_ckpt_name)
            print("Finishing training.")
Exemplo n.º 5
0
def main(unused_argv):

    image = load_image_from_path("images/image.jpg")[np.newaxis, ...]

    vocab, pretrained_matrix = load_glove(vocab_size=100, embedding_size=50)
    pos, pos_embeddings = get_parts_of_speech(), np.random.normal(
        0, 0.1, [15, 50])
    with tf.Graph().as_default():

        inputs = tf.placeholder(tf.float32, shape=image.shape)
        box_extractor = BoxExtractor(get_faster_rcnn_config(), top_k_boxes=16)
        boxes, scores, cropped_inputs = box_extractor(inputs)
        feature_extractor = FeatureExtractor()
        mean_image_features = tf.reduce_mean(feature_extractor(inputs), [1, 2])
        mean_object_features = tf.reshape(
            tf.reduce_mean(feature_extractor(cropped_inputs), [1, 2]),
            [1, 16, 2048])
        image_captioner = PartOfSpeechImageCaptioner(UpDownCell(50), vocab,
                                                     pretrained_matrix,
                                                     UpDownCell(50),
                                                     UpDownCell(50), pos,
                                                     pos_embeddings)
        pos_logits, pos_logits_ids, word_logits, word_logits_ids = image_captioner(
            mean_image_features=mean_image_features,
            mean_object_features=mean_object_features)

        with tf.Session() as sess:

            box_saver = tf.train.Saver(var_list=box_extractor.variables)
            resnet_saver = tf.train.Saver(var_list=feature_extractor.variables)

            box_saver.restore(sess, get_faster_rcnn_checkpoint())
            resnet_saver.restore(sess, get_resnet_v2_101_checkpoint())
            sess.run(tf.variables_initializer(image_captioner.variables))

            results = sess.run(
                [pos_logits, pos_logits_ids, word_logits, word_logits_ids],
                feed_dict={inputs: image})

            assert (results[2].shape[0] == 1 and results[2].shape[1] == 3
                    and results[2].shape[3] == 100)
            tf.logging.info("Successfully passed test.")
def main(unused_argv):

    vocab, pretrained_matrix = load_glove(vocab_size=100000,
                                          embedding_size=300)
    attribute_map, attribute_embeddings_map = get_visual_attributes(
    ), np.random.normal(0, 0.1, [1000, 2048])
    with tf.Graph().as_default():

        image_id, mean_features, object_features, input_seq, target_seq, indicator = import_mscoco(
            mode="train",
            batch_size=FLAGS.batch_size,
            num_epochs=FLAGS.num_epochs,
            is_mini=FLAGS.is_mini)
        up_down_cell = UpDownCell(300, num_image_features=4096)
        attribute_image_captioner = AttributeImageCaptioner(
            up_down_cell, vocab, pretrained_matrix, attribute_map,
            attribute_embeddings_map)
        attribute_detector = AttributeDetector(1000)
        _, top_k_attributes = attribute_detector(mean_features)
        logits, ids = attribute_image_captioner(
            top_k_attributes,
            lengths=tf.reduce_sum(indicator, axis=1),
            mean_image_features=mean_features,
            mean_object_features=object_features,
            seq_inputs=input_seq)
        tf.losses.sparse_softmax_cross_entropy(target_seq,
                                               logits,
                                               weights=indicator)
        loss = tf.losses.get_total_loss()

        global_step = tf.train.get_or_create_global_step()
        optimizer = tf.train.AdamOptimizer()
        learning_step = optimizer.minimize(
            loss,
            var_list=attribute_image_captioner.variables,
            global_step=global_step)

        captioner_saver = tf.train.Saver(
            var_list=attribute_image_captioner.variables + [global_step])
        attribute_detector_saver = tf.train.Saver(
            var_list=attribute_detector.variables)
        captioner_ckpt, captioner_ckpt_name = get_up_down_attribute_checkpoint(
        )
        attribute_detector_ckpt, attribute_detector_ckpt_name = get_attribute_detector_checkpoint(
        )
        with tf.Session() as sess:

            sess.run(tf.variables_initializer(optimizer.variables()))
            if captioner_ckpt is not None:
                captioner_saver.restore(sess, captioner_ckpt)
            else:
                sess.run(
                    tf.variables_initializer(
                        attribute_image_captioner.variables + [global_step]))
            if attribute_detector_ckpt is not None:
                attribute_detector_saver.restore(sess, attribute_detector_ckpt)
            else:
                sess.run(tf.variables_initializer(
                    attribute_detector.variables))
            captioner_saver.save(sess,
                                 captioner_ckpt_name,
                                 global_step=global_step)
            last_save = time.time()

            for i in itertools.count():

                time_start = time.time()
                try:
                    _target, _ids, _loss, _learning_step = sess.run(
                        [target_seq, ids, loss, learning_step])
                except:
                    break

                iteration = sess.run(global_step)

                print(
                    PRINT_STRING.format(
                        iteration, _loss,
                        list_of_ids_to_string(_ids[0, :].tolist(), vocab),
                        list_of_ids_to_string(_target[0, :].tolist(), vocab),
                        FLAGS.batch_size / (time.time() - time_start)))

                new_save = time.time()
                if new_save - last_save > 3600:  # save the model every hour
                    captioner_saver.save(sess,
                                         captioner_ckpt_name,
                                         global_step=global_step)
                    last_save = new_save

            captioner_saver.save(sess,
                                 captioner_ckpt_name,
                                 global_step=global_step)
            print("Finishing training.")
from detailed_captioning.utils import get_resnet_v2_101_checkpoint
from detailed_captioning.utils import get_show_and_tell_checkpoint
from detailed_captioning.utils import remap_decoder_name_scope
from detailed_captioning.utils import list_of_ids_to_string
from detailed_captioning.utils import recursive_ids_to_string
from detailed_captioning.utils import coco_get_metrics
from detailed_captioning.utils import get_train_annotations_file
from detailed_captioning.inputs.mean_image_features_only import import_mscoco

PRINT_STRING = """({3:.2f} img/sec) iteration: {0:05d}\n    caption: {1}\n    label: {2}"""
BATCH_SIZE = 10
BEAM_SIZE = 16

if __name__ == "__main__":

    vocab, pretrained_matrix = load_glove(vocab_size=100000,
                                          embedding_size=300)
    with tf.Graph().as_default():

        image_id, mean_features, input_seq, target_seq, indicator = (
            import_mscoco(mode="train",
                          batch_size=BATCH_SIZE,
                          num_epochs=1,
                          is_mini=True))
        image_captioner = ImageCaptioner(ShowAndTellCell(300),
                                         vocab,
                                         pretrained_matrix,
                                         trainable=False,
                                         beam_size=BEAM_SIZE)
        logits, ids = image_captioner(mean_image_features=mean_features)
        captioner_saver = tf.train.Saver(
            var_list=remap_decoder_name_scope(image_captioner.variables))
Exemplo n.º 8
0
def main(unused_argv):

    vocab, pretrained_matrix = load_glove(vocab_size=100000,
                                          embedding_size=300)
    with tf.Graph().as_default():

        image_id, mean_features, input_seq, target_seq, indicator = (
            import_mscoco(mode="train",
                          batch_size=BATCH_SIZE,
                          num_epochs=100,
                          is_mini=True))
        show_and_tell_cell = ShowAndTellCell(300)
        image_captioner = ImageCaptioner(show_and_tell_cell, vocab,
                                         pretrained_matrix)
        logits, ids = image_captioner(lengths=tf.reduce_sum(indicator, axis=1),
                                      mean_image_features=mean_features,
                                      seq_inputs=input_seq)
        tf.losses.sparse_softmax_cross_entropy(target_seq,
                                               logits,
                                               weights=indicator)
        loss = tf.losses.get_total_loss()

        global_step = tf.train.get_or_create_global_step()
        learning_rate = tf.train.exponential_decay(
            INITIAL_LEARNING_RATE,
            global_step, (TRAINING_EXAMPLES // BATCH_SIZE) * EPOCHS_PER_DECAY,
            DECAY_RATE,
            staircase=True)
        learning_step = tf.train.GradientDescentOptimizer(
            learning_rate).minimize(loss,
                                    var_list=image_captioner.variables,
                                    global_step=global_step)

        captioner_saver = tf.train.Saver(var_list=image_captioner.variables +
                                         [global_step])
        captioner_ckpt, captioner_ckpt_name = get_show_and_tell_checkpoint()
        with tf.Session() as sess:

            if captioner_ckpt is not None:
                captioner_saver.restore(sess, captioner_ckpt)
            else:
                sess.run(
                    tf.variables_initializer(image_captioner.variables +
                                             [global_step]))
            captioner_saver.save(sess,
                                 captioner_ckpt_name,
                                 global_step=global_step)
            last_save = time.time()

            for i in itertools.count():

                time_start = time.time()
                try:
                    _ids, _loss, _learning_step = sess.run(
                        [ids, loss, learning_step])
                except:
                    break

                iteration = sess.run(global_step)

                print(
                    PRINT_STRING.format(
                        iteration, _loss,
                        list_of_ids_to_string(_ids[0, :].tolist(), vocab),
                        BATCH_SIZE / (time.time() - time_start)))

                new_save = time.time()
                if new_save - last_save > 3600:  # save the model every hour
                    captioner_saver.save(sess,
                                         captioner_ckpt_name,
                                         global_step=global_step)
                    last_save = new_save

            captioner_saver.save(sess,
                                 captioner_ckpt_name,
                                 global_step=global_step)
            print("Finishing training.")
def main(unused_argv):

    vocab, pretrained_matrix = load_glove(vocab_size=100000,
                                          embedding_size=300)
    attribute_map = get_visual_attributes()
    attribute_to_word_lookup_table = vocab.word_to_id(
        attribute_map.reverse_vocab)

    with tf.Graph().as_default():

        (image_id, image_features, object_features, input_seq, target_seq,
         indicator, attributes) = import_mscoco(mode="train",
                                                batch_size=FLAGS.batch_size,
                                                num_epochs=FLAGS.num_epochs,
                                                is_mini=FLAGS.is_mini)

        attribute_detector = AttributeDetector(1000)
        _, image_attributes, object_attributes = attribute_detector(
            image_features, object_features)

        grounded_attribute_cell = GroundedAttributeCell(1024)
        attribute_captioner = AttributeCaptioner(
            grounded_attribute_cell, vocab, pretrained_matrix,
            attribute_to_word_lookup_table)
        logits, ids = attribute_captioner(lengths=tf.reduce_sum(indicator,
                                                                axis=1),
                                          mean_image_features=image_features,
                                          mean_object_features=object_features,
                                          seq_inputs=input_seq,
                                          image_attributes=image_attributes,
                                          object_attributes=object_attributes)

        tf.losses.sparse_softmax_cross_entropy(target_seq,
                                               logits,
                                               weights=indicator)
        loss = tf.losses.get_total_loss()

        global_step = tf.train.get_or_create_global_step()
        learning_rate = tf.train.exponential_decay(5e-4,
                                                   global_step,
                                                   3 * 586363 //
                                                   FLAGS.batch_size,
                                                   0.8,
                                                   staircase=True)
        optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
        learning_step = optimizer.minimize(
            loss,
            var_list=attribute_captioner.variables,
            global_step=global_step)

        detector_saver = tf.train.Saver(var_list=attribute_detector.variables +
                                        [global_step])
        detector_ckpt, detector_ckpt_name = get_attribute_detector_checkpoint()

        captioner_saver = tf.train.Saver(
            var_list=attribute_captioner.variables + [global_step])
        captioner_ckpt, captioner_ckpt_name = get_grounded_attribute_checkpoint(
        )

        with tf.Session() as sess:

            sess.run(tf.variables_initializer(optimizer.variables()))

            if detector_ckpt is not None:
                detector_saver.restore(sess, detector_ckpt)
            else:
                sess.run(
                    tf.variables_initializer(attribute_detector.variables +
                                             [global_step]))

            if captioner_ckpt is not None:
                captioner_saver.restore(sess, captioner_ckpt)
            else:
                sess.run(
                    tf.variables_initializer(attribute_captioner.variables +
                                             [global_step]))

            captioner_saver.save(sess,
                                 captioner_ckpt_name,
                                 global_step=global_step)
            last_save = time.time()

            for i in itertools.count():

                time_start = time.time()
                try:
                    _target, _ids, _loss, _learning_step = sess.run(
                        [target_seq, ids, loss, learning_step])
                except:
                    break

                iteration = sess.run(global_step)

                print(
                    PRINT_STRING.format(
                        iteration, _loss,
                        list_of_ids_to_string(_ids[0, :].tolist(), vocab),
                        list_of_ids_to_string(_target[0, :].tolist(), vocab),
                        FLAGS.batch_size / (time.time() - time_start)))

                new_save = time.time()
                if new_save - last_save > 3600:  # save the model every hour
                    captioner_saver.save(sess,
                                         captioner_ckpt_name,
                                         global_step=global_step)
                    last_save = new_save

            captioner_saver.save(sess,
                                 captioner_ckpt_name,
                                 global_step=global_step)
            print("Finishing training.")
Exemplo n.º 10
0
def main(unused_argv):

    vocab, pretrained_matrix = load_glove(vocab_size=100000,
                                          embedding_size=300)
    with tf.Graph().as_default():

        image_id, spatial_features, input_seq, target_seq, indicator = import_mscoco(
            mode="train",
            batch_size=FLAGS.batch_size,
            num_epochs=FLAGS.num_epochs,
            is_mini=FLAGS.is_mini)
        image_captioner = ImageCaptioner(SpatialAttentionCell(300), vocab,
                                         pretrained_matrix)
        logits, ids = image_captioner(lengths=tf.reduce_sum(indicator, axis=1),
                                      spatial_image_features=spatial_features,
                                      seq_inputs=input_seq)
        tf.losses.sparse_softmax_cross_entropy(target_seq,
                                               logits,
                                               weights=indicator)
        loss = tf.losses.get_total_loss()

        global_step = tf.train.get_or_create_global_step()
        optimizer = tf.train.AdamOptimizer()
        learning_step = optimizer.minimize(loss,
                                           var_list=image_captioner.variables,
                                           global_step=global_step)

        captioner_saver = tf.train.Saver(var_list=image_captioner.variables +
                                         [global_step])
        captioner_ckpt, captioner_ckpt_name = get_spatial_attention_checkpoint(
        )
        with tf.Session() as sess:

            sess.run(tf.variables_initializer(optimizer.variables()))
            if captioner_ckpt is not None:
                captioner_saver.restore(sess, captioner_ckpt)
            else:
                sess.run(
                    tf.variables_initializer(image_captioner.variables +
                                             [global_step]))
            captioner_saver.save(sess,
                                 captioner_ckpt_name,
                                 global_step=global_step)
            last_save = time.time()

            for i in itertools.count():

                time_start = time.time()
                try:
                    _target, _ids, _loss, _learning_step = sess.run(
                        [target_seq, ids, loss, learning_step])
                except:
                    break

                iteration = sess.run(global_step)

                print(
                    PRINT_STRING.format(
                        iteration, _loss,
                        list_of_ids_to_string(_ids[0, :].tolist(), vocab),
                        list_of_ids_to_string(_target[0, :].tolist(), vocab),
                        FLAGS.batch_size / (time.time() - time_start)))

                new_save = time.time()

                if new_save - last_save > 3600:
                    captioner_saver.save(sess,
                                         captioner_ckpt_name,
                                         global_step=global_step)
                    last_save = new_save

            captioner_saver.save(sess,
                                 captioner_ckpt_name,
                                 global_step=global_step)
            print("Finishing training.")
def main(unused_argv):
    def _is_valid_num_shards(num_shards):
        """Returns True if num_shards is compatible with FLAGS.num_threads."""
        return num_shards < FLAGS.num_threads or not num_shards % FLAGS.num_threads

    assert _is_valid_num_shards(FLAGS.train_shards), (
        "Please make the FLAGS.num_threads commensurate with FLAGS.train_shards")
    assert _is_valid_num_shards(FLAGS.val_shards), (
        "Please make the FLAGS.num_threads commensurate with FLAGS.val_shards")
    assert _is_valid_num_shards(FLAGS.test_shards), (
        "Please make the FLAGS.num_threads commensurate with FLAGS.test_shards")
        
    # Create vocabulary from the glove embeddings.
    vocab, _ = load_glove(vocab_size=FLAGS.vocab_size, embedding_size=FLAGS.embedding_size)
    tagger = load_tagger()

    if not tf.gfile.IsDirectory(FLAGS.output_dir):
        tf.gfile.MakeDirs(FLAGS.output_dir)

    # Load image metadata from caption files.
    mscoco_train_dataset = _load_and_process_metadata(FLAGS.train_captions_file, FLAGS.train_image_dir)
    mscoco_val_dataset = _load_and_process_metadata(FLAGS.val_captions_file, FLAGS.val_image_dir)

    # Redistribute the MSCOCO data as follows:
    #   train_dataset = 99% of mscoco_train_dataset
    #   val_dataset = 1% of mscoco_train_dataset (for validation during training).
    #   test_dataset = 100% of mscoco_val_dataset (for final evaluation).
    train_cutoff = int(0.99 * len(mscoco_train_dataset))
    train_dataset = mscoco_train_dataset[:train_cutoff]
    val_dataset = mscoco_train_dataset[train_cutoff:]
    test_dataset = mscoco_val_dataset
    
    # If needed crop the dataset to make it smaller
    max_train_size = len(train_dataset)
    if FLAGS.train_dataset_size < max_train_size:
        # Shuffle the ordering of images. Make the randomization repeatable.
        random.seed(12345)
        random.shuffle(train_dataset)
        train_dataset = train_dataset[:FLAGS.train_dataset_size]
        
    max_val_size = len(val_dataset)
    if FLAGS.val_dataset_size < max_val_size:
        # Shuffle the ordering of images. Make the randomization repeatable.
        random.seed(12345)
        random.shuffle(val_dataset)
        val_dataset = val_dataset[:FLAGS.val_dataset_size]
        
    max_test_size = len(test_dataset)
    if FLAGS.test_dataset_size < max_test_size:
        # Shuffle the ordering of images. Make the randomization repeatable.
        random.seed(12345)
        random.shuffle(test_dataset)
        test_dataset = test_dataset[:FLAGS.test_dataset_size]

    # Create the model to extract image boxes
    box_extractor = BoxExtractor(get_faster_rcnn_config(), top_k_boxes=FLAGS.top_k_boxes, 
        trainable=False)
    image_tensor = tf.placeholder(tf.float32, name='image_tensor', shape=[None, 
        FLAGS.image_height, FLAGS.image_width, 3])
    boxes, scores, cropped_images = box_extractor(image_tensor)
    # Create the model to extract the image features
    feature_extractor = FeatureExtractor(is_training=False, global_pool=False)
    # Compute the ResNet-101 features
    image_features = feature_extractor(image_tensor)
    feature_batch = tf.shape(image_features)[0]
    feature_depth = tf.shape(image_features)[3]
    object_features = tf.reduce_mean(feature_extractor(cropped_images), [1, 2])
    object_features = tf.reshape(object_features, [feature_batch, FLAGS.top_k_boxes, feature_depth])

    with tf.Session() as sess:

        rcnn_saver = tf.train.Saver(var_list=box_extractor.variables)
        resnet_saver = tf.train.Saver(var_list=feature_extractor.variables)
        rcnn_saver.restore(sess, get_faster_rcnn_checkpoint())
        resnet_saver.restore(sess, get_resnet_v2_101_checkpoint())
        
        lock = threading.Lock()
        def run_model_fn(images):
            lock.acquire()
            r = sess.run([image_features, object_features], feed_dict={image_tensor: images})
            lock.release()
            return r

        _process_dataset("train", train_dataset, vocab, tagger, FLAGS.train_shards, run_model_fn)
        _process_dataset("val", val_dataset, vocab, tagger, FLAGS.val_shards, run_model_fn)
        _process_dataset("test", test_dataset, vocab, tagger, FLAGS.test_shards, run_model_fn)