def main(unused_argv):

    vocab, pretrained_matrix = load_glove(vocab_size=100000,
                                          embedding_size=300)
    attribute_map, attribute_embeddings_map = get_visual_attributes(
    ), np.random.normal(0, 0.1, [1000, 2048])
    with tf.Graph().as_default():

        image_id, mean_features, object_features, input_seq, target_seq, indicator = import_mscoco(
            mode="train",
            batch_size=FLAGS.batch_size,
            num_epochs=FLAGS.num_epochs,
            is_mini=FLAGS.is_mini)
        up_down_cell = UpDownCell(300, num_image_features=4096)
        attribute_image_captioner = AttributeImageCaptioner(
            up_down_cell, vocab, pretrained_matrix, attribute_map,
            attribute_embeddings_map)
        attribute_detector = AttributeDetector(1000)
        _, top_k_attributes = attribute_detector(mean_features)
        logits, ids = attribute_image_captioner(
            top_k_attributes,
            lengths=tf.reduce_sum(indicator, axis=1),
            mean_image_features=mean_features,
            mean_object_features=object_features,
            seq_inputs=input_seq)
        tf.losses.sparse_softmax_cross_entropy(target_seq,
                                               logits,
                                               weights=indicator)
        loss = tf.losses.get_total_loss()

        global_step = tf.train.get_or_create_global_step()
        optimizer = tf.train.AdamOptimizer()
        learning_step = optimizer.minimize(
            loss,
            var_list=attribute_image_captioner.variables,
            global_step=global_step)

        captioner_saver = tf.train.Saver(
            var_list=attribute_image_captioner.variables + [global_step])
        attribute_detector_saver = tf.train.Saver(
            var_list=attribute_detector.variables)
        captioner_ckpt, captioner_ckpt_name = get_up_down_attribute_checkpoint(
        )
        attribute_detector_ckpt, attribute_detector_ckpt_name = get_attribute_detector_checkpoint(
        )
        with tf.Session() as sess:

            sess.run(tf.variables_initializer(optimizer.variables()))
            if captioner_ckpt is not None:
                captioner_saver.restore(sess, captioner_ckpt)
            else:
                sess.run(
                    tf.variables_initializer(
                        attribute_image_captioner.variables + [global_step]))
            if attribute_detector_ckpt is not None:
                attribute_detector_saver.restore(sess, attribute_detector_ckpt)
            else:
                sess.run(tf.variables_initializer(
                    attribute_detector.variables))
            captioner_saver.save(sess,
                                 captioner_ckpt_name,
                                 global_step=global_step)
            last_save = time.time()

            for i in itertools.count():

                time_start = time.time()
                try:
                    _target, _ids, _loss, _learning_step = sess.run(
                        [target_seq, ids, loss, learning_step])
                except:
                    break

                iteration = sess.run(global_step)

                print(
                    PRINT_STRING.format(
                        iteration, _loss,
                        list_of_ids_to_string(_ids[0, :].tolist(), vocab),
                        list_of_ids_to_string(_target[0, :].tolist(), vocab),
                        FLAGS.batch_size / (time.time() - time_start)))

                new_save = time.time()
                if new_save - last_save > 3600:  # save the model every hour
                    captioner_saver.save(sess,
                                         captioner_ckpt_name,
                                         global_step=global_step)
                    last_save = new_save

            captioner_saver.save(sess,
                                 captioner_ckpt_name,
                                 global_step=global_step)
            print("Finishing training.")
