def __init__(self,
                 collection,
                 vocab_file,
                 feature,
                 language,
                 flag_shuffle=True,
                 method=None,
                 fluency_threshold=DEFAULT_FLUENCY_U,
                 rootpath=ROOT_PATH):
        self.language = language
        self.anno_file_path = utility.get_sent_file(collection, language,
                                                    rootpath)
        self.fluency_threshold = fluency_threshold
        self.method = method
        if method:
            self.sent_score_file = utility.get_sent_score_file(
                collection, language, rootpath)
            assert method in ['sample', 'filter', 'weighted']
            assert self.sent_score_file != None
            assert fluency_threshold > 0
            if method == 'weighted':
                # Not sampling the data if fluency-guided method is weighted_loss
                self.method = method = None
        else:
            self.sent_score_file = None

        self.textbank = TextBank(vocab_file)
        assert self.textbank.vocab[TOKEN_PAD] == 0
        self.vf_reader = BigFile(
            utility.get_feat_dir(collection, feature, rootpath))
        self.vf_names = set(self.vf_reader.names)
        self.vf_size = self.vf_reader.ndims
        self.flag_shuffle = flag_shuffle
        self._load_data()
Beispiel #2
0
  def __init__(self, config, model, length_normalization_factor=0.0):
    self.config = copy.deepcopy(config)
    self.config.batch_size = 1
    self.model = model

    self.textbank = TextBank(get_vocab_file(config.trainCollection, config.word_cnt_thr, config.rootpath))
    self.length_normalization_factor=length_normalization_factor
  def __init__(self, config, model, length_normalization_factor=0.0):
    self.config = copy.deepcopy(config)
    self.config.batch_size = 1
    self.model = model

    self.textbank = TextBank(get_vocab_file(config.trainCollection, config.word_cnt_thr, config.rootpath))
    self.length_normalization_factor=length_normalization_factor
 def __init__(self,
              collection,
              vocab_file,
              feature,
              language,
              flag_shuffle=False,
              fluency_threshold=DEFAULT_FLUENCY_U,
              rootpath=ROOT_PATH):
     self.language = language
     self.anno_file_path = utility.get_sent_file(collection, language,
                                                 rootpath)
     self.fluency_threshold = fluency_threshold
     self.textbank = TextBank(vocab_file)
     assert self.textbank.vocab[TOKEN_PAD] == 0
     self.vf_reader = BigFile(
         utility.get_feat_dir(collection, feature, rootpath))
     self.vf_names = set(self.vf_reader.names)
     self.vf_size = self.vf_reader.ndims
     self.flag_shuffle = flag_shuffle
     self._load_data()
 def __init__(self, collection, vocab_file, feature, language,
             flag_shuffle=False,  fluency_threshold=DEFAULT_FLUENCY_U, rootpath=ROOT_PATH):
     self.language = language
     self.anno_file_path = utility.get_sent_file(collection, language, rootpath)
     self.fluency_threshold = fluency_threshold
     self.textbank = TextBank(vocab_file)
     assert self.textbank.vocab[TOKEN_PAD] == 0
     self.vf_reader = BigFile(utility.get_feat_dir(collection, feature, rootpath))
     self.vf_names = set(self.vf_reader.names)
     self.vf_size = self.vf_reader.ndims
     self.flag_shuffle = flag_shuffle
     self._load_data()
