def beam_search(self, v_feat, a_feat, beam_size=3, max_caption_length=15, length_normalization_factor=0.0,
                 include_unknown=False, sentence_word=False):
     batch_size = v_feat.size(0)
     assert batch_size == 1, 'Currently, the beam search only support batch_size == 1'
     input_x = torch.cat([a_feat, v_feat], 1).unsqueeze(1)
     _, hidden_feat = self.rnn(input_x)  # initialize the LSTM
     x = Variable(torch.ones(batch_size, 1, ).type(torch.LongTensor) * self.start, requires_grad=False).cuda()  # <start>
     embeddings = self.embedder(x)
     input_x = torch.cat([embeddings, v_feat.unsqueeze(1)], 2)
     cap_gen = CaptionGenerator(embedder=self.embedder,
                                rnn=self.rnn,
                                classifier=self.classifier,
                                eos_id=self.end,
                                include_unknown=include_unknown,
                                unk_id=self.unk,
                                beam_size=beam_size,
                                max_caption_length=max_caption_length,
                                length_normalization_factor=length_normalization_factor,
                                batch_first=True)
     sentences, score = cap_gen.beam_search(input_x, v_feat.unsqueeze(1), hidden_feat)
     sentence = sentences[0]
     sentence_word_ = [str(self.vocab_i2t[int(idx.cpu().numpy())]) for idx in sentence]
     while 'EOS' in sentence_word_:
         sentence_word_.remove('EOS')
     sentence = [' '.join(sentence_word_) + '.']
     if sentence_word:
         return sentence, sentence_word_
     else:
         return sentence
def main():
    """
        run inference model on test set.
    """
    idx = h5py.File('./data/test/test_idx.h5')['labels']
    data = h5py.File('./data/test/test_caption.h5')
    model = HierarchicalModel(config, mode='inference')
    model.build()

    images = data['images']
    first_level_label_start_ix = data['first_layer_label_start_ix']
    first_level_label_end_ix = data['first_layer_label_end_ix']
    first_level_labels = data['first_layer_labels']
    second_level_label_start_ix = data['label_start_ix']
    second_level_label_end_ix = data['label_end_ix']
    # second_level_label_pos = data['label_position']
    second_level_labels = data['labels']
    generator = CaptionGenerator(model,
                                 model.level1_word2ix,
                                 model.level2_word2ix,
                                 beam_size_1level=5,
                                 beam_size_2level=2,
                                 encourage_1level=0.1,
                                 encourage_2level=0.9)

    result = []
    config_ = tf.ConfigProto(allow_soft_placement=True)
    config_.gpu_options.per_process_gpu_memory_fraction = 0.9
    config_.gpu_options.allow_growth = True
    with tf.Session(config=config_) as sess:
        tf.global_variables_initializer().run()

        for i in range(images.shape[0]):
            print('***************')
            # images_batch contains only one image.
            images_batch = images[i:i + 1, :, :, :]
            images_batch = crop_image(images_batch, False)
            prediction = generator.beam_search(sess, images_batch)
            print(i, idx[i], prediction)

            first_level_this = first_level_labels[
                first_level_label_start_ix[i]:first_level_label_end_ix[i]]
            second_level_this = []
            for j in range(first_level_label_start_ix[i],
                           first_level_label_end_ix[i]):
                second_level_this.append(
                    second_level_labels[second_level_label_start_ix[j]:
                                        second_level_label_end_ix[j]])
            decoded = decode_captions_2level(first_level_this,
                                             second_level_this,
                                             model.level1_model.idx_to_word,
                                             model.level2_model.idx_to_word)
            print(decoded)

            result.append({'image_id': int(idx[i]), 'caption': prediction})
    json.dump(result,
              open('./data/test/result_resnet_tf.json', 'w', encoding='utf8'))
    def generate(self,
                 img,
                 scale_size=256,
                 crop_size=224,
                 eos_token='EOS',
                 beam_size=3,
                 max_caption_length=20,
                 length_normalization_factor=0.0):

        preproc = [
            transforms.Scale(scale_size),
            transforms.CenterCrop(crop_size),
            transforms.ToTensor(),
            transforms.Normalize(**normalize_values)
        ]

        if torch.is_tensor(img):
            preproc = [transforms.ToPILImage()] + preproc

        img_transform = transforms.Compose(preproc)
        cap_gen = CaptionGenerator(
            embedder=self.embedder,
            rnn=self.rnn,
            classifier=self.classifier,
            eos_id=self.vocab.index(eos_token),
            beam_size=beam_size,
            max_caption_length=max_caption_length,
            length_normalization_factor=length_normalization_factor)
        img = img_transform(img)
        if next(self.parameters()).is_cuda:
            img = img.cuda()
        img = Variable(img.unsqueeze(0), volatile=True)
        img_feats = self.cnn(img).unsqueeze(0)
        sentences, score = cap_gen.beam_search(img_feats)
        sentences = [
            ' '.join([self.vocab[idx] for idx in sent]) for sent in sentences
        ]
        return sentences
