def generate_caption(self, img_path, beam_size=2):
     """Caption generator
     Args:
         image_path: path to the image
     Returns:
         caption: caption, generated for a given image
     """
     # TODO: to avoid specify model again use frozen graph
     g = tf.Graph()
     # change some Parameters
     self.params.sample_gen = self.gen_method
     try:
         os.path.exists(img_path)
     except:
         raise ValueError("Image not found")
     with g.as_default():
         # specify rnn_placeholders
         ann_lengths_ps = tf.placeholder(tf.int32, [None])
         captions_ps = tf.placeholder(tf.int32, [None, None])
         images_ps = tf.placeholder(tf.float32, [None, 224, 224, 3])
         with tf.variable_scope("cnn"):
             image_embeddings = vgg16(images_ps, trainable_fe=True,
                                      trainable_top=True)
         features = image_embeddings.fc2
         # image fesatures [b_size + f_size(4096)] -> [b_size + embed_size]
         images_fv = tf.layers.dense(features, self.params.embed_size,
                                     name='imf_emb')
         # will use methods from Decoder class
         decoder = Decoder(images_fv, captions_ps,
                           ann_lengths_ps, self.params, self.data_dict)
         with tf.variable_scope("decoder"):
             _, _ = decoder.decoder()  # for initialization
         # if use cluster vectors
         if self.params.use_c_v:
             # cluster vectors from "Diverse and Accurate Image Description.." paper.
             c_i = tf.placeholder(tf.float32, [None, 90])
             c_i_emb = tf.layers.dense(c_i, self.params.embed_size,
                                       name='cv_emb')
             # map cluster vectors into embedding space
             decoder.c_i = c_i_emb
             decoder.c_i_ph = c_i
         # image_id
         im_id = [img_path.split('/')[-1]]
         saver = tf.train.Saver(tf.trainable_variables())
     with tf.Session(graph=g) as sess:
         saver.restore(sess, self.checkpoint_path)
         if self.params.use_c_v:
             c_v = self._c_v_generator(image)
         else:
             c_v = None
         im_shape = (224, 224) # VGG16
         image = np.expand_dims(load_image(img_path, im_shape), 0)
         if self.gen_method == 'beam_search':
             sent = decoder.beam_search(sess, im_id, image,
                                        images_ps, c_v,
                                        beam_size=beam_size)
         elif self.gen_method == 'greedy':
             sent, _ = decoder.online_inference(sess, im_id, image,
                                                images_ps, c_v=c_v)
         return sent