class BucketDataProvider(object):
    """TensorFlow Data Provider with Buckets"""
    def __init__(self, collection, vocab_file, feature, language,
                flag_shuffle=False,  fluency_threshold=DEFAULT_FLUENCY_U, rootpath=ROOT_PATH):
        self.language = language
        self.anno_file_path = utility.get_sent_file(collection, language, rootpath)
        self.fluency_threshold = fluency_threshold
        self.textbank = TextBank(vocab_file)
        assert self.textbank.vocab[TOKEN_PAD] == 0
        self.vf_reader = BigFile(utility.get_feat_dir(collection, feature, rootpath))
        self.vf_names = set(self.vf_reader.names)
        self.vf_size = self.vf_reader.ndims
        self.flag_shuffle = flag_shuffle
        self._load_data()

    def shuffle_data_queue(self):
        random.shuffle(self._data_queue)

    def generate_batches(self, batch_size, buckets):
        """Return a list generator of mini-batches of training data."""
        # create Batches
        batches = []
        for max_seq_len in buckets:
            batches.append(Batch(batch_size, max_seq_len, self.vf_size, self.textbank.vocab[TOKEN_BOS]))
        
        # shuffle if necessary
        if self.flag_shuffle:
            np.random.shuffle(self._data_queue)
        # scan data queue
        for data in self._data_queue:
            # pdb.set_trace()
            sentence = data['sentence']
            # Load visual features
            # print(len(data['image_id']))
            visual_features = np.array(self.vf_reader.read_one(data['image_id']))
            #print("11111111")
            # print (data['image_id'])
            # print(visual_features)
            # print(data['sentence'])
            # sent = self.textbank.decode_tokens(data['sentence'], flag_remove_bos=True)
            # for word in sent:
            #     print (word)
            # # pdb.set_trace()
            if len(sentence) >= buckets[-1]:
                feed_res = batches[-1].feed_and_vomit(visual_features, sentence)
                ind_buc = len(buckets) - 1
            else:
                for (ind_b, batch) in enumerate(batches):
                    if len(sentence) < batch.max_seq_len:
                        feed_res = batches[ind_b].feed_and_vomit(visual_features, sentence)
                        ind_buc = ind_b
                        break
            if feed_res:
                yield (ind_buc,) + feed_res
                batches[ind_buc].empty()

            
    def _load_data(self, verbose=True):
        logger.debug('Loading data')
        self._data_queue = []
        annoss = codecs.open(self.anno_file_path,'r','utf-8').readlines()
        annos = [an.encode('utf-8').decode('utf-8-sig') for an in annoss]

        for (ind_a, line) in enumerate(annos):
            data = {}
            sid, sent = line.strip().split(" ", 1)
            imgid = sid.strip().split("#", 1)[0]
            # print(imgid)
            assert(imgid in self.vf_names)
            # pdb.set_trace()
            # if imgid not in self.vf_names:
            #    print(imgid)
            #    logger.info('%s not in feature data, skipping that.'%imgid)
            #    pdb.set_trace()
            #    continue
            data['image_id'] = imgid
            # print(imgid)
            # # Encode sentences

            tokens = TextTool.tokenize(sent, self.language)
            data['sentence'] = self.textbank.encode_tokens(tokens, flag_add_bos=False)
            self._data_queue.append(data)
            if verbose and (ind_a + 1) % 20000 == 0:
                logger.debug('%d/%d annotation', ind_a + 1, len(annos))
        random.shuffle( self._data_queue )   #       ############################# changed by gxr
        
        nr_of_images = len(set([data['image_id'] for data in self._data_queue]))
        logger.info('%d images, %d sentences from %s', nr_of_images, len(self._data_queue), self.anno_file_path)
