コード例 #1
0
ファイル: decoders.py プロジェクト: srush/tf-fork
 def delta_weights_old(self, updates, weights):
   self.set_weights(weights)
   #
   self.full_create_weight_fst(updates)
   if self.non_neg:
     print "Intersecting"
     #openfst.ArcSortInput(self.weight_fst)
     self.temp_fst = fsa.rho_compose(self.fst, False, self.just_weight_fst, True, False)
   else:
     self.create_weight_fst(updates)
     self.temp_fst = fsa.rho_compose(self.temp_fst, False, self.weight_fst, True, True)
コード例 #2
0
ファイル: temp.py プロジェクト: srush/tf-fork
sent_fst1.SetInputSymbols(s_table)
sent_fst1.SetOutputSymbols(s_table)

lm_fsa.SetInputSymbols(s_table)
lm_fsa.SetOutputSymbols(s_table)


openfst.ArcSortInput(lm_fsa)

inter_fst = openfst.StdVectorFst()
short_fst = openfst.StdVectorFst()
inter_det_fst = openfst.StdVectorFst()


#openfst.Intersect(fst1, fst2, inter_fst)
inter_fst = fsa.rho_compose(sent_fst1, False, lm_fsa, True, True)
openfst.Determinize(inter_fst, inter_det_fst)
print inter_det_fst.NumStates()
fsa.print_fst(inter_det_fst)

openfst.ShortestPath(inter_det_fst, short_fst, 1)
print short_fst.NumStates()
openfst.TopSort(short_fst)
fsa.print_fst(short_fst)


#-----


fst1 = openfst.StdVectorFst()
fst2 = openfst.StdVectorFst()
コード例 #3
0
ファイル: decoders.py プロジェクト: srush/tf-fork
  def set_heuristic_fsa(self, simple_tree, orig_uni_tree, simple_count, step_count, nts, last_sym):    
    """
    simple_tree - FSA of trans tree with zero weights
    orig_uni_tree - FSA of trans tree with unigram weights
    simple_count - FSA with counting
    step_count - FSA with 1.0 penalty for each step
    nts- map of all non-terms
    """

    uni_tree = orig_uni_tree
    #openfst.RmEpsilon(orig_uni_tree)
    #openfst.Determinize(orig_uni_tree, uni_tree)
    #openfst.Minimize(uni_tree)
    t1 = fsa.rho_compose(uni_tree, False, simple_count, True, True)
    t2 = fsa.rho_compose(simple_tree, False, step_count, True, True)
    
    assert uni_tree.NumStates() == simple_tree.NumStates(), str(uni_tree.NumStates()) + " "+ str(simple_tree.NumStates())
    assert step_count.NumStates() == simple_count.NumStates()

    self.heuristic_count_fst = openfst.StdVectorFst()
    
    self.heuristic_fst = openfst.StdVectorFst()
    openfst.Determinize(t1, self.heuristic_fst)
    openfst.Determinize(t2, self.heuristic_count_fst)
    openfst.Connect(self.heuristic_count_fst)
    openfst.Connect(self.heuristic_fst)
    openfst.TopSort(self.heuristic_count_fst)
    openfst.TopSort(self.heuristic_fst)
    assert self.heuristic_count_fst.NumStates() == self.heuristic_fst.NumStates()

    self.reverse_heuristic_fst = openfst.StdVectorFst()
    openfst.Reverse(self.heuristic_fst, self.reverse_heuristic_fst)

    self.nts = set([nt[1] for nt in nts])
    print "creating tables"
    
#     heuristic = {}
#     for step in range(1,self.output_bound+1):
#       for nt in self.nts:
#         heuristic.setdefault((nt, step), 1e10)
#         #print nt, step, heuristic[nt, step]

#     def unzip3(m):
#       n = len(m)
#       a = openfst.IntVector(n)
#       b = openfst.IntVector(n)
#       c = openfst.FloatVector(n)
#       for i,k in enumerate(m):
#         #print k
#         a[i] = int(k[0])
#         b[i] = int(k[1])
#         #print i, k[0], k[1], self.heuristic_count_fst.InputSymbols().Find(k[0]), float(m[k])
#         c[i] = float(m[k])
#       return (a,b,c)
#     self.blanks = unzip3(heuristic)
#     self.orig_n = len(self.blanks[0])

    self.topo_dist = openfst.FloatVector()
    openfst.ShortestDistance(self.heuristic_count_fst, self.topo_dist, False)
    self.topo_dist = [int(uninf(u))+1 for u in self.topo_dist]

    self.topo_dist = openfst.IntVector(self.topo_dist)
    n = len(self.topo_dist)
    self.cache_step = [None] * n
#    for i in range(n):
#       for j in range(self.heuristic_count_fst.NumArcs(i)):
#         out = self.heuristic_count_fst.GetOutput(i,j)
#         next = self.heuristic_count_fst.GetNext(i,j)
#         if next >= n: continue
#         if self.topo_dist[next] > self.output_bound: continue
#         if out not in self.nts: continue
#         self.cache_heu.append((next, out, int(self.output_bound - self.topo_dist[next])+1))

    self.state_symbol = openfst.IntVector(n)
    for i in range(n):
      for j in range(self.heuristic_count_fst.NumArcs(i)):
         out = self.heuristic_count_fst.GetOutput(i,j)
         next = self.heuristic_count_fst.GetNext(i,j)
         if next >= n: continue
         if self.topo_dist[next] > self.output_bound+ 1: continue
         if out not in self.nts: continue
         self.cache_step[next] = (self.topo_dist[next], out)
         self.state_symbol[next] = out - tree_extractor.SRC_NODE

    self.heuristic_pruner = openfst.BeamPrune(self.reverse_heuristic_fst.NumStates(),
                                              self.output_bound+1,
                                              tree_extractor.SRC_NODE,
                                              self.last_sym)
コード例 #4
0
ファイル: decode.py プロジェクト: srush/tf-fork
          #openfst.Minimize(det_lm_fsa)
        else:
          
          det_fsa = fsa
          
        openfst.ArcSortInput(fsa)
        openfst.ArcSortInput(lm_fsa)
        det_lm_fsa = lm_fsa
        print "Intersecting %s %s %s" % (det_fsa.NumStates(), det_lm_fsa.NumStates(), count_fsa.NumStates()) 
        #openfst.Prune(det_lm_fsa, 3.0)
        fsa2 = openfst.StdVectorFst()
        fsa3 = openfst.StdVectorFst()

        print "Intersect 1"
        #FSA.Intersect(det_fsa, det_lm_fsa, fsa2)
        fsa2 = FSA.rho_compose(det_fsa, False, det_lm_fsa, True, True)
#        tree_count_fsa = FSA.rho_compose(det_fsa, False, count_fsa, True, True)

        shortest = openfst.StdVectorFst()        

        bests = {}
        totals = {}

        if minimize:
          print "Minimizing FSA2"
          #det_fsa2 = openfst.StdVectorFst()

          #det_fsa2 = fsa2
          det_fsa2 = openfst.StdVectorFst()
          openfst.Determinize(fsa2, det_fsa2)
          openfst.Connect(det_fsa2)