Beispiel #1
0
def words2ids(words, feed_single=False, allow_all_zero=False, pad=True):
    """
  default params is suitable for bow
  for sequence method may need seg_method prhase and feed_single=True,
  @TODO feed_single is for simplicity, the best strategy may be try to use one level lower words
  like new-word -> phrase -> basic -> single cn

  #@TODO feed_single move to Segmentor.py to add support for seg with vocab 
  """
    if not feed_single:
        word_ids = [
            vocab.id(word) for word in words if vocab.has(word) or ENCODE_UNK
        ]
    else:
        word_ids = []
        for word in words:
            if vocab.has(word):
                word_ids.append(vocab.id(word))
            else:
                cns = gezi.get_single_cns(word)
                if cns:
                    for w in gezi.get_single_cns(word):
                        if vocab.has(w) or ENCODE_UNK:
                            word_ids.append(vocab.id(w))
                else:
                    if ENCODE_UNK:
                        word_ids.append(vocab.unk_id())

    if not allow_all_zero and not word_ids:
        word_ids.append(1)

    if pad:
        word_ids = gezi.pad(word_ids, TEXT_MAX_WORDS, 0)

    return word_ids
Beispiel #2
0
def words2ids(words,
              feed_single=True,
              allow_all_zero=False,
              pad=True,
              append_start=False,
              append_end=False,
              max_words=None,
              norm_digit=True,
              norm_all_digit=False):
    """
  default params is suitable for bow
  for sequence method may need seg_method prhase and feed_single=True,
  @TODO feed_single is for simplicity, the best strategy may be try to use one level lower words
  like new-word -> phrase -> basic -> single cn

  #@TODO feed_single move to Segmentor.py to add support for seg with vocab 
  norm_all_digit is not used mostly, since you can control this behavior when gen vocab 
  """
    if not feed_single:
        word_ids = [
            vocab.id(word) for word in words if vocab.has(word) or ENCODE_UNK
        ]
    else:
        word_ids = []
        for word in words:
            if norm_all_digit and word.isdigit():
                word_ids.append(vocab.id(NUM_MARK))
                continue
            if vocab.has(word):
                word_ids.append(vocab.id(word))
            elif not norm_all_digit and norm_digit and word.isdigit():
                word_ids.append(vocab.id(NUM_MARK))
            else:
                cns = gezi.get_single_cns(word)
                if cns:
                    for w in gezi.get_single_cns(word):
                        if vocab.has(w) or ENCODE_UNK:
                            word_ids.append(vocab.id(w))
                else:
                    if ENCODE_UNK:
                        word_ids.append(vocab.unk_id())

    if append_start:
        word_ids = [vocab.start_id()] + word_ids

    if append_end:
        word_ids = word_ids + [vocab.end_id()]

    if not allow_all_zero and not word_ids:
        word_ids.append(vocab.end_id())

    if pad:
        word_ids = gezi.pad(word_ids, max_words or TEXT_MAX_WORDS, 0)

    return word_ids
Beispiel #3
0
def loggest_match_seg(word, vocab, encode_unk=False):
    cns = gezi.get_single_cns(word)
    word_ids = []
    while True:
        id, cns = loggest_match(cns, vocab, encode_unk=encode_unk)
        if id != -1:
            word_ids.append(id)
        if not cns:
            break
    return word_ids
