def generate_caption(self, img_path, beam_size=2):
     """Caption generator
     Args:
         image_path: path to the image
     Returns:
         caption: caption, generated for a given image
     """
     # TODO: to avoid specify model again use frozen graph
     g = tf.Graph()
     # change some Parameters
     self.params.sample_gen = self.gen_method
     try:
         os.path.exists(img_path)
     except:
         raise ValueError("Image not found")
     with g.as_default():
         # specify rnn_placeholders
         ann_lengths_ps = tf.placeholder(tf.int32, [None])
         captions_ps = tf.placeholder(tf.int32, [None, None])
         images_ps = tf.placeholder(tf.float32, [None, 224, 224, 3])
         with tf.variable_scope("cnn"):
             image_embeddings = vgg16(images_ps, trainable_fe=True,
                                      trainable_top=True)
         features = image_embeddings.fc2
         # image fesatures [b_size + f_size(4096)] -> [b_size + embed_size]
         images_fv = tf.layers.dense(features, self.params.embed_size,
                                     name='imf_emb')
         # will use methods from Decoder class
         decoder = Decoder(images_fv, captions_ps,
                           ann_lengths_ps, self.params, self.data_dict)
         with tf.variable_scope("decoder"):
             _, _ = decoder.decoder()  # for initialization
         # if use cluster vectors
         if self.params.use_c_v:
             # cluster vectors from "Diverse and Accurate Image Description.." paper.
             c_i = tf.placeholder(tf.float32, [None, 90])
             c_i_emb = tf.layers.dense(c_i, self.params.embed_size,
                                       name='cv_emb')
             # map cluster vectors into embedding space
             decoder.c_i = c_i_emb
             decoder.c_i_ph = c_i
         # image_id
         im_id = [img_path.split('/')[-1]]
         saver = tf.train.Saver(tf.trainable_variables())
     with tf.Session(graph=g) as sess:
         saver.restore(sess, self.checkpoint_path)
         if self.params.use_c_v:
             c_v = self._c_v_generator(image)
         else:
             c_v = None
         im_shape = (224, 224) # VGG16
         image = np.expand_dims(load_image(img_path, im_shape), 0)
         if self.gen_method == 'beam_search':
             sent = decoder.beam_search(sess, im_id, image,
                                        images_ps, c_v,
                                        beam_size=beam_size)
         elif self.gen_method == 'greedy':
             sent, _ = decoder.online_inference(sess, im_id, image,
                                                images_ps, c_v=c_v)
         return sent
Exemplo n.º 2
0
 def _build_feature_graph(self):
     im_embed = tf.Graph()
     with im_embed.as_default():
         self.input_img = tf.placeholder(tf.float32, [224, 224, 3])
         input_image = tf.expand_dims(self.input_img, 0)
         if self._img_embed_ == "resnet":
             image_embeddings = ResNet(50)
             is_training = tf.constant(False)
             features = image_embeddings(input_image, is_training)
             self._resnet_saver = tf.train.Saver()
         else:
             image_embeddings = vgg16(input_image)
             features = image_embeddings.fc2
     return im_embed, features, image_embeddings
Exemplo n.º 3
0
 def extract_features_from_dir(self,
                               data_dir,
                               save_pickle=True,
                               im_shape=(224, 224)):
     """
     Args:
         data_dir: image data directory
         model: tf.contrib.keras model, CNN, used for feature extraction
         save_pickle: bool, will serialize feature_dict and save it into
     ./pickle directory
         im_shape: desired images shape
     Returns:
         feature_dict: dictionary of the form {image_name: feature_vector}
     """
     feature_dict = {}
     try:
         with open("./pickles/" + data_dir.split('/')[-2] + '.pickle',
                   'rb') as rf:
             print("Loading prepared feature vector from {}".format(
                 "./pickles/" + data_dir.split('/')[-2] + '.pickle'))
             feature_dict = pickle.load(rf)
     except:
         print("Extracting features")
         if not os.path.exists("./pickles"):
             os.makedirs("./pickles")
         im_embed = tf.Graph()
         with im_embed.as_default():
             input_img = tf.placeholder(tf.float32,
                                        [None, im_shape[0], im_shape[1], 3])
             image_embeddings = vgg16(input_img)
             features = image_embeddings.fc2
             config = tf.ConfigProto()
             config.gpu_options.allow_growth = True
         with tf.Session(graph=im_embed) as sess:
             image_embeddings.load_weights(self.weights_path, sess)
             for img_path in tqdm(glob(data_dir + '*.jpg')):
                 img = load_image(img_path)
                 img = np.expand_dims(img, axis=0)
                 f_vector = sess.run(features, {input_img: img})
                 # ex. COCO_val2014_0000000XXXXX.jpg
                 feature_dict[img_path.split('/')[-1]] = f_vector
         if save_pickle:
             with open("./pickles/" + data_dir.split('/')[-2] + '.pickle',
                       'wb') as wf:
                 pickle.dump(feature_dict, wf)
     return feature_dict