def main(params):
    # load data, class data contains captions, images, image features (if avaliable)
    if params.gen_val_captions < 0:
        repartiton = False
    else:
        repartiton = True
    data = Data(params,
                True,
                params.image_net_weights_path,
                repartiton=repartiton,
                gen_val_cap=params.gen_val_captions)
    # load batch generator, repartiton to use more val set images in train
    gen_batch_size = params.batch_size
    if params.fine_tune:
        gen_batch_size = params.batch_size
    batch_gen = data.load_train_data_generator(gen_batch_size,
                                               params.fine_tune)
    # whether use presaved pretrained imagenet features (saved in pickle)
    # feature extractor after fine_tune will be saved in tf checkpoint
    # caption generation after fine_tune must be made with params.fine_tune=True
    pretrained = not params.fine_tune
    val_gen = data.get_valid_data(gen_batch_size,
                                  val_tr_unused=batch_gen.unused_cap_in,
                                  pretrained=pretrained)
    test_gen = data.get_test_data(gen_batch_size, pretrained=pretrained)
    # annotations vector of form <EOS>...<BOS><PAD>...
    cap_enc = tf.placeholder(tf.int32, [None, None])
    cap_dec = tf.placeholder(tf.int32, [None, None])
    cap_len = tf.placeholder(tf.int32, [None])
    if params.fine_tune:
        # if fine_tune dont just load images_fv
        image_batch = tf.placeholder(tf.float32, [None, 224, 224, 3])
    else:
        # use prepared image features [batch_size, 4096] (fc2)
        image_batch = tf.placeholder(tf.float32, [None, 4096])
    if params.use_c_v:
        cl_vectors = tf.placeholder(tf.float32, [None, 90])
    # features, params.fine_tune stands for not using presaved imagenet weights
    # here, used this dummy placeholder during fine_tune
    # thats for saving image_net weights for futher usage
    image_f_inputs2 = tf.placeholder_with_default(tf.ones([1, 224, 224, 3]),
                                                  shape=[None, 224, 224, 3],
                                                  name='dummy_ps')
    if params.fine_tune:
        image_f_inputs2 = image_batch
    cnn_dropout_keep = tf.placeholder_with_default(1.0, ())
    weights_regularizer = tf.contrib.layers.l2_regularizer(params.weight_decay)
    with tf.variable_scope("cnn", regularizer=weights_regularizer):
        image_embeddings = vgg16(image_f_inputs2,
                                 trainable_fe=params.fine_tune_fe,
                                 trainable_top=params.fine_tune_top,
                                 dropout_keep=cnn_dropout_keep)
    if params.fine_tune:
        features = image_embeddings.fc2
    else:
        features = image_batch
    # forward pass is expensive, so can use this method to reduce computation
    if params.num_captions > 1 and params.mode == 'training':  # [b_s, 4096]
        features_tiled = tf.tile(tf.expand_dims(features, 1),
                                 [1, params.num_captions, 1])
        features = tf.reshape(features_tiled, [
            tf.shape(features)[0] * params.num_captions,
            params.cnn_feature_size
        ])  # [5 * b_s, 4096]
    # dictionary
    cap_dict = data.dictionary
    params.vocab_size = cap_dict.vocab_size
    # image features [b_size + f_size(4096)] -> [b_size + embed_size]
    images_fv = layers.dense(features, params.embed_size, name='imf_emb')
    # decoder, input_fv, get x, x_logits (for generation)
    decoder = Decoder(images_fv, cap_dec, cap_len, params, cap_dict)
    if params.use_c_v:
        # cluster vectors from "Diverse and Accurate Image Description.." paper.
        c_i_emb = layers.dense(cl_vectors, params.embed_size, name='cv_emb')
        # map cluster vectors into embedding space
        decoder.c_i = c_i_emb
        decoder.c_i_ph = cl_vectors
    with tf.variable_scope("decoder"):
        x_logits, _ = decoder.decoder()
    # calculate rec. loss, mask padded part
    labels_flat = tf.reshape(cap_enc, [-1])
    ce_loss_padded = tf.nn.sparse_softmax_cross_entropy_with_logits(
        logits=x_logits, labels=labels_flat)
    loss_mask = tf.sign(tf.to_float(labels_flat))
    batch_loss = tf.div(tf.reduce_sum(tf.multiply(ce_loss_padded, loss_mask)),
                        tf.reduce_sum(loss_mask),
                        name="batch_loss")
    tf.losses.add_loss(batch_loss)
    rec_loss = tf.losses.get_total_loss()
    # overall loss reconstruction loss - kl_regularization
    # optimization, can print global norm for debugging
    optimize, global_step, global_norm = optimizers.non_cnn_optimizer(
        rec_loss, params)
    optimize_cnn = tf.constant(0.0)
    if params.fine_tune and params.mode == 'training':
        optimize_cnn, _ = optimizers.cnn_optimizer(rec_loss, params)
    # cnn parameters update
    # model restore
    vars_to_save = tf.trainable_variables()
    if not params.fine_tune_fe or not params.fine_tune_top:
        cnn_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, 'cnn')
        vars_to_save += cnn_vars
    saver = tf.train.Saver(vars_to_save,
                           max_to_keep=params.max_checkpoints_to_keep)
    # m_builder = tf.saved_model.builder.SavedModelBuilder('./saved_model')
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        sess.run([
            tf.global_variables_initializer(),
            tf.local_variables_initializer()
        ])
        # train using batch generator, every iteration get
        # f(I), [batch_size, max_seq_len], seq_lengths
        if params.mode == "training":
            if params.logging:
                summary_writer = tf.summary.FileWriter(params.LOG_DIR,
                                                       sess.graph)
                summary_writer.add_graph(sess.graph)
            if not params.restore:
                print("Loading imagenet weights for futher usage")
                image_embeddings.load_weights(params.image_net_weights_path,
                                              sess)
            if params.restore:
                print("Restoring from checkpoint")
                saver.restore(
                    sess, "./checkpoints/{}.ckpt".format(params.checkpoint))
            for e in range(params.num_epochs):
                gs = tf.train.global_step(sess, global_step)
                gs_epoch = 0
                while True:

                    def stop_condition():
                        num_examples = gs_epoch * params.batch_size
                        if num_examples > params.num_ex_per_epoch:
                            return True
                        return False
                    for f_images_batch,\
                    captions_batch, cl_batch, c_v in batch_gen.next_batch(
                        use_obj_vectors=params.use_c_v,
                        num_captions=params.num_captions):
                        if params.num_captions > 1:
                            captions_batch, cl_batch, c_v = preprocess_captions(
                                captions_batch, cl_batch, c_v)
                        feed = {
                            image_batch: f_images_batch,
                            cap_enc: captions_batch[1],
                            cap_dec: captions_batch[0],
                            cap_len: cl_batch,
                            cnn_dropout_keep: params.cnn_dropout
                        }
                        if params.use_c_v:
                            feed.update({c_i: c_v[:, 1:]})
                        gs = tf.train.global_step(sess, global_step)
                        # print(sess.run(debug_print, feed))
                        total_loss_, _, _ = sess.run(
                            [rec_loss, optimize, optimize_cnn], feed)
                        gs_epoch += 1
                        if gs % 500 == 0:
                            print(
                                "Iteraton: {} Total training loss: {} ".format(
                                    gs, total_loss_))
                        if stop_condition():
                            break
                    if stop_condition():
                        break
                print("Epoch: {} training loss {}".format(e, total_loss_))

                def validate():
                    val_rec = []
                    for f_images_batch, captions_batch, cl_batch, c_v in val_gen.next_batch(
                            use_obj_vectors=params.use_c_v,
                            num_captions=params.num_captions):
                        gs = tf.train.global_step(sess, global_step)
                        if params.num_captions > 1:
                            captions_batch, cl_batch, c_v = preprocess_captions(
                                captions_batch, cl_batch, c_v)
                        feed = {
                            image_batch: f_images_batch,
                            cap_enc: captions_batch[1],
                            cap_dec: captions_batch[0],
                            cap_len: cl_batch
                        }
                        if params.use_c_v:
                            feed.update({c_i: c_v[:, 1:]})
                        rl = sess.run([rec_loss], feed_dict=feed)
                        val_rec.append(rl)
                    print("Validation loss: {}".format(np.mean(val_rec)))
                    print("-----------------------------------------------")

                validate()
                # save model
                if not os.path.exists("./checkpoints"):
                    os.makedirs("./checkpoints")
                save_path = saver.save(
                    sess, "./checkpoints/{}.ckpt".format(params.checkpoint))
                print("Model saved in file: %s" % save_path)
        # builder.add_meta_graph_and_variables(sess, ["main_model"])
        if params.use_hdf5 and params.fine_tune:
            batch_gen.h5f.close()
        # run inference
        if params.mode == "inference":
            inference.inference(params, decoder, val_gen, test_gen,
                                image_batch, saver, sess)