def words2ids(words,
              feed_single=True,
              allow_all_zero=False,
              pad=True,
              append_start=False,
              append_end=False,
              max_words=None,
              norm_digit=True,
              norm_all_digit=False,
              multi_grid=None,
              encode_unk=None,
              feed_single_en=False,
              digit_to_chars=False,
              unk_vocab_size=None):
    """
  default params is suitable for bow
  for sequence method may need seg_method prhase and feed_single=True,
  @TODO feed_single is for simplicity, the best strategy may be try to use one level lower words
  like new-word -> phrase -> basic -> single cn

  #@TODO feed_single move to Segmentor.py to add support for seg with vocab 
  norm_all_digit is not used mostly, since you can control this behavior when gen vocab 
  """
    multi_grid = multi_grid or MULTI_GRID
    encode_unk = encode_unk or ENCODE_UNK

    new_words = []
    if not feed_single:
        word_ids = [
            get_id(word, unk_vocab_size) for word in words
            if vocab.has(word) or encode_unk
        ]
    else:
        word_ids = []
        for word in words:
            if digit_to_chars and any(char.isdigit() for char in word):
                for w in word:
                    if not vocab.has(w) and unk_vocab_size:
                        word_ids.append(
                            gezi.hash(w) % unk_vocab_size + vocab.size())
                        new_words.append(w)
                    else:
                        if vocab.has(w) or encode_unk:
                            word_ids.append(vocab.id(w))
                            new_words.append(w)
                continue
            elif norm_all_digit and word.isdigit():
                word_ids.append(vocab.id(NUM_MARK))
                new_words.append(word)
                continue
            if vocab.has(word):
                word_ids.append(vocab.id(word))
                new_words.append(word)
            elif not norm_all_digit and norm_digit and word.isdigit():
                word_ids.append(vocab.id(NUM_MARK))
                new_words.append(word)
            else:
                #TODO might use trie to speed up longest match segment
                if (not multi_grid) or feed_single_en:
                    if not feed_single_en:
                        chars = gezi.get_single_cns(word)
                    else:
                        chars = word
                    if chars:
                        for w in chars:
                            if not vocab.has(w) and unk_vocab_size:
                                word_ids.append(
                                    gezi.hash(w) % unk_vocab_size +
                                    vocab.size())
                                new_words.append(w)
                            else:
                                if vocab.has(w) or encode_unk:
                                    word_ids.append(vocab.id(w))
                                    new_words.append(w)
                    else:
                        if unk_vocab_size:
                            word_ids.append(
                                gezi.hash(word) % unk_vocab_size +
                                vocab.size())
                            new_words.append(word)
                        else:
                            if encode_unk:
                                word_ids.append(vocab.unk_id())
                                new_words.append(word)
                else:
                    #test it!  print text2ids.ids2text(text2ids.text2ids('匍匐前进'))
                    word_ids += gezi.loggest_match_seg(
                        word,
                        vocab,
                        encode_unk=encode_unk,
                        unk_vocab_size=unk_vocab_size,
                        vocab_size=vocab.size())
                    # NOTICE new_words lost here!

    if append_start:
        word_ids = [vocab.start_id()] + word_ids

    if append_end:
        word_ids = word_ids + [vocab.end_id()]

    if not allow_all_zero and not word_ids:
        word_ids.append(vocab.end_id())

    if pad:
        word_ids = gezi.pad(word_ids, max_words or TEXT_MAX_WORDS, 0)

    return word_ids, new_words