def main(argv):
    assert FLAGS.train_dir is not None, "train_dir is required"
    assert FLAGS.resnet_ckpt is not None, "resnet_ckpt is required"

    # data
    print('loading data...')
    (train_stems_list, train_stem_attrs_list, train_images, train_image2stem,
     train_stem2image) = utils.load_coco_data(config.data_root, 'train')
    (val_stems_list, val_stem_attrs_list, val_images, val_image2stem,
     val_stem2image) = utils.load_coco_data(config.data_root, 'val')

    # handling directories
    train_dir = os.path.join(config.model_root, FLAGS.train_dir)
    if not tf.gfile.IsDirectory(train_dir):
        tf.logging.info("Creating training directory: %s", train_dir)
        tf.gfile.MakeDirs(train_dir)

    log_dir = os.path.join(train_dir, 'log')
    if not tf.gfile.IsDirectory(log_dir):
        tf.logging.info("Creating log directory for training: %s", log_dir)
        tf.gfile.MakeDirs(log_dir)

    checkpoint = None
    if FLAGS.checkpoint is not None:
        checkpoint = os.path.join(config.model_root, FLAGS.checkpoint)
        assert os.path.exists(checkpoint), "checkpoint must exists if given."

    # model
    print('building model.')
    model = HierarchicalModel(config, mode=ModeKeys.TRAIN)
    loss = model.build()

    generator = CaptionGenerator(model,
                                 model.level1_word2ix,
                                 None,
                                 beam_size_1level=3,
                                 beam_size_2level=None,
                                 encourage_1level=0.0,
                                 encourage_2level=None,
                                 level2=False)

    # train_op
    with tf.name_scope('optimizer'):
        optimizer = tf.train.AdamOptimizer(learning_rate=config.learning_rate)

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            optim_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                           scope='level1')
            if config.train_resnet:
                optim_vars += tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                                scope='resnet')
            # deriv
            level1_grads = tf.gradients(loss, optim_vars)
            grads_and_vars = [(i, j) for i, j in zip(level1_grads, optim_vars)
                              if i is not None]
            grads_and_vars = [(tf.clip_by_value(grad, -0.1, 0.1), var)
                              for grad, var in grads_and_vars]

            # todo: here check the batch-norm moving average/var
            # if config.train_resnet:
            #     optim_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='resnet')
            #     resnet_grads = tf.gradients(model.resnet.features, optim_vars)
            #     resnet_pairs = [(i, j) for i, j in zip(resnet_grads, optim_vars) if i is not None]
            #     grads_and_vars.extend(resnet_pairs)

            batchnorm_updates = tf.get_collection('resnet_update_ops')
            batchnorm_updates_op = tf.group(*batchnorm_updates)
            apply_gradient_op = optimizer.apply_gradients(
                grads_and_vars=grads_and_vars)
            train_op = tf.group(apply_gradient_op, batchnorm_updates_op)

    # summary op
    print('************************')
    tf.summary.scalar('batch_loss', loss)
    for var in tf.trainable_variables():
        tf.summary.histogram(var.op.name, var)
    # for grad, var in grads_and_vars:
    #     tf.summary.histogram(var.op.name + '/gradient', grad)
    summary_op = tf.summary.merge_all()

    # stats:
    n_examples = len(train_stems_list)
    n_examples_val = len(val_stems_list)
    n_iters_per_epoch = int(np.ceil(float(n_examples) / config.batch_size))
    n_iters_val = int(np.ceil(float(n_examples_val) / config.batch_size))

    print("The number of epoch: %d" % config.n_epochs)
    print("Data size: %d" % n_examples)
    print("Batch size: %d" % config.batch_size)
    print("Iterations per epoch: %d" % n_iters_per_epoch)

    # tf session
    config_ = tf.ConfigProto(allow_soft_placement=True)
    config_.gpu_options.per_process_gpu_memory_fraction = 0.6
    config_.gpu_options.allow_growth = True
    with tf.Session(config=config_) as sess:
        tf.global_variables_initializer().run()
        summary_writer = tf.summary.FileWriter(log_dir,
                                               graph=tf.get_default_graph())
        saver = tf.train.Saver(max_to_keep=40)

        # pretrained
        if checkpoint is not None:
            print("Start training with checkpoint..")
            saver.restore(sess, checkpoint)

        # dynamic stats
        prev_loss_epo = np.inf
        curr_loss_epo = 0
        best_loss_val = np.inf
        curr_loss_val = 0
        i_global = 0

        start_t = time.time()
        for epo in range(config.n_epochs):
            # stochastic batching
            rand_idxs = list(np.random.permutation(n_examples))

            for it in range(n_iters_per_epoch):
                # next batch
                rand_idx = sorted(rand_idxs[it * config.batch_size:(it + 1) *
                                            config.batch_size])
                stems_batch, mask_batch = utils.list2batch(
                    [train_stems_list[i] for i in rand_idx])
                img_idx = train_stem2image[rand_idx]
                img_batch = utils.crop_image(train_images[img_idx], True)
                # print(decode_captions(captions_batch, model.level1_model.idx_to_word))

                feed_dict = {
                    model.level1_model.captions: stems_batch,
                    model.level1_model.mask: mask_batch,
                    model.level1_model.resnet.images: img_batch,
                    model.level1_model.resnet.is_training: config.train_resnet,
                    model.level1_model.keep_prob: 0.5
                }
                _, l = sess.run([train_op, loss], feed_dict)
                curr_loss_epo += l
                # print 'batch norm beta:', sess.run(test1)[:10]
                # print 'batch norm gamma:', sess.run(test2)[:10]

                # global iteration counts
                i_global += 1

                # write summary for tensorboard visualization
                if it % config.log_freq == 0:
                    summary = sess.run(summary_op, feed_dict)
                    summary_writer.add_summary(summary,
                                               epo * n_iters_per_epoch + it)

                # periodical display
                if it % config.print_freq == 0:
                    print(
                        "\nTrain loss at epoch %d & iteration %d (mini-batch): %.5f"
                        % (epo + 1, it + 1, l))
                    ground_truths = stems_batch[0]
                    decoded = utils.decode_captions(
                        ground_truths, model.level1_model.idx_to_word)
                    for j, gt in enumerate(decoded):
                        print("Ground truth %d: %s" % (j + 1, gt))
                        print(ground_truths)

                    predicted = generator.beam_search(sess,
                                                      img_batch[0:1, :, :, :])
                    print("Generated caption: %s\n" % predicted)
                    print('***************')

                # auto save
                if i_global % config.save_freq == 0:
                    saver.save(sess,
                               os.path.join(train_dir,
                                            'model_level1_auto_save'),
                               global_step=i_global)
                    print("model-auto-%s saved." % (i_global))

                # validate
                if i_global % config.valid_freq == 0:
                    cur_loss_val = 0
                    if config.print_bleu:
                        # TODO: some preparation for saving search result.
                        #all_gen_cap = np.ndarray((n_examples_val, 16))
                        pass

                    for it_val in range(n_iters_val):
                        idx_val = np.arange(it_val * config.batch_size,
                                            (it_val + 1) * config.batch_size)
                        stems_batch_val, mask_batch_val = utils.list2batch(
                            [val_stems_list[i] for i in idx_val])
                        img_idx_val = val_stem2image[idx_val]
                        img_batch_val = utils.crop_image(
                            val_images[img_idx_val], False)

                        feed_dict_val = {
                            model.level1_model.captions: stems_batch_val,
                            model.level1_model.mask: mask_batch_val,
                            model.level1_model.resnet.images: img_batch_val,
                            model.level1_model.resnet.is_training: False,
                            model.level1_model.keep_prob: 1.0
                        }
                        curr_loss_val += sess.run(loss, feed_dict_val)

                        if config.print_bleu:
                            # TODO: beam search and evaluate bleu.
                            pass

                    curr_loss_val /= n_iters_val

                    if curr_loss_val < best_loss_val:
                        best_loss_val = cur_loss_val
                        # better model
                        saver.save(sess,
                                   os.path.join(train_dir, 'model_level1_val'),
                                   global_step=i_global)
                        print('model-val-%s saved.' % (i_global))
                    else:
                        # TODO: early stop checking.
                        pass
            # end for(i)
            curr_loss_epo /= n_iters_per_epoch

            # epoch summary:
            print("Previous epoch loss: ", prev_loss_epo)
            print("Current epoch loss: ", curr_loss_epo)
            print("Elapsed time: ", time.time() - start_t)
            prev_loss_epo = curr_loss_epo
            curr_loss_epo = 0

            # save model's parameters
            saver.save(sess,
                       os.path.join(train_dir, 'model_level1_epo'),
                       global_step=epo + 1)
            print("model-epo-%s saved." % (epo + 1))
