Exemplo n.º 1
0
def main():

    start_time = time.time()
    """ compute distance """
    if FLAGS.compute_distance:
        filename = os.path.join(FLAGS.feature_dir, 'query_images.npy')
        query_images = np.load(filename)
        if FLAGS.test_texts:
            filename = os.path.join(FLAGS.feature_dir, 'test_texts.npy')
        else:
            filename = os.path.join(FLAGS.feature_dir, 'test_images.npy')
        test_images = np.load(filename)

        num_query = len(query_images)
        num_test = len(test_images)
        mod2img_dist = np.zeros(shape=[num_query, num_test], dtype=np.float32)

        if FLAGS.gpu_mode:
            ### compute distances on gpu
            graph = tf.Graph()
            with graph.as_default(), tf.device('/gpu:0'):
                query_images = tf.constant(query_images)
                test_images = tf.constant(test_images)

                idx_placeholder = tf.placeholder(tf.int32, shape=(None))
                query_image = tf.gather(query_images, idx_placeholder)
                diff = (query_image[:, tf.newaxis] - test_images)**2
                dist = tf.reduce_sum(diff, 2)

            config = tf.ConfigProto(allow_soft_placement=True,
                                    log_device_placement=False)
            config.gpu_options.allow_growth = True

            sess = tf.Session(graph=graph, config=config)

            i = 0
            num_iters = math.floor(num_query / FLAGS.batch_size)
            last_batch_size = num_query - num_iters * FLAGS.batch_size
            print('num_iters = %d, last_batch_size = %d' %
                  (num_iters, last_batch_size))
            while True:
                idx_min = i * FLAGS.batch_size
                idx_max = (i + 1) * FLAGS.batch_size
                i = i + 1
                feed_dict = {idx_placeholder: np.arange(idx_min, idx_max)}
                mod2img_dist[idx_min:idx_max, :] = sess.run(
                    dist, feed_dict=feed_dict)
                if i == num_iters:
                    if last_batch_size > 0:
                        idx_min = i * FLAGS.batch_size
                        idx_max = num_query
                        feed_dict = {
                            idx_placeholder: np.arange(idx_min, idx_max)
                        }
                        mod2img_dist[idx_min:idx_max, :] = sess.run(
                            dist, feed_dict=feed_dict)
                    break
        else:
            ### compute distances on cpu
            for i in range(num_query):
                diff = query_images[i, :][np.newaxis, :] - test_images
                dist = np.sum(diff**2, 1)
                mod2img_dist[i, :] = dist

        ### save pre-computed distance
        if FLAGS.test_texts:
            filename = os.path.join(FLAGS.feature_dir, 'mod2text_dist.npy')
        else:
            if FLAGS.subset is not None:
                ### store subset into invidual files
                filename = os.path.join(
                    FLAGS.feature_dir, 'mod2img_dist_' + FLAGS.subset + '.npy')
            else:
                filename = os.path.join(FLAGS.feature_dir, 'mod2img_dist.npy')
        np.save(filename, mod2img_dist)

        duration = time.time() - start_time
        print("elapsed time = %.2f second " % duration)
        dist = mod2img_dist

    else:
        if FLAGS.test_joint:
            filename = os.path.join(FLAGS.feature_dir, 'mod2text_dist.npy')
            dist1 = np.load(filename)
            if FLAGS.subset is not None:
                ### store subset into invidual files
                filename = os.path.join(
                    FLAGS.feature_dir, 'mod2img_dist_' + FLAGS.subset + '.npy')
            else:
                filename = os.path.join(FLAGS.feature_dir, 'mod2img_dist.npy')
            dist2 = np.load(filename)
            dist = dist1 + dist2
        else:
            if FLAGS.test_texts:
                filename = os.path.join(FLAGS.feature_dir, 'mod2text_dist.npy')
            else:
                if FLAGS.subset is not None:
                    ### store subset into invidual files
                    filename = os.path.join(
                        FLAGS.feature_dir,
                        'mod2img_dist_' + FLAGS.subset + '.npy')
                else:
                    filename = os.path.join(FLAGS.feature_dir,
                                            'mod2img_dist.npy')
            dist = np.load(filename)
    """ load groundtruth """
    if FLAGS.dataset == "fashion200k":
        filename = 'groundtruth/fashion200k_modif_pairs.npy'
        testset = fashion200k.fashion200k(path=FLAGS.data_path,
                                          split=FLAGS.data_split)
    elif FLAGS.dataset == "fashion_iq":
        if FLAGS.subset is None:
            filename = "groundtruth/fashion_iq_modif_pairs.npy"
        else:
            filename = "groundtruth/fashion_iq_modif_pairs_" + FLAGS.subset + ".npy"
        testset = fashion_iq.fashion_iq(path=FLAGS.data_path,
                                        split=FLAGS.data_split,
                                        subset=FLAGS.subset)
    elif FLAGS.dataset == 'shoes':
        filename = 'groundtruth/shoes_modif_pairs.npy'
        testset = shoes.shoes(path=FLAGS.data_path, split=FLAGS.data_split)
    else:
        raise ValueError("dataset is unknown.")
    groundtruth = np.load(filename)
    """ perform retrieval """
    start_time = time.time()
    ### generate source-query pairs at test time
    if FLAGS.dataset == "shoes":
        testset.generate_queries_()
        testset.generate_test_images_all_()
    elif FLAGS.dataset == "fashion200k":
        testset.generate_test_queries_()
    else:
        testset.generate_queries_(subset=FLAGS.subset)
        testset.generate_test_images_all_(subset=FLAGS.subset)

    gt_mask = groundtruth.astype(bool)
    order = np.arange(dist.shape[1])
    recall = np.ones(dist.shape)

    for i in range(len(dist)):
        ### Note: here the searching ones do not include itself
        if FLAGS.dataset == "fashion_iq" or FLAGS.dataset == "shoes":
            idx = testset.database.index(testset.source_files[i])
        else:
            idx = testset.test_queries[i]['source_img_id']
        dist[i, idx] = INF

        rank = np.argsort(dist[i, :])
        gt_label = order[gt_mask[i, :]]
        indexes = []
        for j in range(len(gt_label)):
            indexes.append(np.where(rank == gt_label[j])[0][0])
        recall_atk = min(indexes)
        recall[i, 0:recall_atk] = 0

    recall_avg = np.sum(recall, 0) / len(dist) * 100
    print("recall: R@1 = %.2f, R@10 = %.2f, R@50 = %.2f" %
          (recall_avg[0], recall_avg[9], recall_avg[49]))

    r1_num = len(dist) * recall_avg[0] / 100
    print("%d out of %d is correct rank1" % (r1_num, len(dist)))

    folder = "results/"

    if FLAGS.dataset == "shoes":
        filename = folder + 'results_shoes.log'
    elif FLAGS.dataset == "fashion200k":
        filename = folder + 'results_fashion200k.log'
    else:
        filename = folder + 'results_fashion_iq.log'

    with open(filename, 'a') as f:
        f.write(FLAGS.feature_dir + "\n")
        if FLAGS.test_texts:
            f.write('text retrieval: ')
        elif FLAGS.test_joint:
            f.write('joint retrieval: ')
        else:
            f.write('image retrieval: ')
        if FLAGS.subset is not None:
            f.write(FLAGS.subset + ' ')
        f.write("recall: R@1 = %.2f, R@10 = %.2f, R@50 = %.2f \n" %
                (recall_avg[0], recall_avg[9], recall_avg[49]))

    duration = time.time() - start_time
    print("elapsed time = %.2f second " % duration)