Beispiel #7
0
def main(unused_args):

  length_normalization_factor = FLAGS.length_normalization_factor

  # Load model configuration
  config_path = os.path.join(os.path.dirname(__file__), 'model_conf', FLAGS.model_name + '.py')
  config = utility.load_config(config_path)

  config.trainCollection = FLAGS.train_collection
  config.word_cnt_thr = FLAGS.word_cnt_thr
  config.rootpath = FLAGS.rootpath

  train_collection =  FLAGS.train_collection
  test_collection = FLAGS.test_collection
  overwrite = FLAGS.overwrite
  feature = FLAGS.vf_name


  img_set_file = os.path.join(rootpath, test_collection, 'VideoSets', '%s.txt' % test_collection)
  if not os.path.exists(img_set_file):
      img_set_file = os.path.join(rootpath, test_collection, 'ImageSets', '%s.txt' % test_collection)
  img_list = map(str.strip, open(img_set_file).readlines())

  # have visual feature ready
  vf_dir = utility.get_feat_dir(test_collection, feature, rootpath)
  vf_reader = BigFile( vf_dir )

  textbank = TextBank(utility.get_train_vocab_file(FLAGS))
  config.vocab_size = len(textbank.vocab)
  config.vf_size = int(open(os.path.join(vf_dir, 'shape.txt')).read().split()[1])

  model_dir = utility.get_model_dir(FLAGS)
  output_dir = utility.get_pred_dir(FLAGS)

  checkpoint_style = FLAGS.checkpoint_style

  if checkpoint_style == 'file':
    #output_per_filename = 'model_perf_in_topk_%d_%s' % (FLAGS.top_k, FLAGS.eval_model_list_file)
    # read validated top models
    validation_output_dir = utility.get_sim_dir(FLAGS)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    eval_model_list_file = os.path.join(validation_output_dir, 'loss_info.txt') #FLAGS.eval_model_list_file)
    shutil.copy(eval_model_list_file, output_dir)
    test_iter_list = []
    for line in open(eval_model_list_file).readlines()[:FLAGS.top_k]:
      iter_current = int(line.strip().split()[0])
      test_iter_list.append(iter_current)

  elif checkpoint_style == 'iter_interval':
    #output_per_filename =  'model_perf_in_%s' % FLAGS.eval_stat
    test_iter_list = range(*[int(x) for x in FLAGS.eval_stat.split("-")])
  elif checkpoint_style == 'iter_num':
    #output_per_filename =  'model_perf_in_iter_%d' % FLAGS.iter_num
    test_iter_list = [FLAGS.iter_num]

  with_image_embedding = True if FLAGS.with_image_embedding != 0 else False
  g = tf.Graph()
  with g.as_default():
    model = InferenceWrapper(config=config,model_dir=model_dir,
                             gpu_memory_fraction=FLAGS.gpu_memory_fraction,
                             gpu=FLAGS.gpu,
                             with_image_embedding=with_image_embedding)
    model.build_model()
  
  for k, iter_n in enumerate(test_iter_list):
    model_path = os.path.join(model_dir, 'variables', 'model_%d.ckpt' % iter_n)
    while not os.path.exists(model_path+'.meta'):
      logger.error('Model path: %s', model_path)
      logger.error('Cannot load model file and exit')
      sys.exit(0)

    top_one_pred_sent_file = os.path.join(output_dir, 'top%d' % k, 'top_one_pred_sent.txt')
    top_n_pred_sent_file = os.path.join(output_dir, 'top%d' % k, 'top_n_pred_sent.txt')
    # perf_file = os.path.join(output_dir, 'model_%d.ckpt' % iter_n, 'perf.txt')

    if os.path.exists(top_one_pred_sent_file) and not overwrite:
      # write existing perf file and print out
      logger.info('%s exists. skip', top_one_pred_sent_file)
      continue

    if not os.path.exists(os.path.split(top_one_pred_sent_file)[0]):
      os.makedirs(os.path.split(top_one_pred_sent_file)[0])

    logger.info('save results to %s', top_one_pred_sent_file)

    # load the trained model
    generator = CaptionGenerator(config, model, length_normalization_factor = length_normalization_factor)
    gpu_options = tf.GPUOptions(
        per_process_gpu_memory_fraction=FLAGS.gpu_memory_fraction)
    config_proto = tf.ConfigProto(
      intra_op_parallelism_threads=FLAGS.ses_threads, gpu_options=gpu_options, allow_soft_placement=True)
    #with  tf.Session(config=config_proto) as session:
      #model.build_model(session, model_path)
    model.load_model(model_path)

    fout_one_sent = codecs.open(top_one_pred_sent_file, 'w','utf-8')
    fout_n_sent = codecs.open(top_n_pred_sent_file, 'w','utf-8')

    for progress,img in enumerate(img_list):
        # predict sentences given a visual feature
        visual_feature = np.array(vf_reader.read_one(img))
        sentences = generator.beam_search( visual_feature, FLAGS.beam_size)

        # output top one sentence info
        sent_score = sentences[0].score
        sent = ' '.join(sentences[0].words)
        fout_one_sent.write(img + ' ' + '%.3f' % sent_score + ' ' + sent + '\n')
        logger.debug(img + ' ' + '%.3f' % sent_score + ' ' + sent)

        # output top n sentences info
        fout_n_sent.write(img)
        for sentence in sentences:
            sent_score = sentence.score
            sent = ' '.join(sentence.words)
            fout_n_sent.write('\t' + '%.3f' % sent_score + '\t' + sent)
        fout_n_sent.write('\n')
      
        if progress % 100 == 0:
          logger.info('%d images decoded' % (progress+1))

    logger.info('%d images decoded' % (progress+1))
 
    fout_one_sent.close()
    fout_n_sent.close()