Exemple #5
0
config = Config()
data = Preprocessor(config)
ctable = CharaterTable(data.train_captions + data.val_captions)
caption_len = FLAGS.caption_len

caption_model = CaptionModel(image_len=data.image_len,
                             caption_len=caption_len,
                             vocab_size=ctable.vocab_size,
                             ifpool=config.ifpool)

model_weights = FLAGS.model_weights
caption_model.build_inference_model(model_weights, beam_search=True)
caption_gen = CaptionGenerator(
    model=caption_model,
    ctable=ctable,
    caption_len=caption_len,
    beam_size=3,  # set beam_search size
    length_normalization_factor=0.5
)  # biger indicate longer sentence will be favored

with codecs.open(FLAGS.save_result, 'w+', 'utf8') as f:
    for id in range(data.test_num):
        # print(id)
        # print(data.test_set[id].shape)
        result = caption_gen.beam_search(data.test_set[id])
        decode = ctable.decode(result[0], calc_argmax=False)
        f.write('{}'.format(9000 + id))
        for word in decode:
            f.write(' ' + word)
        if id % 10 == 0:
            print '[%.2f%%]' % ((id * 100.0) / data.test_num)
def main():
    # Configuration for hyper-parameters
    config = Config()
    beam_search_size = 5
    # Image Preprocessing
    transform = config.test_transform
    # Load vocabulary
    with open(os.path.join('../../coco', 'vocab.pkl'), 'rb') as f:
        vocab = pickle.load(f)
    # Build Models
    encoder = EncoderCNN(config.embed_size)
    encoder.eval()  # evaluation mode (BN uses moving mean/variance)
    decoder = DecoderRNN(config.embed_size, config.hidden_size, len(vocab),
                         config.num_layers)

    # Load the trained model parameters
    encoder.load_state_dict(
        torch.load(
            os.path.join('../../TrainedModels/TeacherCNN',
                         config.trained_encoder)))
    decoder.load_state_dict(
        torch.load(
            os.path.join('../../TrainedModels/TeacherLSTM',
                         config.trained_decoder)))
    # Build data loader
    image_path = os.path.join('../../coco', 'val2017')
    json_path = os.path.join('../../coco/annotations', 'captions_val2017.json')
    train_loader = get_data_loader(image_path,
                                   json_path,
                                   vocab,
                                   transform,
                                   1,
                                   shuffle=False,
                                   num_workers=config.num_threads)
    my_list = []
    img_ids = []
    loop_count = 0
    for i, (image_tensor, captions, lengths,
            img_id) in enumerate(train_loader):
        if img_id[0] in img_ids:
            continue
        loop_count += 1
        img_ids.append(img_id[0])
        image_tensor = Variable(image_tensor)
        state = (Variable(torch.zeros(config.num_layers, 1,
                                      config.hidden_size)),
                 Variable(torch.zeros(config.num_layers, 1,
                                      config.hidden_size)))
        # If use gpu
        if torch.cuda.is_available():
            encoder.cuda()
            decoder.cuda()
            state = [s.cuda() for s in state]
            image_tensor = image_tensor.cuda()
        cap_gen = CaptionGenerator(embedder=decoder.embed,
                                   rnn=decoder.lstm,
                                   classifier=decoder.linear,
                                   eos_id=2,
                                   beam_size=beam_search_size,
                                   max_caption_length=20,
                                   length_normalization_factor=0)
        # Generate caption from image
        feature = encoder(image_tensor)
        sentences, score = cap_gen.beam_search(feature)
        sampled_caption = []
        for word_id in sentences[-1]:
            if word_id == 1:
                continue
            if word_id == 2:
                break
            word = vocab.idx2word[word_id]
            sampled_caption.append(word)
        bestcaption = ' '.join(sampled_caption)
        mydic = {}
        mydic["image_id"] = img_id[0]
        mydic["caption"] = bestcaption
        my_list.append(mydic)
        if ((i + 1) % 100 == 0):
            print("Image Count=" + str(loop_count))
    filename = './beam' + str(beam_search_size) + 'database_teacher.json'
    with open(filename, 'w') as myfile:
        json.dump(my_list, myfile)