Exemplo n.º 2
0
def main():
  if FLAGS.dataset == "fashion200k":
    testset = fashion200k.fashion200k(path=FLAGS.data_path, split=FLAGS.data_split)
    trainset = fashion200k.fashion200k(path=FLAGS.data_path, split="train")
  elif FLAGS.dataset == "fashion_iq":
    testset = fashion_iq.fashion_iq(path=FLAGS.data_path, split=FLAGS.data_split, subset=FLAGS.subset)
    trainset = fashion_iq.fashion_iq(path=FLAGS.data_path, split="train", subset=FLAGS.subset)
  elif FLAGS.dataset == "shoes":
    testset = shoes.shoes(path=FLAGS.data_path, split=FLAGS.data_split)
    trainset = shoes.shoes(path=FLAGS.data_path, split="train")
  else: 
    raise ValueError("dataset is unknown.")

  if FLAGS.dataset != "fashion_iq" and FLAGS.dataset != "shoes":
    ### generate source-query pairs at test time
    testset.generate_test_queries_()
  elif FLAGS.dataset == "shoes":
    testset.generate_queries_()
  else:
    testset.generate_queries_(subset=FLAGS.subset)

  vocab = vocabulary.SimpleVocab()
  all_texts = trainset.get_all_texts()

  for text in all_texts:
    vocab.add_text_to_vocab(text)
  if FLAGS.remove_rare_words:
    vocab.threshold_rare_words()
  vocab_size = vocab.get_size()

  with tf.Graph().as_default():
    if FLAGS.dataset == "shoes":
      if FLAGS.query_images:
        dataset = tf.data.Dataset.from_tensor_slices((testset.source_files, testset.modify_texts))
        num_images = len(testset.source_files)
      else:
        testset.generate_test_images_all_()
        dataset = tf.data.Dataset.from_tensor_slices((testset.database, testset.database))
        num_images = len(testset.database)
    elif FLAGS.dataset == "fashion200k":
      if FLAGS.query_images:
        dataset = tf.data.Dataset.from_tensor_slices((testset.query_filenames, testset.modify_texts))
        num_images = len(testset.test_queries)
      else:
        dataset = tf.data.Dataset.from_tensor_slices((testset.filenames, testset.texts))
        num_images = len(testset.filenames)
    else:
      testset.generate_test_images_all_(subset=FLAGS.subset)
      if FLAGS.query_images:
        dataset = tf.data.Dataset.from_tensor_slices((testset.source_files, testset.modify_texts))
        num_images = len(testset.source_files)
      else:
        dataset = tf.data.Dataset.from_tensor_slices((testset.database, testset.database))
        num_images = len(testset.database)

    dataset = dataset.prefetch(1).map(eval_image_parse_function, num_parallel_calls=1)
    data_iterator = dataset.make_one_shot_iterator()
    batch_image, batch_text = data_iterator.get_next()

    images_placeholder = tf.placeholder(tf.float32, shape=(1, FLAGS.image_size, FLAGS.image_size, 3))
    texts_placeholder = tf.placeholder(tf.int32, shape=(1, None))
    seqlengths_placeholder = tf.placeholder(tf.int32, shape=(1))

    with tf.variable_scope(tf.get_variable_scope()):
      if FLAGS.query_images:
        cnn_features = _image_modify_model(images_placeholder, texts_placeholder, seqlengths_placeholder, vocab_size)
      elif FLAGS.test_texts:
        cnn_features = _text_model(texts_placeholder, seqlengths_placeholder, vocab_size)
      else:
        cnn_features = _image_model(images_placeholder)

    if math.isnan(FLAGS.moving_average_decay):
      vars_to_restore = tf.global_variables()
      vars_to_restore = [var for var in vars_to_restore if not "ogits" in var.name]
    else:
      vars_to_restore = tf.train.ExponentialMovingAverage(FLAGS.moving_average_decay).variables_to_restore()
      vars_to_restore = {k: v for k, v in vars_to_restore.items() if not "ogits" in k}    
      
    restorer = tf.train.Saver(vars_to_restore)

    config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)
    config.gpu_options.allow_growth=True

    feed_dict = {
      images_placeholder: np.zeros((FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size, 3)), 
      texts_placeholder: np.zeros((FLAGS.batch_size, 10), dtype=int),
      seqlengths_placeholder: np.zeros((FLAGS.batch_size), dtype=int)
    }

    with tf.Session(config=config) as sess:
      ### restore model
      if FLAGS.exact_model_checkpoint:
        restore_dir = os.path.join(FLAGS.checkpoint_dir, FLAGS.exact_model_checkpoint)
      else:
        ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
        restore_dir = ckpt.model_checkpoint_path
      if restore_dir:
        restorer.restore(sess, restore_dir)
        global_step = restore_dir.split('/')[-1].split('-')[-1]
        print('Successfully loaded model from %s at step=%s.' % (restore_dir, global_step))
      else:
        print('No checkpoint file found')
        return

      feature_size = cnn_features.get_shape().as_list()[1]
      print('feature dim is ' + str(feature_size))

      np_image_features = np.zeros(shape=[num_images, feature_size], dtype=np.float32)
      
      coord = tf.train.Coordinator()
      try:
        threads = []
        for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
          threads.extend(qr.create_threads(sess, coord=coord, daemon=True, start=True))
        index = 0

        print('Starting extaction on (%s). \n' % 'test data.')
        while index < num_images and not coord.should_stop():
          image_array, raw_text = sess.run([batch_image, batch_text], feed_dict=feed_dict)
          text_array = vocab.encode_text(raw_text.decode('utf-8'))
          lengths = len(text_array)

          if FLAGS.max_length is not None:
            lengths = min(FLAGS.max_length, lengths)
            max_length = FLAGS.max_length
          else:
            max_length = lengths

          feed_dict = {
            images_placeholder: image_array[np.newaxis,:,:,:], 
            texts_placeholder: np.array(text_array)[np.newaxis,:][:,0:max_length],
            seqlengths_placeholder: np.array([lengths])
          }

          np_image_feature = sess.run([cnn_features], feed_dict=feed_dict)
          np_image_features[index, :] = np_image_feature[0]
          sys.stdout.write('\r>> Extracting image features %d/%d.' % (index + 1, num_images))
          sys.stdout.flush()
          index += 1
        print('\n Finished extraction on (%s) data. \n' % FLAGS.data_split)

      except Exception as e:
        coord.request_stop(e)  
      coord.request_stop()
      coord.join(threads, stop_grace_period_secs=10)

  if not os.path.exists(FLAGS.feature_dir):
    os.mkdir(FLAGS.feature_dir)

  if FLAGS.query_images:
    filename = os.path.join(FLAGS.feature_dir, 'query_images.npy')
  elif FLAGS.test_texts:
    filename = os.path.join(FLAGS.feature_dir, 'test_texts.npy')
  else:
    filename = os.path.join(FLAGS.feature_dir, 'test_images.npy')
  os.makedirs(os.path.dirname(filename), exist_ok=True)
  print(filename)
  np.save(filename, np_image_features) 