Exemplo n.º 4
0
def main(params):
    # load data, class data contains captions, images, image features (if avaliable)
    if params.gen_val_captions < 0:
        repartiton = False
    else:
        repartiton = True
    data = Data(params,
                True,
                params.image_net_weights_path,
                repartiton=repartiton,
                gen_val_cap=params.gen_val_captions)
    # load batch generator, repartiton to use more val set images in train
    gen_batch_size = params.batch_size
    if params.fine_tune:
        gen_batch_size = params.batch_size
    batch_gen = data.load_train_data_generator(gen_batch_size,
                                               params.fine_tune)
    # whether use presaved pretrained imagenet features (saved in pickle)
    # feature extractor after fine_tune will be saved in tf checkpoint
    # caption generation after fine_tune must be made with params.fine_tune=True
    pretrained = not params.fine_tune
    val_gen = data.get_valid_data(gen_batch_size,
                                  val_tr_unused=batch_gen.unused_cap_in,
                                  pretrained=pretrained)
    test_gen = data.get_test_data(gen_batch_size, pretrained=pretrained)
    # annotations vector of form <EOS>...<BOS><PAD>...
    ann_inputs_enc = tf.placeholder(tf.int32, [None, None])
    ann_inputs_dec = tf.placeholder(tf.int32, [None, None])
    ann_lengths = tf.placeholder(tf.int32, [None])
    if params.fine_tune:
        # if fine_tune dont just load images_fv
        image_f_inputs = tf.placeholder(tf.float32, [None, 224, 224, 3])
    else:
        # use prepared image features [batch_size, 4096] (fc2)
        image_f_inputs = tf.placeholder(tf.float32, [None, 4096])
    if params.use_c_v or (params.prior == 'GMM' or params.prior == 'AG'):
        c_i = tf.placeholder(tf.float32, [None, 90])
    else:
        c_i = ann_lengths  # dummy tensor
    # because of past changes
    image_batch, cap_enc, cap_dec, cap_len, cl_vectors = image_f_inputs,\
    ann_inputs_enc, ann_inputs_dec, ann_lengths, c_i
    # features, params.fine_tune stands for not using presaved imagenet weights
    # here, used this dummy placeholder during fine_tune, will remove it in
    # future releases, thats for saving image_net weights for futher usage
    image_f_inputs2 = tf.placeholder_with_default(tf.ones([1, 224, 224, 3]),
                                                  shape=[None, 224, 224, 3],
                                                  name='dummy_ps')
    if params.fine_tune:
        image_f_inputs2 = image_batch
    if params.mode == 'training' and params.fine_tune:
        cnn_dropout = params.cnn_dropout
        weights_regularizer = tf.contrib.layers.l2_regularizer(
            params.weight_decay)
    else:
        cnn_dropout = 1.0
        weights_regularizer = None
    with tf.variable_scope("cnn", regularizer=weights_regularizer):
        image_embeddings = vgg16(image_f_inputs2,
                                 trainable_fe=params.fine_tune_fe,
                                 trainable_top=params.fine_tune_top,
                                 dropout_keep=cnn_dropout)
    if params.fine_tune:
        features = image_embeddings.fc2
    else:
        features = image_batch
    # forward pass is expensive, so can use this method to reduce computation
    if params.num_captions > 1 and params.mode == 'training':  # [b_s, 4096]
        features_tiled = tf.tile(tf.expand_dims(features, 1),
                                 [1, params.num_captions, 1])
        features = tf.reshape(features_tiled, [
            tf.shape(features)[0] * params.num_captions,
            params.cnn_feature_size
        ])  # [5 * b_s, 4096]
    # dictionary
    cap_dict = data.dictionary
    params.vocab_size = cap_dict.vocab_size
    # image features [b_size + f_size(4096)] -> [b_size + embed_size]
    images_fv = layers.dense(features, params.embed_size, name='imf_emb')
    # images_fv = tf.Print(images_fv, [tf.shape(features), features[0][0:10],
    #                                   image_embeddings.imgs[0][:10], images_fv])
    # encoder, input fv and ...<BOS>,get z
    if not params.no_encoder:
        encoder = Encoder(images_fv, cap_enc, cap_len, params)
    # decoder, input_fv, get x, x_logits (for generation)
    decoder = Decoder(images_fv, cap_dec, cap_len, params, cap_dict)
    if params.use_c_v or (params.prior == 'GMM' or params.prior == 'AG'):
        # cluster vectors from "Diverse and Accurate Image Description.." paper.
        # 80 is number of classes, for now hardcoded
        # for GMM-CVAE must be specified
        c_i_emb = layers.dense(cl_vectors, params.embed_size, name='cv_emb')
        # map cluster vectors into embedding space
        decoder.c_i = c_i_emb
        decoder.c_i_ph = cl_vectors
        if not params.no_encoder:
            encoder.c_i = c_i_emb
            encoder.c_i_ph = cl_vectors
    if not params.no_encoder:
        with tf.variable_scope("encoder"):
            qz, tm_list, tv_list = encoder.q_net()
        if params.prior == 'Normal':
            # kld between normal distributions KL(q, p), see Kingma et.al
            kld = -0.5 * tf.reduce_mean(
                tf.reduce_sum(
                    1 + tf.log(tf.square(qz.distribution.std) + 0.00001) -
                    tf.square(qz.distribution.mean) -
                    tf.square(qz.distribution.std), 1))
        elif params.prior == 'GMM':
            # initialize sigma as constant, mu drawn randomly
            # TODO: finish GMM loss implementation
            c_means, c_sigma = init_clusters(params.num_clusters,
                                             params.latent_size)
            decoder.cap_clusters = c_means
            kld = -0.5 * tf.reduce_mean(
                tf.reduce_sum(
                    1 + tf.log(tf.square(qz.distribution.std) + 0.00001) -
                    tf.square(qz.distribution.mean) -
                    tf.square(qz.distribution.std), 1))
        elif params.prior == 'AG':
            c_means, c_sigma = init_clusters(params.num_clusters,
                                             params.latent_size)
            decoder.cap_clusters = c_means
            kld_clusters = 0.5 + tf.log(qz.distribution.std+ 0.00001)\
             - tf.log(c_sigma + 0.00001) - (
                 tf.square(qz.distribution.mean - tf.matmul(
                     tf.squeeze(c_i), c_means)) + tf.square(
                         qz.distribution.std))/(2*tf.square(c_sigma)+0.0000001)
            kld = -0.5 * tf.reduce_sum(kld_clusters, 1)
    with tf.variable_scope("decoder"):
        if params.no_encoder:
            dec_model, x_logits, shpe, _ = decoder.px_z_fi({})
        else:
            dec_model, x_logits, shpe, _ = decoder.px_z_fi({'z': qz})
    # calculate rec. loss, mask padded part
    labels_flat = tf.reshape(cap_enc, [-1])
    ce_loss_padded = tf.nn.sparse_softmax_cross_entropy_with_logits(
        logits=x_logits, labels=labels_flat)
    loss_mask = tf.sign(tf.to_float(labels_flat))
    batch_loss = tf.div(tf.reduce_sum(tf.multiply(ce_loss_padded, loss_mask)),
                        tf.reduce_sum(loss_mask),
                        name="batch_loss")
    tf.losses.add_loss(batch_loss)
    rec_loss = tf.losses.get_total_loss()
    # kld weight annealing
    anneal = tf.placeholder_with_default(0, [])
    if params.fine_tune or params.restore:
        annealing = tf.constant(1.0)
    else:
        if params.ann_param > 1:
            annealing = (tf.tanh(
                (tf.to_float(anneal) - 1000 * params.ann_param) / 1000) +
                         1) / 2
        else:
            annealing = tf.constant(1.0)
    # overall loss reconstruction loss - kl_regularization
    if not params.no_encoder:
        lower_bound = rec_loss + tf.multiply(tf.to_float(annealing),
                                             tf.to_float(kld)) / 10
    else:
        lower_bound = rec_loss
        kld = tf.constant(0.0)
    # optimization, can print global norm for debugging
    optimize, global_step, global_norm = optimizers.non_cnn_optimizer(
        lower_bound, params)
    optimize_cnn = tf.constant(0.0)
    if params.fine_tune and params.mode == 'training':
        optimize_cnn, _ = optimizers.cnn_optimizer(lower_bound, params)
    # cnn parameters update
    # model restore
    vars_to_save = tf.trainable_variables()
    if not params.fine_tune_fe or not params.fine_tune_top:
        cnn_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, 'cnn')
        vars_to_save += cnn_vars
    saver = tf.train.Saver(vars_to_save,
                           max_to_keep=params.max_checkpoints_to_keep)
    # m_builder = tf.saved_model.builder.SavedModelBuilder('./saved_model')
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        sess.run([
            tf.global_variables_initializer(),
            tf.local_variables_initializer()
        ])
        # train using batch generator, every iteration get
        # f(I), [batch_size, max_seq_len], seq_lengths
        if params.mode == "training":
            if params.logging:
                summary_writer = tf.summary.FileWriter(params.LOG_DIR,
                                                       sess.graph)
                summary_writer.add_graph(sess.graph)
            if not params.restore:
                print("Loading imagenet weights for futher usage")
                image_embeddings.load_weights(params.image_net_weights_path,
                                              sess)
            if params.restore:
                print("Restoring from checkpoint")
                saver.restore(
                    sess, "./checkpoints/{}.ckpt".format(params.checkpoint))
            for e in range(params.num_epochs):
                gs = tf.train.global_step(sess, global_step)
                gs_epoch = 0
                while True:

                    def stop_condition():
                        num_examples = gs_epoch * params.batch_size
                        if num_examples > params.num_ex_per_epoch:
                            return True
                        return False
                    for f_images_batch,\
                    captions_batch, cl_batch, c_v in batch_gen.next_batch(
                        use_obj_vectors=params.use_c_v,
                        num_captions=params.num_captions):
                        if params.num_captions > 1:
                            captions_batch, cl_batch, c_v = preprocess_captions(
                                captions_batch, cl_batch, c_v)
                        feed = {
                            image_f_inputs: f_images_batch,
                            ann_inputs_enc: captions_batch[1],
                            ann_inputs_dec: captions_batch[0],
                            ann_lengths: cl_batch,
                            anneal: gs
                        }
                        if params.use_c_v or (params.prior == 'GMM'
                                              or params.prior == 'AG'):
                            feed.update({c_i: c_v[:, 1:]})
                        gs = tf.train.global_step(sess, global_step)
                        feed.update({anneal: gs})
                        # if gs_epoch == 0:
                        # print(sess.run(debug_print, feed))
                        kl, rl, lb, _, _, ann = sess.run([
                            kld, rec_loss, lower_bound, optimize, optimize_cnn,
                            annealing
                        ], feed)
                        gs_epoch += 1
                        if gs % 500 == 0:
                            print("Epoch: {} Iteration: {} VLB: {} "
                                  "Rec Loss: {}".format(
                                      e, gs, np.mean(lb), rl))
                            if not params.no_encoder:
                                print("Annealing coefficient:"
                                      "{} KLD: {}".format(ann, np.mean(kl)))
                        if stop_condition():
                            break
                    if stop_condition():
                        break
                print("Epoch: {} Iteration: {} VLB: {} Rec Loss: {}".format(
                    e,
                    gs,
                    np.mean(lb),
                    rl,
                ))

                def validate():
                    val_rec = []
                    for f_images_batch, captions_batch, cl_batch, c_v in val_gen.next_batch(
                            use_obj_vectors=params.use_c_v,
                            num_captions=params.num_captions):
                        gs = tf.train.global_step(sess, global_step)
                        if params.num_captions > 1:
                            captions_batch, cl_batch, c_v = preprocess_captions(
                                captions_batch, cl_batch, c_v)
                        feed = {
                            image_f_inputs: f_images_batch,
                            ann_inputs_enc: captions_batch[1],
                            ann_inputs_dec: captions_batch[0],
                            ann_lengths: cl_batch,
                            anneal: gs
                        }
                        if params.use_c_v or (params.prior == 'GMM'
                                              or params.prior == 'AG'):
                            feed.update({c_i: c_v[:, 1:]})
                        rl = sess.run([rec_loss], feed_dict=feed)
                        val_rec.append(rl)
                    print("Validation reconstruction loss: {}".format(
                        np.mean(val_rec)))
                    print("-----------------------------------------------")

                validate()
                # save model
                if not os.path.exists("./checkpoints"):
                    os.makedirs("./checkpoints")
                save_path = saver.save(
                    sess, "./checkpoints/{}.ckpt".format(params.checkpoint))
                print("Model saved in file: %s" % save_path)
        # builder.add_meta_graph_and_variables(sess, ["main_model"])
        if params.use_hdf5 and params.fine_tune:
            batch_gen.h5f.close()
        # run inference
        if params.mode == "inference":
            inference.inference(params, decoder, val_gen, test_gen,
                                image_f_inputs, saver, sess)