def main(unused_argv):

    vocab, pretrained_matrix = load_glove(vocab_size=100000,
                                          embedding_size=300)
    attribute_map = get_visual_attributes()
    attribute_to_word_lookup_table = vocab.word_to_id(
        attribute_map.reverse_vocab)

    with tf.Graph().as_default():

        (image_id, image_features, object_features, input_seq, target_seq,
         indicator, attributes) = import_mscoco(mode="train",
                                                batch_size=FLAGS.batch_size,
                                                num_epochs=FLAGS.num_epochs,
                                                is_mini=FLAGS.is_mini)

        attribute_detector = AttributeDetector(1000)
        _, image_attributes, object_attributes = attribute_detector(
            image_features, object_features)

        grounded_attribute_cell = GroundedAttributeCell(1024)
        attribute_captioner = AttributeCaptioner(
            grounded_attribute_cell, vocab, pretrained_matrix,
            attribute_to_word_lookup_table)
        logits, ids = attribute_captioner(lengths=tf.reduce_sum(indicator,
                                                                axis=1),
                                          mean_image_features=image_features,
                                          mean_object_features=object_features,
                                          seq_inputs=input_seq,
                                          image_attributes=image_attributes,
                                          object_attributes=object_attributes)

        tf.losses.sparse_softmax_cross_entropy(target_seq,
                                               logits,
                                               weights=indicator)
        loss = tf.losses.get_total_loss()

        global_step = tf.train.get_or_create_global_step()
        learning_rate = tf.train.exponential_decay(5e-4,
                                                   global_step,
                                                   3 * 586363 //
                                                   FLAGS.batch_size,
                                                   0.8,
                                                   staircase=True)
        optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
        learning_step = optimizer.minimize(
            loss,
            var_list=attribute_captioner.variables,
            global_step=global_step)

        detector_saver = tf.train.Saver(var_list=attribute_detector.variables +
                                        [global_step])
        detector_ckpt, detector_ckpt_name = get_attribute_detector_checkpoint()

        captioner_saver = tf.train.Saver(
            var_list=attribute_captioner.variables + [global_step])
        captioner_ckpt, captioner_ckpt_name = get_grounded_attribute_checkpoint(
        )

        with tf.Session() as sess:

            sess.run(tf.variables_initializer(optimizer.variables()))

            if detector_ckpt is not None:
                detector_saver.restore(sess, detector_ckpt)
            else:
                sess.run(
                    tf.variables_initializer(attribute_detector.variables +
                                             [global_step]))

            if captioner_ckpt is not None:
                captioner_saver.restore(sess, captioner_ckpt)
            else:
                sess.run(
                    tf.variables_initializer(attribute_captioner.variables +
                                             [global_step]))

            captioner_saver.save(sess,
                                 captioner_ckpt_name,
                                 global_step=global_step)
            last_save = time.time()

            for i in itertools.count():

                time_start = time.time()
                try:
                    _target, _ids, _loss, _learning_step = sess.run(
                        [target_seq, ids, loss, learning_step])
                except:
                    break

                iteration = sess.run(global_step)

                print(
                    PRINT_STRING.format(
                        iteration, _loss,
                        list_of_ids_to_string(_ids[0, :].tolist(), vocab),
                        list_of_ids_to_string(_target[0, :].tolist(), vocab),
                        FLAGS.batch_size / (time.time() - time_start)))

                new_save = time.time()
                if new_save - last_save > 3600:  # save the model every hour
                    captioner_saver.save(sess,
                                         captioner_ckpt_name,
                                         global_step=global_step)
                    last_save = new_save

            captioner_saver.save(sess,
                                 captioner_ckpt_name,
                                 global_step=global_step)
            print("Finishing training.")
        attribute_image_captioner = AttributeImageCaptioner(
            visual_sentinel_cell, vocab, pretrained_matrix, attribute_map,
            attribute_embeddings_map)
        attribute_detector = AttributeDetector(1000)
        _, top_k_attributes = attribute_detector(
            tf.reduce_mean(spatial_features, [1, 2]))
        logits, ids = attribute_image_captioner(
            top_k_attributes, spatial_image_features=spatial_features)

        captioner_saver = tf.train.Saver(var_list=remap_decoder_name_scope(
            attribute_image_captioner.variables))
        attribute_detector_saver = tf.train.Saver(
            var_list=attribute_detector.variables)
        captioner_ckpt, captioner_ckpt_name = get_visual_sentinel_attribute_checkpoint(
        )
        attribute_detector_ckpt, attribute_detector_ckpt_name = get_attribute_detector_checkpoint(
        )

        with tf.Session() as sess:

            assert (captioner_ckpt is not None
                    and attribute_detector_ckpt is not None)
            captioner_saver.restore(sess, captioner_ckpt)
            attribute_detector_saver.restore(sess, attribute_detector_ckpt)
            used_ids = set()
            json_dump = []

            for i in itertools.count():
                time_start = time.time()
                try:
                    _ids, _target_seq, _image_id = sess.run(
                        [ids, target_seq, image_id])
def main(unused_argv):

    attribute_map = get_visual_attributes()

    with tf.Graph().as_default():

        (image_id, image_features, object_features, input_seq, target_seq,
         indicator, attributes) = import_mscoco(mode="train",
                                                batch_size=FLAGS.batch_size,
                                                num_epochs=FLAGS.num_epochs,
                                                is_mini=FLAGS.is_mini)
        attribute_detector = AttributeDetector(1000)
        logits, image_detections, object_detections = attribute_detector(
            image_features, object_features)
        tf.losses.sigmoid_cross_entropy(attributes, logits)
        loss = tf.losses.get_total_loss()

        global_step = tf.train.get_or_create_global_step()
        optimizer = tf.train.AdamOptimizer()
        learning_step = optimizer.minimize(
            loss,
            var_list=attribute_detector.variables,
            global_step=global_step)

        captioner_saver = tf.train.Saver(
            var_list=attribute_detector.variables + [global_step])
        captioner_ckpt, captioner_ckpt_name = get_attribute_detector_checkpoint(
        )
        with tf.Session() as sess:

            sess.run(tf.variables_initializer(optimizer.variables()))
            if captioner_ckpt is not None:
                captioner_saver.restore(sess, captioner_ckpt)
            else:
                sess.run(
                    tf.variables_initializer(attribute_detector.variables +
                                             [global_step]))
            captioner_saver.save(sess,
                                 captioner_ckpt_name,
                                 global_step=global_step)
            last_save = time.time()

            for i in itertools.count():

                time_start = time.time()
                try:
                    _attributes, _detections, _loss, _ = sess.run(
                        [attributes, image_detections, loss, learning_step])
                except:
                    break

                iteration = sess.run(global_step)
                ground_truth_ids = np.where(_attributes[0, :] > 0.5)

                print(
                    PRINT_STRING.format(
                        FLAGS.batch_size / (time.time() - time_start),
                        iteration,
                        _loss,
                        str(
                            attribute_map.id_to_word(
                                _detections[0, :].tolist())),
                        str(
                            attribute_map.id_to_word(
                                ground_truth_ids[0].tolist())),
                    ))

                new_save = time.time()
                if new_save - last_save > 3600:  # save the model every hour
                    captioner_saver.save(sess,
                                         captioner_ckpt_name,
                                         global_step=global_step)
                    last_save = new_save

            captioner_saver.save(sess,
                                 captioner_ckpt_name,
                                 global_step=global_step)
            print("Finishing training.")