def main():
    # Configuration for hyper-parameters
    config = Config()
    beam_search_size = 5
    # Image Preprocessing
    transform = config.test_transform
    # Load vocabulary
    with open(os.path.join(config.vocab_path, 'vocab.pkl'), 'rb') as f:
        vocab = pickle.load(f)
    # Build Models
    encoder = EncoderCNN(config.embed_size)
    encoder.eval()  # evaluation mode (BN uses moving mean/variance)
    decoder = DecoderRNN(config.embed_size, config.hidden_size, len(vocab),
                         config.num_layers)

    # Load the trained model parameters
    encoder.load_state_dict(
        torch.load(
            os.path.join('../TrainedModels/TeacherCNN',
                         config.trained_encoder)))
    decoder.load_state_dict(
        torch.load(
            os.path.join('../TrainedModels/TeacherLSTM',
                         config.trained_decoder)))
    # Build data loader
    image_path = os.path.join(config.image_path, 'train2017')
    json_path = os.path.join(config.caption_path, 'captions_train2017.json')
    train_loader = get_data_loader(image_path,
                                   json_path,
                                   vocab,
                                   transform,
                                   1,
                                   shuffle=False,
                                   num_workers=config.num_threads)
    my_list = []
    img_ids = []
    loop_count = 0
    for i, (image_tensor, captions, lengths,
            img_id) in enumerate(train_loader):
        if img_id in img_ids:
            continue
        loop_count += 1
        if (loop_count == 5):
            break
        image_tensor = Variable(image_tensor)
        state = (Variable(torch.zeros(config.num_layers, 1,
                                      config.hidden_size)),
                 Variable(torch.zeros(config.num_layers, 1,
                                      config.hidden_size)))
        # If use gpu
        if torch.cuda.is_available():
            encoder.cuda()
            decoder.cuda()
            state = [s.cuda() for s in state]
            image_tensor = image_tensor.cuda()
        cap_gen = CaptionGenerator(embedder=decoder.embed,
                                   rnn=decoder.lstm,
                                   classifier=decoder.linear,
                                   eos_id=2,
                                   beam_size=beam_search_size,
                                   max_caption_length=20,
                                   length_normalization_factor=0)
        # Generate caption from image
        feature = encoder(image_tensor)
        sentences, score = cap_gen.beam_search(feature)
        get_topcider(img_id[0], sentences)
        if ((i + 1) % 100 == 0):
            print('Completed generation captions for ' + str(loop_count) +
                  ' Images')
    filename = './beam' + str(beam_search_size) + 'database.txt'
    with open(filename, 'w') as myfile:
        pickle.dump(mydatabaselist, myfile)