Beispiel #8
0
def main(unused_args):
    train_collection = FLAGS.train_collection
    val_collection = FLAGS.val_collection
    overwrite = FLAGS.overwrite
    output_dir = utility.get_sim_dir(FLAGS)
    loss_info_file = os.path.join(output_dir, 'loss_info.txt')
    if os.path.exists(loss_info_file) and not overwrite:
        logger.info('%s exists. quit', loss_info_file)
        sys.exit(0)

    model_dir = utility.get_model_dir(FLAGS)
    config_path = os.path.join(os.path.dirname(__file__), 'model_conf',
                               FLAGS.model_name + '.py')
    config = utility.load_config(config_path)

    if FLAGS.fluency_method == 'None':
        FLAGS.fluency_method = None
    config.fluency_method = FLAGS.fluency_method
    if config.fluency_method == 'weighted':
        config.use_weighted_loss = True
    else:
        config.use_weighted_loss = False

    textbank = TextBank(utility.get_train_vocab_file(FLAGS))
    config.vocab_size = len(textbank.vocab)
    config.vf_size = int(
        open(os.path.join(utility.get_val_feat_dir(FLAGS),
                          'shape.txt')).read().split()[1])

    gpu_options = tf.GPUOptions(
        per_process_gpu_memory_fraction=FLAGS.gpu_memory_fraction)
    config_proto = tf.ConfigProto(
        intra_op_parallelism_threads=FLAGS.ses_threads,
        gpu_options=gpu_options,
        allow_soft_placement=True)

    with_image_embedding = True if FLAGS.with_image_embedding > 0 else False
    with tf.Graph().as_default(), tf.Session(config=config_proto) as session:

        assert len(config.buckets) >= 1
        assert config.buckets[-1] == config.max_num_steps
        with tf.device('/gpu:%d' % FLAGS.gpu):
            with tf.variable_scope("LSTMModel", reuse=None):
                if with_image_embedding:
                    model = LSTMModel(
                        mode='eval',
                        num_steps=config.buckets[-1],
                        config=config,
                        model_dir=model_dir,  #model_name=FLAGS.model_name,
                        flag_with_saver=True)
                    #model_root=FLAGS.model_root)
                else:
                    # deprecating this option
                    print('Plz use image_embedding')
                    sys.exit(-1)
                model.build()

        model_path_list = []
        _dir = os.path.join(model_dir, 'variables')
        for _file in os.listdir(_dir):
            if _file.startswith('model_') and _file.endswith('.ckpt.meta'):
                iter_n = int(_file[6:-10])
                model_path = os.path.join(_dir, 'model_%d.ckpt' % iter_n)
                model_path_list.append((iter_n, model_path))

        data_provider = BucketDataProvider(val_collection,
                                           utility.get_train_vocab_file(FLAGS),
                                           feature=FLAGS.vf_name,
                                           language=FLAGS.language,
                                           flag_shuffle=False,
                                           method=config.fluency_method,
                                           rootpath=FLAGS.rootpath)
        iter2loss = {}
        for iter_n, model_path in model_path_list:
            loss_file = os.path.join(output_dir, 'model_%d.ckpt' % iter_n,
                                     'loss.txt')
            if os.path.exists(loss_file) and not overwrite:
                logger.info('load loss from %s', loss_file)
                loss = float(open(loss_file).readline().strip())
                iter2loss[iter_n] = loss
                continue
            if not os.path.exists(os.path.split(loss_file)[0]):
                os.makedirs(os.path.split(loss_file)[0])

            model.saver.restore(session, model_path)
            # print([v.name for v in tf.trainable_variables()])
            logger.info('Continue to train from %s', model_path)

            val_cost = run_epoch(session, config.batch_size,
                                 config.buckets[-1], config, model,
                                 data_provider)
            logger.info(
                "Validation cost for checkpoint model_%d.ckpt is %.3f" %
                (iter_n, val_cost))

            iter2loss[iter_n] = val_cost
            with open(loss_file, "w") as fw:
                fw.write('%g' % val_cost)
                fw.close()

    sorted_iter2loss = sorted(iter2loss.iteritems(), key=lambda x: x[1])
    with open(loss_info_file, 'w') as fw:
        fw.write('\n'.join(
            ['%d %g' % (iter_n, loss) for (iter_n, loss) in sorted_iter2loss]))
        fw.close()