def main(_):
  prediction_file = FLAGS.prediction_file or sys.argv[1]

  assert prediction_file

  log_dir = os.path.dirname(prediction_file)
  log_dir = log_dir or './'
  print('prediction_file', prediction_file, 'log_dir', log_dir, file=sys.stderr)
  logging.set_logging_path(log_dir)

  sess = tf.Session()
  summary_writer = tf.summary.FileWriter(log_dir, sess.graph)

  refs = prepare_refs()
  tokenizer = prepare_tokenization()
  ##TODO some problem running tokenizer..
  #refs = tokenizer.tokenize(refs)

  min_len = 10000
  min_len_image = None
  min_len_caption = None
  max_len = 0
  max_len_image = None
  max_len_caption = None
  sum_len = 0

  min_words = 10000
  min_words_image = None
  min_words_caption = None
  max_words = 0
  max_words_image = None
  max_words_caption = None
  sum_words = 0

  caption_metrics_file = FLAGS.caption_metrics_file or prediction_file.replace('evaluate-inference', 'caption_metrics')
  imgs = []
  captions = []
  infos = {}
  for line in open(prediction_file):
    l = line.strip().split('\t')
    img, caption, all_caption, all_score = l[0], l[1], l[-2], l[-1]
    img = img.replace('.jpg', '')
    img += '.jpg'
    imgs.append(img)

    infos[img] = '%s %s' % (all_caption.replace(' ', '|'), all_score.replace(' ', '|'))

    caption = caption.replace(' ', '').replace('\t', '')
    caption_words = [x.encode('utf-8') for x in jieba.cut(caption)]
    caption_str = ' '.join(caption_words)
    captions.append([caption_str])

    caption_len = len(gezi.get_single_cns(caption))
    num_words = len(caption_words)

    if caption_len < min_len:
      min_len = caption_len
      min_len_image = img 
      min_len_caption = caption
    if caption_len > max_len:
      max_len = caption_len
      max_len_image = img 
      max_len_caption = caption
    sum_len += caption_len

    if num_words < min_words:
      min_words = num_words
      min_words_image = img
      min_words_caption = caption_str
    if num_words > max_words:
      max_words = num_words
      max_words_image = img
      max_words_caption = caption_str
    sum_words += num_words

  results = dict(zip(imgs, captions))
  
  #results = tokenizer.tokenize(results)

  selected_results, selected_refs = translation_reorder_keys(results, refs)

  scorers = [
            (Bleu(4), ["bleu_1", "bleu_2", "bleu_3", "bleu_4"]),
            (Cider(), "cider"),
            (Meteor(), "meteor"),
            (Rouge(), "rouge_l")
        ]

  score_list = []
  metric_list = []
  scores_list = []

  print('img&predict&label:{}:{}{}{}'.format(selected_results.items()[0][0], '|'.join(selected_results.items()[0][1]), '---', '|'.join(selected_refs.items()[0][1])), file=sys.stderr)
  #print('avg_len:', sum_len / len(refs), 'min_len:', min_len, min_len_image, min_len_caption, 'max_len:', max_len, max_len_image, max_len_caption, file=sys.stderr)
  print('avg_len:', sum_len / refs_len, 'min_len:', min_len, min_len_image, min_len_caption, 'max_len:', max_len, max_len_image, max_len_caption, file=sys.stderr)
  print('avg_words', sum_words / refs_len, 'min_words:', min_words, min_words_image, min_words_caption, 'max_words:', max_words, max_words_image, max_words_caption, file=sys.stderr)
  
  for scorer, method in scorers:
    print('computing %s score...' % (scorer.method()), file=sys.stderr)
    score, scores = scorer.compute_score(selected_refs, selected_results)
    if type(method) == list:
      for i in range(len(score)):
        score_list.append(score[i])
        metric_list.append(method[i])
        scores_list.append(scores[i])
        print(method[i], score[i], file=sys.stderr)
    else:
      score_list.append(score)
      metric_list.append(method)
      scores_list.append(scores)
      print(method, score, file=sys.stderr)

  assert(len(score_list) == 7)

  avg_score = np.mean(np.array(score_list[3:]))
  score_list.insert(0, avg_score)
  metric_list.insert(0, 'avg')

  if caption_metrics_file:
    out = open(caption_metrics_file, 'w')
    print('image_id', 'caption', 'ref', '\t'.join(metric_list), 'infos', sep='\t', file=out)
    for i in range(len(selected_results)):
      key = selected_results.keys()[i] 
      result = selected_results[key][0]
      refs = '|'.join(selected_refs[key])
      bleu_1 = scores_list[0][i]
      bleu_2 = scores_list[1][i]
      bleu_3 = scores_list[2][i]
      bleu_4 = scores_list[3][i]
      cider = scores_list[4][i]
      meteor = scores_list[5][i]
      rouge_l = scores_list[6][i]
      avg = (bleu_4 + cider + meteor + rouge_l) / 4.
      print(key.split('.')[0], result, refs, avg, bleu_1, bleu_2, bleu_3, bleu_4, cider, meteor, rouge_l, infos[key], sep='\t', file=out)

  metric_list = ['trans_' + x for x in metric_list]
  metric_score_str = '\t'.join('%s:[%.4f]' % (name, result) for name, result in zip(metric_list, score_list))
  logging.info('%s\t%s'%(metric_score_str, os.path.basename(prediction_file)))

  print(key.split('.')[0], 'None', 'None', '\t'.join(map(str, score_list)), 'None', sep='\t', file=out)

  summary = tf.Summary()
  if score_list and 'ckpt' in prediction_file:
    try:
      epoch = float(os.path.basename(prediction_file).split('-')[1])
      #for float epoch like 0.01 0.02 turn it to 1, 2, notice it make epoch 1 to 100 
      epoch = int(epoch * 100)
      step = int(float(os.path.basename(prediction_file).split('-')[2].split('.')[0]))
      prefix = 'step' if FLAGS.write_step else 'epoch'
      melt.add_summarys(summary, score_list, metric_list, prefix=prefix)
      step = epoch if not FLAGS.write_step else step
      summary_writer.add_summary(summary, step)
      summary_writer.flush()
    except Exception:
      print(traceback.format_exc(), file=sys.stderr)