示例#1
0
def main():
  from ngram import Ngram
  from model import Model
  from forest import Forest
  
  flags.DEFINE_integer("beam", 100, "beam size", short_name="b")
  flags.DEFINE_integer("debuglevel", 0, "debug level")
  flags.DEFINE_boolean("mert", True, "output mert-friendly info (<hyp><cost>)")
  flags.DEFINE_boolean("cube", True, "using cube pruning to speedup")
  flags.DEFINE_integer("kbest", 1, "kbest output", short_name="k")
  flags.DEFINE_integer("ratio", 3, "the maximum items (pop from PQ): ratio*b", short_name="r")
  

  argv = FLAGS(sys.argv)

  weights = Model.cmdline_model()
  lm = Ngram.cmdline_ngram()
  
  false_decoder = CYKDecoder(weights, lm)
  
  def non_local_scorer(cedge, ders):
    (lmsc, alltrans, sig) = false_decoder.deltLMScore(cedge.lhsstr, ders)
    fv = Vector()
    fv["lm"] = lmsc
    return ((weights.dot(fv), fv), alltrans, sig)
  cube_prune = CubePruning(FeatureScorer(weights), non_local_scorer, FLAGS.k, FLAGS.ratio)

  for i, forest in enumerate(Forest.load("-", is_tforest=True, lm=lm), 1):
    a = false_decoder.beam_search(forest, b = FLAGS.beam)
    b = cube_prune.run(forest.root)

    assert a[0], b[0].score[0]
    assert a[1], b[0].score[1]
    print a
    print b[0]
示例#2
0
文件: oracle.py 项目: srush/tf-fork
def oracle_extracter(forest, weights, false_decoder, k, ratio, extract=1):
  "reimplementation of forest.compute_oracle using cube pruning to get oracle forest"

  flen = len(forest)
  fbleu = forest.bleu
  def non_local_scorer(cedge, ders):
    bleu = fbleu.copy()

    node = cedge.head
    
    ratio = node.span_width() / float(flen) 
    bleu.special_reflen = fbleu.single_reflen() * ratio # proportional reflen
    wlen = ratio * flen
    
    hyp = cedge.assemble(ders)
    #print wlen, ratio, flen, bleu.rescore(hyp), hyp,ders
    bleu_score = bleu.rescore(hyp) #- (float(len(hyp.split()))* 1e-5)
    fv = Vector()
    if false_decoder:
      (lmsc, alltrans, sig) = false_decoder.deltLMScore(cedge.lhsstr, ders)
      fv["lm"] = lmsc    
    return ((bleu_score * wlen, -weights.dot(fv), fv), hyp, hyp)  

  decoder = CubePruning(BleuScorer(weights, 1.0, 0.0), non_local_scorer, k, ratio, find_min=False)

  start = time.time()
  best = decoder.run(forest.root)
  end = time.time()
  print >> logs, "Cube Bleu %s"%((end - start))

  start = time.time()
  dec_forest = decoder.extract_kbest_forest(forest, extract)
  end = time.time()
  print >> logs, "Extracting Forest %s"%((end - start))
  
  return dec_forest, best