Beispiel #9
0
def main(unused_args):
    model_dir = utility.get_model_dir(FLAGS)
    if os.path.exists(model_dir) and not FLAGS.overwrite:
        logger.info('%s exists. quit', model_dir)
        sys.exit(0)

    # Load model configuration
    config_path = os.path.join(os.path.dirname(__file__), 'model_conf',
                               FLAGS.model_name + '.py')
    config = utility.load_config(config_path)

    FLAGS.vf_dir = os.path.join(FLAGS.rootpath, FLAGS.train_collection,
                                'FeatureData', FLAGS.vf_name)
    vocab_file = utility.get_vocab_file(FLAGS.train_collection,
                                        FLAGS.word_cnt_thr, FLAGS.rootpath)
    textbank = TextBank(vocab_file)
    config.vocab_size = len(textbank.vocab)
    config.vf_size = int(
        open(os.path.join(FLAGS.vf_dir, 'shape.txt')).read().split()[1])

    if hasattr(config, 'num_epoch_save'):
        num_epoch_save = config.num_epoch_save
    else:
        num_epoch_save = 1

    if FLAGS.fluency_method == 'None':
        FLAGS.fluency_method = None
    config.fluency_method = FLAGS.fluency_method
    if config.fluency_method == 'weighted':
        config.use_weighted_loss = True
    else:
        config.use_weighted_loss = False

    train_image_embedding = True
    try:
        if config.train_image_embedding == False:
            assert ('freeze' in FLAGS.model_name)
            train_image_embedding = False
            logger.info('Not training image embedding')
    except:
        pass

    with_image_embedding = True if FLAGS.with_image_embedding != 0 else False
    # Start model training
    gpu_options = tf.GPUOptions(
        per_process_gpu_memory_fraction=FLAGS.gpu_memory_fraction)
    config_proto = tf.ConfigProto(
        intra_op_parallelism_threads=FLAGS.ses_threads,
        gpu_options=gpu_options,
        allow_soft_placement=True)

    with tf.Graph().as_default(), tf.Session(config=config_proto) as session:
        initializer = tf.random_uniform_initializer(-config.init_scale,
                                                    config.init_scale)
        assert len(config.buckets) >= 1
        assert config.buckets[-1] == config.max_num_steps
        models = []
        with tf.device('gpu:%s' % FLAGS.gpu):
            with tf.variable_scope("LSTMModel",
                                   reuse=None,
                                   initializer=initializer):
                if with_image_embedding:
                    m = LSTMModel(mode='train',
                                  num_steps=config.buckets[0],
                                  config=config,
                                  model_dir=model_dir,
                                  flag_with_saver=True,
                                  train_image_embedding=train_image_embedding)
                    #model_root=FLAGS.model_root)
                else:
                    # deprecating this function
                    logger.info('Plz use with_image_embedding=1')
                    sys.exit(-1)
                    '''m = PCALSTMModel(mode='train',
              num_steps=config.buckets[0],
              config=config,
              model_name=FLAGS.model_name,
              flag_with_saver=True,
              train_image_embedding=train_image_embedding,
              model_root=FLAGS.model_root)
          '''
                m.build()
                models.append(m)

        pre_trained_iter = 0
        if FLAGS.pre_trained_model_path:
            pre_trained_iter = int(
                FLAGS.pre_trained_model_path.split('model_')[1].split('.')[0])
        hdlr = logging.FileHandler(
            os.path.join(m.model_dir, 'log%d.txt' % pre_trained_iter))
        hdlr.setLevel(logging.INFO)
        hdlr.setFormatter(logging.Formatter(formatter_log))
        logger.addHandler(hdlr)

        if FLAGS.pre_trained_model_path:
            if tf.__version__ < '1.0':
                tf.initialize_all_variables().run()
            else:
                tf.global_variables_initializer().run()
            models[0].saver.restore(session, FLAGS.pre_trained_model_path)
            logger.info('Continue to train from %s',
                        FLAGS.pre_trained_model_path)
        elif FLAGS.pre_trained_imembedding_path:
            if tf.__version__ < '1.0':
                tf.initialize_all_variables().run()
            else:
                tf.global_variables_initializer().run()
            models[0].imemb_saver.restore(session,
                                          FLAGS.pre_trained_imembedding_path)
            logger.info('Init image-embedding from %s',
                        FLAGS.pre_trained_imembedding_path)
        elif FLAGS.pre_trained_lm_path:
            if tf.__version__ < '1.0':
                tf.initialize_all_variables().run()
            else:
                tf.global_variables_initializer().run()
            models[0].lm_saver.restore(session, FLAGS.pre_trained_lm_path)
            logger.info('Init language from %s', FLAGS.pre_trained_lm_path)
        else:
            if tf.__version__ < '1.0':
                tf.initialize_all_variables().run()
            else:
                tf.global_variables_initializer().run()
            # print([v.name for v in tf.trainable_variables()])

        iters_done = 0
        data_provider = BucketDataProvider(FLAGS.train_collection,
                                           vocab_file,
                                           FLAGS.vf_name,
                                           language=FLAGS.language,
                                           method=config.fluency_method,
                                           rootpath=FLAGS.rootpath)

        for i in range(config.num_epoch):
            logger.info('epoch %d', i)
            data_provider.shuffle_data_queue()
            train_cost, iters_done = run_epoch(session,
                                               iters_done,
                                               config,
                                               models,
                                               data_provider,
                                               verbose=True)
            logger.info("Train cost for epoch %d is %.3f" % (i, train_cost))

            # save the current model if necessary
            if (i + 1) % num_epoch_save == 0:
                models[0].saver.save(
                    session,
                    os.path.join(
                        m.variable_dir,
                        'model_%d.ckpt' % (iters_done + pre_trained_iter)))
                if with_image_embedding:
                    models[0].imemb_saver.save(session, os.path.join(m.variable_dir, \
                       'imembedding_model_%d.ckpt' % (iters_done)))
                logger.info("Model saved at iteration %d", iters_done)

    # copy the configure file in to checkpoint direction
    os.system("cp %s %s" % (config_path, model_dir))
    if FLAGS.pre_trained_model_path:
        os.system("echo %s > %s" %
                  (FLAGS.pre_trained_model_path,
                   os.path.join(model_dir, 'pre_trained_model_path.txt')))
    if FLAGS.pre_trained_imembedding_path:
        os.system(
            "echo %s > %s" %
            (FLAGS.pre_trained_imembedding_path,
             os.path.join(model_dir, 'pre_trained_imembedding_path.txt')))