Exemplo n.º 5
0
 def _extract_features_from_dir(self,
                                save_pickle=True,
                                im_shape=(224, 224)):
     """
     Args:
         data_dir: image data directory
         save_pickle: bool, will serialize feature_dict and save it into
     ./pickle directory
         im_shape: desired images shape
     Returns:
         feature_dict: dictionary of the form {image_name: feature_vector}
     """
     feature_dict = {}
     # Impaths in form folder/imname.jpg
     # ex. val2014/COCO_val2014_0000000XXXXX.jpg
     im_paths = list(self.train_captions.keys()) + list(
         self.val_captions.keys()) + list(self.test_captions.keys())
     if self.img_embed == "resnet":
         embed_file = os.path.join("./pickles/", "sc_img_embed_res.pickle")
     else:
         embed_file = os.path.join("./pickles/", "sc_img_embed_vgg.pickle")
     try:
         with open(embed_file, 'rb') as rf:
             print("Loading prepared feature vector from {}".format(
                 embed_file))
             feature_dict = pickle.load(rf)
     except:
         print("Extracting features")
         if not os.path.exists("./pickles"):
             os.makedirs("./pickles")
         im_embed = tf.Graph()
         with im_embed.as_default():
             input_img = tf.placeholder(tf.float32,
                                        [None, im_shape[0], im_shape[1], 3])
             if self.img_embed == "resnet":
                 image_embeddings = ResNet(50)
                 is_training = tf.constant(False)
                 features = image_embeddings(input_img, is_training)
                 saver = tf.train.Saver()
             else:
                 image_embeddings = vgg16(input_img)
                 features = image_embeddings.fc2
             # gpu_options = tf.GPUOptions(
             #     visible_device_list=self.params.gpu,
             #     allow_growth=True)
             config = tf.ConfigProto()
         with tf.Session(graph=im_embed, config=config) as sess:
             if self.img_embed == "resnet":
                 print("loading resnet weights")
                 saver.restore(sess, self.weights_path)
             else:
                 print("loading vgg16 imagenet weights")
                 image_embeddings.load_weights(self.weights_path, sess)
             for img_path in tqdm(im_paths):
                 # Get the full path
                 img_path_ = os.path.join(self.images_dir, img_path)
                 img = load_image(img_path_)
                 img = np.expand_dims(img, axis=0)
                 f_vector = sess.run(features, {input_img: img})
                 feature_dict[img_path] = f_vector
         if save_pickle:
             with open(embed_file, 'wb') as wf:
                 pickle.dump(feature_dict, wf)
     return feature_dict