Exemplo n.º 3
0
  'data_split', "train", 'either "train" or "test".')
tf.app.flags.DEFINE_string(
  'subset', None, 'can be "dress" or "shirt" or "toptee".')
tf.app.flags.DEFINE_boolean(
  'remove_rare_words', False, 'whether to remove the rare words.')

FLAGS = tf.app.flags.FLAGS

########### read dataset
print("Construct dataset")
if FLAGS.dataset == "fashion200k":
  trainset = fashion200k.fashion200k(path=FLAGS.data_path, split=FLAGS.data_split)
elif FLAGS.dataset == "fashion_iq":
  trainset = fashion_iq.fashion_iq(path=FLAGS.data_path, split=FLAGS.data_split, subset=FLAGS.subset)
elif FLAGS.dataset == "shoes":
  trainset = shoes.shoes(path=FLAGS.data_path, split=FLAGS.data_split)
else:
  raise ValueError("dataset is unknown.")
num_images = len(trainset.filenames)

### initialize the relations between source and target
if FLAGS.dataset == "fashion_iq":
  trainset.generate_queries_(subset=FLAGS.subset)
  all_texts = trainset.get_all_texts(subset=FLAGS.subset)
elif FLAGS.dataset == "shoes":
  trainset.generate_queries_()
  all_texts = trainset.get_all_texts()