Beispiel #10
0
class CaptionGenerator(object):
    """Class to generate captions from an image-to-text model."""
    def __init__(self, config, model, length_normalization_factor=0.0):
        self.config = copy.deepcopy(config)
        self.config.batch_size = 1
        self.model = model

        self.textbank = TextBank(
            get_vocab_file(config.trainCollection, config.word_cnt_thr,
                           config.rootpath))
        self.length_normalization_factor = length_normalization_factor

    def beam_search(self,
                    visual_feature,
                    beam_size,
                    max_steps=30,
                    tag2score=None):
        """Decode an image with a sentences."""
        assert visual_feature.shape[0] == self.config.vf_size
        #assert self.flag_load_model, 'Must call local_model first'
        # Get the initial logit and state
        initial_state = self.model.feed_visual_feature(visual_feature)
        # print(visual_feature)
        initial_beam = Caption(sentence=[self.textbank.vocab[TOKEN_BOS]],
                               state=initial_state[0],
                               logprob=0.0,
                               score=0.0,
                               metadata=[""])
        partial_captions = TopN(beam_size)
        partial_captions.push(initial_beam)
        complete_captions = TopN(beam_size)
        # print(len(partial_captions))
        # print(len(complete_captions))
        # Run beam search.
        for _ in range(max_steps - 1):
            partial_captions_list = partial_captions.extract()
            partial_captions.reset()
            # for c in partial_captions_list:
            #   print(c.sentence)
            input_feed = np.array(
                [c.sentence[-1] for c in partial_captions_list])
            state_feed = np.array([c.state for c in partial_captions_list])
            softmax, new_states, metadata = self.model.inference_step(  #sess,
                input_feed, state_feed)
            for i, partial_caption in enumerate(partial_captions_list):
                word_probabilities = softmax[i]
                state = new_states[i]
                # For this partial caption, get the beam_size most probable next words.
                words_and_probs = list(enumerate(word_probabilities))
                words_and_probs.sort(key=lambda x: -x[1])
                words_and_probs = words_and_probs[0:beam_size]
                # Each next word gives a new partial caption.
                for w, p in words_and_probs:
                    if tag2score != None and w in tag2score and w not in partial_caption.sentence:
                        p += tag2score[w]
                    if p < 1e-12:
                        continue  # Avoid log(0).
                    sentence = partial_caption.sentence + [w]
                    logprob = partial_caption.logprob + math.log(p)
                    score = logprob
                    if metadata:
                        metadata_list = partial_caption.metadata + [
                            metadata[i]
                        ]
                    else:
                        metadata_list = None
                    if w == self.textbank.vocab[TOKEN_BOS]:
                        if self.length_normalization_factor > 1e-6:
                            score /= len(
                                sentence)**self.length_normalization_factor
                        beam = Caption(sentence, state, logprob, score,
                                       metadata_list)
                        complete_captions.push(beam)
                    else:
                        beam = Caption(sentence, state, logprob, score,
                                       metadata_list)
                        partial_captions.push(beam)
            if partial_captions.size() == 0:
                # We have run out of partial candidates; happens when beam_size = 1.
                break

        # If we have no complete captions then fall back to the partial captions.
        # But never output a mixture of complete and partial captions because a
        # partial caption could have a higher score than all the complete captions.
        if not complete_captions.size():
            complete_captions = partial_captions

        captions = complete_captions.extract(sort=True)
        for i, caption in enumerate(captions):
            caption.words = self.textbank.decode_tokens(caption.sentence[1:])

        return captions
class BucketDataProvider(object):
    """TensorFlow Data Provider with Buckets"""
    def __init__(self,
                 collection,
                 vocab_file,
                 feature,
                 language,
                 flag_shuffle=False,
                 fluency_threshold=DEFAULT_FLUENCY_U,
                 rootpath=ROOT_PATH):
        self.language = language
        self.anno_file_path = utility.get_sent_file(collection, language,
                                                    rootpath)
        self.fluency_threshold = fluency_threshold
        self.textbank = TextBank(vocab_file)
        assert self.textbank.vocab[TOKEN_PAD] == 0
        self.vf_reader = BigFile(
            utility.get_feat_dir(collection, feature, rootpath))
        self.vf_names = set(self.vf_reader.names)
        self.vf_size = self.vf_reader.ndims
        self.flag_shuffle = flag_shuffle
        self._load_data()

    def shuffle_data_queue(self):
        random.shuffle(self._data_queue)

    def generate_batches(self, batch_size, buckets):
        """Return a list generator of mini-batches of training data."""
        # create Batches
        batches = []
        for max_seq_len in buckets:
            batches.append(
                Batch(batch_size, max_seq_len, self.vf_size,
                      self.textbank.vocab[TOKEN_BOS]))

        # shuffle if necessary
        if self.flag_shuffle:
            np.random.shuffle(self._data_queue)
        # scan data queue
        for data in self._data_queue:
            # pdb.set_trace()
            sentence = data['sentence']
            # Load visual features
            # print(len(data['image_id']))
            visual_features = np.array(
                self.vf_reader.read_one(data['image_id']))
            #print("11111111")
            # print (data['image_id'])
            # print(visual_features)
            # print(data['sentence'])
            # sent = self.textbank.decode_tokens(data['sentence'], flag_remove_bos=True)
            # for word in sent:
            #     print (word)
            # # pdb.set_trace()
            if len(sentence) >= buckets[-1]:
                feed_res = batches[-1].feed_and_vomit(visual_features,
                                                      sentence)
                ind_buc = len(buckets) - 1
            else:
                for (ind_b, batch) in enumerate(batches):
                    if len(sentence) < batch.max_seq_len:
                        feed_res = batches[ind_b].feed_and_vomit(
                            visual_features, sentence)
                        ind_buc = ind_b
                        break
            if feed_res:
                yield (ind_buc, ) + feed_res
                batches[ind_buc].empty()

    def _load_data(self, verbose=True):
        logger.debug('Loading data')
        self._data_queue = []
        annoss = codecs.open(self.anno_file_path, 'r', 'utf-8').readlines()
        annos = [an.encode('utf-8').decode('utf-8-sig') for an in annoss]

        for (ind_a, line) in enumerate(annos):
            data = {}
            sid, sent = line.strip().split(" ", 1)
            imgid = sid.strip().split("#", 1)[0]
            # print(imgid)
            assert (imgid in self.vf_names)
            # pdb.set_trace()
            # if imgid not in self.vf_names:
            #    print(imgid)
            #    logger.info('%s not in feature data, skipping that.'%imgid)
            #    pdb.set_trace()
            #    continue
            data['image_id'] = imgid
            # print(imgid)
            # # Encode sentences

            tokens = TextTool.tokenize(sent, self.language)
            data['sentence'] = self.textbank.encode_tokens(tokens,
                                                           flag_add_bos=False)
            self._data_queue.append(data)
            if verbose and (ind_a + 1) % 20000 == 0:
                logger.debug('%d/%d annotation', ind_a + 1, len(annos))
        random.shuffle(self._data_queue
                       )  #       ############################# changed by gxr

        nr_of_images = len(set([data['image_id']
                                for data in self._data_queue]))
        logger.info('%d images, %d sentences from %s', nr_of_images,
                    len(self._data_queue), self.anno_file_path)