Exemplo n.º 6
0
def main(params):
    # load data, class data contains captions, images, image features (if avaliable)
    if params.gen_val_captions < 0:
        repartiton = False
    else:
        repartiton = True
    data = Data(params,
                True,
                params.image_net_weights_path,
                repartiton=repartiton,
                gen_val_cap=params.gen_val_captions)
    # load batch generator, repartiton to use more val set images in train
    gen_batch_size = params.batch_size
    if params.fine_tune:
        gen_batch_size = params.batch_size
    batch_gen = data.load_train_data_generator(gen_batch_size,
                                               params.fine_tune)
    # whether use presaved pretrained imagenet features (saved in pickle)
    # feature extractor after fine_tune will be saved in tf checkpoint
    # caption generation after fine_tune must be made with params.fine_tune=True
    pretrained = not params.fine_tune
    val_gen = data.get_valid_data(gen_batch_size,
                                  val_tr_unused=batch_gen.unused_cap_in,
                                  pretrained=pretrained)
    test_gen = data.get_test_data(gen_batch_size, pretrained=pretrained)
    # annotations vector of form <EOS>...<BOS><PAD>...
    cap_enc = tf.placeholder(tf.int32, [None, None])
    cap_dec = tf.placeholder(tf.int32, [None, None])
    cap_len = tf.placeholder(tf.int32, [None])
    if params.fine_tune:
        # if fine_tune dont just load images_fv
        image_batch = tf.placeholder(tf.float32, [None, 224, 224, 3])
    else:
        # use prepared image features [batch_size, 4096] (fc2)
        image_batch = tf.placeholder(tf.float32, [None, 4096])
    if params.use_c_v:
        cl_vectors = tf.placeholder(tf.float32, [None, 90])
    # features, params.fine_tune stands for not using presaved imagenet weights
    # here, used this dummy placeholder during fine_tune
    # thats for saving image_net weights for futher usage
    image_f_inputs2 = tf.placeholder_with_default(tf.ones([1, 224, 224, 3]),
                                                  shape=[None, 224, 224, 3],
                                                  name='dummy_ps')
    if params.fine_tune:
        image_f_inputs2 = image_batch
    cnn_dropout_keep = tf.placeholder_with_default(1.0, ())
    weights_regularizer = tf.contrib.layers.l2_regularizer(params.weight_decay)
    with tf.variable_scope("cnn", regularizer=weights_regularizer):
        image_embeddings = vgg16(image_f_inputs2,
                                 trainable_fe=params.fine_tune_fe,
                                 trainable_top=params.fine_tune_top,
                                 dropout_keep=cnn_dropout_keep)
    if params.fine_tune:
        features = image_embeddings.fc2
    else:
        features = image_batch
    # forward pass is expensive, so can use this method to reduce computation
    if params.num_captions > 1 and params.mode == 'training':  # [b_s, 4096]
        features_tiled = tf.tile(tf.expand_dims(features, 1),
                                 [1, params.num_captions, 1])
        features = tf.reshape(features_tiled, [
            tf.shape(features)[0] * params.num_captions,
            params.cnn_feature_size
        ])  # [5 * b_s, 4096]
    # dictionary
    cap_dict = data.dictionary
    params.vocab_size = cap_dict.vocab_size
    # image features [b_size + f_size(4096)] -> [b_size + embed_size]
    images_fv = layers.dense(features, params.embed_size, name='imf_emb')
    # decoder, input_fv, get x, x_logits (for generation)
    decoder = Decoder(images_fv, cap_dec, cap_len, params, cap_dict)
    if params.use_c_v:
        # cluster vectors from "Diverse and Accurate Image Description.." paper.
        c_i_emb = layers.dense(cl_vectors, params.embed_size, name='cv_emb')
        # map cluster vectors into embedding space
        decoder.c_i = c_i_emb
        decoder.c_i_ph = cl_vectors
    with tf.variable_scope("decoder"):
        x_logits, _ = decoder.decoder()
    # calculate rec. loss, mask padded part
    labels_flat = tf.reshape(cap_enc, [-1])
    ce_loss_padded = tf.nn.sparse_softmax_cross_entropy_with_logits(
        logits=x_logits, labels=labels_flat)
    loss_mask = tf.sign(tf.to_float(labels_flat))
    batch_loss = tf.div(tf.reduce_sum(tf.multiply(ce_loss_padded, loss_mask)),
                        tf.reduce_sum(loss_mask),
                        name="batch_loss")
    tf.losses.add_loss(batch_loss)
    rec_loss = tf.losses.get_total_loss()
    # overall loss reconstruction loss - kl_regularization
    # optimization, can print global norm for debugging
    optimize, global_step, global_norm = optimizers.non_cnn_optimizer(
        rec_loss, params)
    optimize_cnn = tf.constant(0.0)
    if params.fine_tune and params.mode == 'training':
        optimize_cnn, _ = optimizers.cnn_optimizer(rec_loss, params)
    # cnn parameters update
    # model restore
    vars_to_save = tf.trainable_variables()
    if not params.fine_tune_fe or not params.fine_tune_top:
        cnn_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, 'cnn')
        vars_to_save += cnn_vars
    saver = tf.train.Saver(vars_to_save,
                           max_to_keep=params.max_checkpoints_to_keep)
    # m_builder = tf.saved_model.builder.SavedModelBuilder('./saved_model')
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        sess.run([
            tf.global_variables_initializer(),
            tf.local_variables_initializer()
        ])
        # train using batch generator, every iteration get
        # f(I), [batch_size, max_seq_len], seq_lengths
        if params.mode == "training":
            if params.logging:
                summary_writer = tf.summary.FileWriter(params.LOG_DIR,
                                                       sess.graph)
                summary_writer.add_graph(sess.graph)
            if not params.restore:
                print("Loading imagenet weights for futher usage")
                image_embeddings.load_weights(params.image_net_weights_path,
                                              sess)
            if params.restore:
                print("Restoring from checkpoint")
                saver.restore(
                    sess, "./checkpoints/{}.ckpt".format(params.checkpoint))
            for e in range(params.num_epochs):
                gs = tf.train.global_step(sess, global_step)
                gs_epoch = 0
                while True:

                    def stop_condition():
                        num_examples = gs_epoch * params.batch_size
                        if num_examples > params.num_ex_per_epoch:
                            return True
                        return False
                    for f_images_batch,\
                    captions_batch, cl_batch, c_v in batch_gen.next_batch(
                        use_obj_vectors=params.use_c_v,
                        num_captions=params.num_captions):
                        if params.num_captions > 1:
                            captions_batch, cl_batch, c_v = preprocess_captions(
                                captions_batch, cl_batch, c_v)
                        feed = {
                            image_batch: f_images_batch,
                            cap_enc: captions_batch[1],
                            cap_dec: captions_batch[0],
                            cap_len: cl_batch,
                            cnn_dropout_keep: params.cnn_dropout
                        }
                        if params.use_c_v:
                            feed.update({c_i: c_v[:, 1:]})
                        gs = tf.train.global_step(sess, global_step)
                        # print(sess.run(debug_print, feed))
                        total_loss_, _, _ = sess.run(
                            [rec_loss, optimize, optimize_cnn], feed)
                        gs_epoch += 1
                        if gs % 500 == 0:
                            print(
                                "Iteraton: {} Total training loss: {} ".format(
                                    gs, total_loss_))
                        if stop_condition():
                            break
                    if stop_condition():
                        break
                print("Epoch: {} training loss {}".format(e, total_loss_))

                def validate():
                    val_rec = []
                    for f_images_batch, captions_batch, cl_batch, c_v in val_gen.next_batch(
                            use_obj_vectors=params.use_c_v,
                            num_captions=params.num_captions):
                        gs = tf.train.global_step(sess, global_step)
                        if params.num_captions > 1:
                            captions_batch, cl_batch, c_v = preprocess_captions(
                                captions_batch, cl_batch, c_v)
                        feed = {
                            image_batch: f_images_batch,
                            cap_enc: captions_batch[1],
                            cap_dec: captions_batch[0],
                            cap_len: cl_batch
                        }
                        if params.use_c_v:
                            feed.update({c_i: c_v[:, 1:]})
                        rl = sess.run([rec_loss], feed_dict=feed)
                        val_rec.append(rl)
                    print("Validation loss: {}".format(np.mean(val_rec)))
                    print("-----------------------------------------------")

                validate()
                # save model
                if not os.path.exists("./checkpoints"):
                    os.makedirs("./checkpoints")
                save_path = saver.save(
                    sess, "./checkpoints/{}.ckpt".format(params.checkpoint))
                print("Model saved in file: %s" % save_path)
        # builder.add_meta_graph_and_variables(sess, ["main_model"])
        if params.use_hdf5 and params.fine_tune:
            batch_gen.h5f.close()
        # run inference
        if params.mode == "inference":
            inference.inference(params, decoder, val_gen, test_gen,
                                image_batch, saver, sess)