elif FLAGS.dataset == "fashion200k":
  ### initialize the relations between source and target
  trainset.caption_index_init_()
  all_texts = trainset.get_all_texts()
Exemplo n.º 4
0
def main():

    ### prepare test set
    if FLAGS.dataset == "fashion200k":
        testset = fashion200k.fashion200k(path=FLAGS.data_path,
                                          split=FLAGS.data_split)
        filename = "groundtruth/fashion200k_modif_pairs.npy"
    elif FLAGS.dataset == "fashion_iq":
        testset = fashion_iq.fashion_iq(path=FLAGS.data_path,
                                        split=FLAGS.data_split,
                                        subset=FLAGS.subset)
        if FLAGS.subset is None:
            filename = "groundtruth/fashion_iq_modif_pairs.npy"
        else:
            filename = "groundtruth/fashion_iq_modif_pairs_" + FLAGS.subset + ".npy"
    elif FLAGS.dataset == "shoes":
        testset = shoes.shoes(path=FLAGS.data_path, split=FLAGS.data_split)
        filename = "groundtruth/shoes_modif_pairs.npy"
    else:
        raise ValueError("dataset is unknown.")

    ### generate source-query pairs at test time
    if FLAGS.dataset == "fashion200k":
        testset.generate_test_queries_()
        num_query = len(testset.test_queries)
        num_images = len(testset.filenames)
        groundtruth = np.full((num_query, num_images), False, dtype=bool)

        ### find the matching text pairs in the testset
        for i in range(num_query):
            ### the groundtruth has the same target text :)
            indices = [
                index for (index, letter) in enumerate(testset.texts)
                if letter == testset.test_queries[i]['target_caption']
            ]
            groundtruth[i, indices] = True  #1
        np.save(filename, groundtruth)

    elif FLAGS.dataset == 'shoes':
        testset.generate_queries_()
        testset.generate_test_images_all_()
        database = testset.database
        num_images = len(database)
        num_query = len(testset.source_files)
        groundtruth = np.full((num_query, num_images), False, dtype=bool)

        for i in range(num_query):
            idx = database.index(testset.target_files[i])
            groundtruth[i, idx] = True
        print('num_images = %d; num_query = %d' % (num_images, num_query))
        np.save(filename, groundtruth)

    elif FLAGS.dataset == 'fashion_iq':
        testset.generate_queries_(subset=FLAGS.subset)
        testset.generate_test_images_all_(subset=FLAGS.subset)
        database = testset.database
        num_images = len(database)
        num_query = len(testset.source_files)
        groundtruth = np.full((num_query, num_images), False, dtype=bool)

        for i in range(num_query):
            idx = database.index(testset.target_files[i])
            groundtruth[i, idx] = True
        print('num_images = %d; num_query = %d' % (num_images, num_query))
        np.save(filename, groundtruth)