class BucketDataProvider(object):
    """TensorFlow Data Provider with Buckets"""
    def __init__(self,
                 collection,
                 vocab_file,
                 feature,
                 language,
                 flag_shuffle=True,
                 method=None,
                 fluency_threshold=DEFAULT_FLUENCY_U,
                 rootpath=ROOT_PATH):
        self.language = language
        self.anno_file_path = utility.get_sent_file(collection, language,
                                                    rootpath)
        self.fluency_threshold = fluency_threshold
        self.method = method
        if method:
            self.sent_score_file = utility.get_sent_score_file(
                collection, language, rootpath)
            assert method in ['sample', 'filter', 'weighted']
            assert self.sent_score_file != None
            assert fluency_threshold > 0
            if method == 'weighted':
                # Not sampling the data if fluency-guided method is weighted_loss
                self.method = method = None
        else:
            self.sent_score_file = None

        self.textbank = TextBank(vocab_file)
        assert self.textbank.vocab[TOKEN_PAD] == 0
        self.vf_reader = BigFile(
            utility.get_feat_dir(collection, feature, rootpath))
        self.vf_names = set(self.vf_reader.names)
        self.vf_size = self.vf_reader.ndims
        self.flag_shuffle = flag_shuffle
        self._load_data()

    def shuffle_data_queue(self):
        random.shuffle(self._data_queue)

    def generate_batches(self, batch_size, buckets):
        """Return a list generator of mini-batches of training data."""
        # create Batches
        batches = []
        for max_seq_len in buckets:
            batches.append(
                Batch(batch_size, max_seq_len, self.vf_size,
                      self.textbank.vocab[TOKEN_BOS]))

        # shuffle if necessary
        if self.flag_shuffle:
            np.random.shuffle(self._data_queue)
        # scan data queue
        for data in self._data_queue:
            if self.method:
                if data['sent_score'] < self.fluency_threshold:
                    if self.method == 'filter':
                        #Drop if the sent_score < threshold
                        continue
                    elif self.method == 'sample':
                        # Drop with certain probability if the sent_score < 1
                        x = random.uniform(0, self.fluency_threshold)
                        if x > data['sent_score']:
                            continue
            score = data['sent_score'] if self.sent_score_file else None
            sentence = data['sentence']
            # Load visual features
            visual_features = np.array(
                self.vf_reader.read_one(data['image_id']))
            if len(sentence) >= buckets[-1]:
                feed_res = batches[-1].feed_and_vomit(visual_features,
                                                      sentence, score)
                ind_buc = len(buckets) - 1
            else:
                for (ind_b, batch) in enumerate(batches):
                    if len(sentence) < batch.max_seq_len:
                        feed_res = batches[ind_b].feed_and_vomit(
                            visual_features, sentence, score)
                        ind_buc = ind_b
                        break
            if feed_res:
                yield (ind_buc, ) + feed_res
                batches[ind_buc].empty()

    def _load_data(self, verbose=True):
        logger.debug('Loading data')
        self._data_queue = []
        ind_img = 0
        num_failed = 0
        if self.sent_score_file != None:
            sid2score = {}
            for line in open(self.sent_score_file):
                elem = line.strip().split('\t')
                sid = elem[0]
                score = float(elem[-1])
                sid2score[sid] = score
        annos = codecs.open(self.anno_file_path, 'r', 'utf-8').readlines()
        for (ind_a, line) in enumerate(annos):
            data = {}
            sid, sent = line.strip().split(" ", 1)
            imgid = sid.strip().split("#")[0]
            if imgid.endswith('.jpg') or imgid.endswith('.mp4'):
                imgid = imgid[:-4]
            #assert imgid in self.vf_names, '%s not in feature data'%imgid
            assert (imgid in self.vf_names)
            #if imgid not in self.vf_names:
            #    logger.info('%s not in feature data, skipping that.'%imgid)
            #    continue
            data['image_id'] = imgid

            # Encode sentences
            tokens = TextTool.tokenize(sent, self.language)
            data['sentence'] = self.textbank.encode_tokens(tokens,
                                                           flag_add_bos=False)
            data['sent_score'] = sid2score[
                sid] if self.sent_score_file and sid in sid2score else 1
            self._data_queue.append(data)
            if verbose and (ind_a + 1) % 20000 == 0:
                logger.debug('%d/%d annotation', ind_a + 1, len(annos))
        random.shuffle(self._data_queue)

        nr_of_images = len(set([data['image_id']
                                for data in self._data_queue]))
        logger.info('%d images, %d sentences from %s', nr_of_images,
                    len(self._data_queue), self.anno_file_path)