Exemple #8
0
def main():
    model = Level1Model(config, mode='training')

    # idx = h5py.File('./data/val/val_idx.h5')['labels']
    data = h5py.File('./data/train/train_caption.h5')
    images = data['images']
    captions = data['first_layer_labels']
    caption_idx = data['first_layer_label2imgid']

    # val_data = h5py.File('./data/val/val_caption.h5')
    # val_images = data['images']
    # val_captions = data['first_layer_labels']
    # val_caption_idx = data['first_layer_label2imgid']

    # first_level_label_start_ix = data['first_layer_label_start_ix']
    # first_level_label_end_ix = data['first_layer_label_end_ix']
    # second_level_label_start_ix = data['label_start_ix']
    # second_level_label_end_ix = data['label_end_ix']
    # second_level_label_pos = data['label_position']
    # second_level_labels = data['labels']

    optimizer = tf.train.AdamOptimizer
    log_path = './model/level1_test/'
    pretrained_model = None

    generator = CaptionGenerator(model,
                                 model.level1_word2ix,
                                 None,
                                 beam_size_1level=3,
                                 beam_size_2level=None,
                                 encourage_1level=0.0,
                                 encourage_2level=None,
                                 level2=False)
    loss = model.build()
    n_examples = caption_idx.shape[0]
    # n_examples_val = val_caption_idx.shape[0]
    n_iters_per_epoch = int(np.ceil(float(n_examples) / config.batch_size))
    # n_iters_val = int(np.ceil(float(n_examples_val) / config.batch_size))
    # print [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
    print[
        v for v in tf.trainable_variables() if v.name.startswith(
            "resnet/block7/bottleneck22/b/batch_normalization/")
    ]
    test1 = [
        v for v in tf.trainable_variables() if v.name == (
            "resnet/block7/bottleneck22/b/batch_normalization/beta:0")
    ][0]
    test2 = [
        v for v in tf.trainable_variables() if v.name == (
            "resnet/block7/bottleneck22/b/batch_normalization/gamma:0")
    ][0]
    # test3 = [v for v in  tf.get_default_graph().as_graph_def().node if v.name == ("resnet/block7/bottleneck22/b/batch_normalization/moving_mean")][0]
    print test1, test2
    with tf.name_scope('optimizer'):
        optimizer = optimizer(learning_rate=0.0000004)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            optim_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                           scope='level1')
            if config.train_resnet:
                optim_vars += tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                                scope='resnet')
            level1_grads = tf.gradients(loss, optim_vars)
            grads_and_vars = [(i, j) for i, j in zip(level1_grads, optim_vars)
                              if i is not None]
            grads_and_vars = [(tf.clip_by_value(grad, -0.1, 0.1), var)
                              for grad, var in grads_and_vars]
            # # todo: here check the batch-norm moving average/var
            # if config.train_resnet:
            #     optim_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='resnet')
            #     resnet_grads = tf.gradients(model.resnet.features, optim_vars)
            #     resnet_pairs = [(i, j) for i, j in zip(resnet_grads, optim_vars) if i is not None]
            #     grads_and_vars.extend(resnet_pairs)

            # batchnorm_updates = tf.get_collection('resnet_update_ops')
            # batchnorm_updates_op = tf.group(*batchnorm_updates)
            train_op = optimizer.apply_gradients(grads_and_vars=grads_and_vars)
            # train_op = tf.group(apply_gradient_op, batchnorm_updates_op)

    # summary op
    print '************************'
    tf.summary.scalar('batch_loss', loss)
    for var in tf.trainable_variables():
        tf.summary.histogram(var.op.name, var)
    # for grad, var in grads_and_vars:
    #     tf.summary.histogram(var.op.name + '/gradient', grad)

    summary_op = tf.summary.merge_all()

    print "The number of epoch: %d" % config.n_epochs
    print "Data size: %d" % n_examples
    print "Batch size: %d" % config.batch_size
    print "Iterations per epoch: %d" % n_iters_per_epoch

    config_ = tf.ConfigProto(allow_soft_placement=True)
    config_.gpu_options.per_process_gpu_memory_fraction = 0.9
    config_.gpu_options.allow_growth = True
    with tf.Session(config=config_) as sess:
        tf.global_variables_initializer().run()
        summary_writer = tf.summary.FileWriter(log_path,
                                               graph=tf.get_default_graph())
        saver = tf.train.Saver(max_to_keep=40)

        if pretrained_model is not None:
            print "Start training with pretrained Model.."
            saver.restore(sess, pretrained_model)

        prev_loss = -1
        curr_loss = 0
        start_t = time.time()
        i_global = 0
        for e in range(config.n_epochs):
            rand_idxs = list(np.random.permutation(n_examples))
            for i in range(n_iters_per_epoch):
                i_global += 1
                rand_idx = sorted(rand_idxs[i * config.batch_size:(i + 1) *
                                            config.batch_size])
                captions_batch = captions[rand_idx]
                img_idx = list(caption_idx[rand_idx])
                # print img_idx
                img_batch = crop_image(images[img_idx], True)
                # print decode_captions(captions_batch, model.level1_model.idx_to_word)
                # img_feature = sess.run(model.resnet.features, {model.resnet.images: img_batch})
                feed_dict = {
                    model.level1_model.captions: captions_batch,
                    model.level1_model.resnet.images: img_batch
                }
                _, l = sess.run([train_op, loss], feed_dict)
                # print 'batch norm beta:', sess.run(test1)[:10]
                # print 'batch norm gamma:', sess.run(test2)[:10]
                # print 'batch norm moving ave:', sess.run('resnet/block7/bottleneck22/b/batch_normalization/moving_mean:0')[:10]

                # l = sess.run(loss, feed_dict)
                curr_loss += l
                # write summary for tensorboard visualization
                if i % 1000 == 0:
                    summary = sess.run(summary_op, feed_dict)
                    summary_writer.add_summary(summary,
                                               e * n_iters_per_epoch + i)

                if (i + 1) % config.print_every == 0:
                    print "\nTrain loss at epoch %d & iteration %d (mini-batch): %.5f" % (
                        e + 1, i + 1, l)
                    # print img_idx, caption_idx == img_idx[0]
                    ground_truths = captions_batch[0]
                    decoded = decode_captions(ground_truths,
                                              model.level1_model.idx_to_word)
                    for j, gt in enumerate(decoded):
                        print "Ground truth %d: %s" % (j + 1, gt)
                        print ground_truths
                    predicted = generator.beam_search(sess,
                                                      img_batch[0:1, :, :, :])
                    decoded_predict = decode_captions(
                        np.asarray(predicted), model.level1_model.idx_to_word)
                    print "Generated caption: %s\n" % decoded_predict
                    print predicted
                    print '***************'
                if (i_global + 1) % 1000 == 0:
                    saver.save(sess,
                               os.path.join('./model',
                                            'model_level1_trained_bn'),
                               global_step=i_global + 1)
                    print "model-%s saved." % (i_global + 1)

            print "Previous epoch loss: ", prev_loss
            print "Current epoch loss: ", curr_loss
            print "Elapsed time: ", time.time() - start_t
            prev_loss = curr_loss
            curr_loss = 0