Exemplo n.º 5
0
def main():

    if FLAGS.dataset == "fashion_iq":
        trainset = fashion_iq.fashion_iq(path=FLAGS.data_path,
                                         split=FLAGS.data_split,
                                         subset=FLAGS.subset)
    elif FLAGS.dataset == "shoes":
        trainset = shoes.shoes(path=FLAGS.data_path, split=FLAGS.data_split)
    else:
        raise ValueError("dataset must be fashion_iq or shoes")

    ### initialize the relations between source and target
    if FLAGS.dataset == "fashion_iq":
        trainset.generate_queries_(subset=FLAGS.subset)
        all_texts = trainset.get_all_texts(subset=FLAGS.subset)
    else:
        trainset.generate_queries_()
        all_texts = trainset.get_all_texts()
    num_modif = trainset.num_modifiable_imgs
    max_steps = FLAGS.train_length

    vocab = vocabulary.SimpleVocab()
    for text in all_texts:
        vocab.add_text_to_vocab(text)  # thêm từ chưa có trong vocab vào vocab
    if FLAGS.remove_rare_words:
        print('Remove rare words')
        vocab.threshold_rare_words()  # loại bỏ các từ hiếm
    vocab_size = vocab.get_size()
    print("Number of samples = {}. Number of words = {}.".format(
        num_modif, vocab_size))

    # Đọc tới đây là éo hiểu
    with tf.Graph().as_default():
        dataset = tf.data.Dataset.from_tensor_slices(
            (trainset.source_files, trainset.target_files,
             trainset.modify_texts))
        dataset = dataset.prefetch(FLAGS.batch_size).shuffle(num_modif).map(
            train_pair_image_parse_function,
            num_parallel_calls=FLAGS.threads).apply(
                batch_and_drop_remainder(FLAGS.batch_size)).repeat()
        data_iterator = dataset.make_one_shot_iterator()
        batch_source_image, batch_target_image, batch_text = data_iterator.get_next(
        )

        source_images_placeholder = tf.placeholder(tf.float32,
                                                   shape=(FLAGS.batch_size,
                                                          FLAGS.image_size,
                                                          FLAGS.image_size, 3))
        target_images_placeholder = tf.placeholder(tf.float32,
                                                   shape=(FLAGS.batch_size,
                                                          FLAGS.image_size,
                                                          FLAGS.image_size, 3))
        modify_texts_placeholder = tf.placeholder(tf.int32,
                                                  shape=(FLAGS.batch_size,
                                                         None))
        seqlengths_placeholder = tf.placeholder(tf.int32,
                                                shape=(FLAGS.batch_size))

        global_step = tf.train.get_or_create_global_step()

        if FLAGS.constant_lr:
            lr = FLAGS.init_learning_rate
        else:
            boundaries = [int(max_steps * 0.5)]
            values = [FLAGS.init_learning_rate, FLAGS.init_learning_rate * 0.1]
            print('boundaries = %s, values = %s ' % (boundaries, values))
            lr = tf.train.piecewise_constant(global_step, boundaries, values)
        opt = tf.train.AdamOptimizer(learning_rate=lr)

        with tf.variable_scope(tf.get_variable_scope()):
            total_loss, matching_loss = _build_model(
                source_images_placeholder, target_images_placeholder,
                modify_texts_placeholder, seqlengths_placeholder,
                vocab_size)  # build mô hình
            train_vars = tf.trainable_variables()
            barchnorm = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

            # not to train the logits layer
            train_vars = [var for var in train_vars if not "ogits" in var.name]
            barchnorm = [var for var in barchnorm if not "ogits" in var.name]

            barchnorm_op = tf.group(*barchnorm)
            updates_op = tf.assign(global_step, global_step + 1)

        if FLAGS.moving_average_decay:
            ema_op = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay).apply(train_vars)

        with tf.control_dependencies([barchnorm_op, updates_op, ema_op]):
            train_op = opt.minimize(loss=total_loss,
                                    global_step=tf.train.get_global_step(),
                                    var_list=train_vars)

        summaries = tf.get_collection(tf.GraphKeys.SUMMARIES)
        summary_op = tf.summary.merge(summaries)

        saver = tf.train.Saver(max_to_keep=6)
        init_op = tf.global_variables_initializer()

        config = tf.ConfigProto(allow_soft_placement=True,
                                log_device_placement=False)
        config.gpu_options.allow_growth = True
        sess = tf.Session(config=config)

        # Fine tuning from checkpoint???
        sess.run(init_op)
        if FLAGS.checkpoint_dir_stage1:
            load_checkpoint = tf.train.latest_checkpoint(
                FLAGS.checkpoint_dir_stage1)
            print("Fine tuning from checkpoint: {}".format(load_checkpoint))
            vars_to_load = optimistic_restore_vars(load_checkpoint)
            finetuning_restorer = tf.train.Saver(var_list=vars_to_load)
            finetuning_restorer.restore(sess, load_checkpoint)

        # Fine tuning from pretrained checkpoint???
        elif FLAGS.pretrain_checkpoint_dir:
            print("Fine tuning from pretrained checkpoint: {}".format(
                FLAGS.pretrain_checkpoint_dir))
            checkpoint_vars = tf.train.list_variables(
                FLAGS.pretrain_checkpoint_dir)
            checkpoint_vars = [v[0] for v in checkpoint_vars]
            vars_can_be_load = []
            all_vars = tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES)
            for i in range(len(all_vars)):
                var_name = all_vars[i].name.replace(":0", "")
                if (var_name in checkpoint_vars) and (
                        not var_name == "global_step") and (not "ogits"
                                                            in var_name):
                    vars_can_be_load.append(all_vars[i])
            pretrain_restorer = tf.train.Saver(var_list=vars_can_be_load)
            pretrain_restorer.restore(sess, FLAGS.pretrain_checkpoint_dir)

        summary_writer = tf.summary.FileWriter(FLAGS.checkpoint_dir,
                                               graph=sess.graph)

        feed_dict = {
            source_images_placeholder:
            np.zeros(
                (FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size, 3)),
            target_images_placeholder:
            np.zeros(
                (FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size, 3)),
            modify_texts_placeholder:
            np.zeros((FLAGS.batch_size, 10), dtype=int),
            seqlengths_placeholder:
            np.zeros((FLAGS.batch_size), dtype=int)
        }

        tf.train.start_queue_runners(sess=sess)
        start_time = time.time()

        # Bắt đầu training

        while True:

            source_image_array, target_image_array, raw_text, step = sess.run(
                [
                    batch_source_image, batch_target_image, batch_text,
                    global_step
                ],
                feed_dict=feed_dict)
            text_array, lengths = vocab.encode_text2id_batch(raw_text)

            if FLAGS.max_length is not None:
                lengths = np.minimum(lengths, FLAGS.max_length)
                max_length = FLAGS.max_length
            else:
                max_length = max(lengths)

            feed_dict = {
                source_images_placeholder: source_image_array,
                target_images_placeholder: target_image_array,
                modify_texts_placeholder: text_array[:, 0:max_length],
                seqlengths_placeholder: lengths
            }

            _, loss_value, matching_loss_value, step = sess.run(
                [train_op, total_loss, matching_loss, global_step],
                feed_dict=feed_dict)

            if step % FLAGS.print_span == 0:
                duration = time.time() - start_time
                start_time = time.time()
                print(
                    "step = %d, total_loss = %.4f, matching_loss = %s, time = %.4f"
                    % (step, loss_value, str(matching_loss_value), duration))
                summary_str = sess.run(summary_op, feed_dict=feed_dict)
                summary_writer.add_summary(summary_str, step)

            if step > 0 and (step % FLAGS.save_length == 0
                             or step == max_steps):
                checkpoint_path = os.path.join(FLAGS.checkpoint_dir,
                                               'model.ckpt')
                saver.save(sess, checkpoint_path,
                           global_step=step)  # Lưu model lại

            if step >= max_steps:
                break