class CaptionGenerator(object):
  """Class to generate captions from an image-to-text model."""

  def __init__(self, config, model, length_normalization_factor=0.0):
    self.config = copy.deepcopy(config)
    self.config.batch_size = 1
    self.model = model

    self.textbank = TextBank(get_vocab_file(config.trainCollection, config.word_cnt_thr, config.rootpath))
    self.length_normalization_factor=length_normalization_factor
    

  def beam_search(self, visual_feature, beam_size, max_steps=30, tag2score=None):
    """Decode an image with a sentences."""
    assert visual_feature.shape[0] == self.config.vf_size
    #assert self.flag_load_model, 'Must call local_model first'
    # Get the initial logit and state
    initial_state = self.model.feed_visual_feature(visual_feature)
    # print(visual_feature)
    initial_beam = Caption(
        sentence=[self.textbank.vocab[TOKEN_BOS]],
        state=initial_state[0],
        logprob=0.0,
        score=0.0,
        metadata=[""])
    partial_captions = TopN(beam_size)
    partial_captions.push(initial_beam)
    complete_captions = TopN(beam_size)
    # print(len(partial_captions))
    # print(len(complete_captions))
    # Run beam search.
    for _ in range(max_steps - 1): 
      partial_captions_list = partial_captions.extract()
      partial_captions.reset()
      # for c in partial_captions_list:
      #   print(c.sentence)
      input_feed = np.array([c.sentence[-1] for c in partial_captions_list])
      state_feed = np.array([c.state for c in partial_captions_list])
      softmax, new_states, metadata = self.model.inference_step(#sess,
                                                                input_feed,
                                                                state_feed)
      for i, partial_caption in enumerate(partial_captions_list):
        word_probabilities = softmax[i]
        state = new_states[i]
        # For this partial caption, get the beam_size most probable next words.
        words_and_probs = list(enumerate(word_probabilities))
        words_and_probs.sort(key=lambda x: -x[1])
        words_and_probs = words_and_probs[0:beam_size]
        # Each next word gives a new partial caption.
        for w, p in words_and_probs:
          if tag2score!=None and w in tag2score and w not in partial_caption.sentence:
            p+=tag2score[w]
          if p < 1e-12:
            continue  # Avoid log(0).
          sentence = partial_caption.sentence + [w]
          logprob = partial_caption.logprob + math.log(p)
          score = logprob
          if metadata:
            metadata_list = partial_caption.metadata + [metadata[i]]
          else:
            metadata_list = None
          if w == self.textbank.vocab[TOKEN_BOS]:
            if self.length_normalization_factor > 1e-6:
              score /= len(sentence)**self.length_normalization_factor
            beam = Caption(sentence, state, logprob, score, metadata_list)
            complete_captions.push(beam)
          else:
            beam = Caption(sentence, state, logprob, score, metadata_list)
            partial_captions.push(beam)
      if partial_captions.size() == 0:
        # We have run out of partial candidates; happens when beam_size = 1.
        break

    # If we have no complete captions then fall back to the partial captions.
    # But never output a mixture of complete and partial captions because a
    # partial caption could have a higher score than all the complete captions.
    if not complete_captions.size():
      complete_captions = partial_captions

    captions = complete_captions.extract(sort=True)
    for i, caption in enumerate(captions):
        caption.words = self.textbank.decode_tokens(caption.sentence[1:])

    return captions