def test_sum(self): # Empty graph self.assertTrue(gtn.equal(gtn.union([]), gtn.Graph())) # Check single graph is a no-op g1 = gtn.Graph() g1.add_node(True) g1.add_node(False, True) g1.add_arc(0, 1, 1) self.assertTrue(gtn.equal(gtn.union([g1]), g1)) # Simple union g1 = gtn.Graph() g1.add_node(True) g1.add_node(False, True) g1.add_arc(0, 1, 1) g2 = gtn.Graph() g2.add_node(True) g2.add_node(False, True) g2.add_arc(0, 1, 0) expected = gtn.Graph() expected.add_node(True) expected.add_node(True) expected.add_node(False, True) expected.add_node(False, True) expected.add_arc(0, 2, 1) expected.add_arc(1, 3, 0) self.assertTrue(gtn.isomorphic(gtn.union([g1, g2]), expected)) # Check adding with an empty graph works g1 = gtn.Graph() g1.add_node(True) g1.add_node(False, True) g1.add_arc(0, 1, 1) g2 = gtn.Graph() g3 = gtn.Graph() g3.add_node(True, True) g3.add_arc(0, 0, 2) expected = gtn.Graph() expected.add_node(True) expected.add_node(False, True) expected.add_node(True, True) expected.add_arc(0, 1, 1) expected.add_arc(2, 2, 2) self.assertTrue(gtn.isomorphic(gtn.union([g1, g2, g3]), expected))
def forward_fn(g): paths = [ gtn.viterbi_path(g), gtn.viterbi_path(g), gtn.viterbi_path(g) ] return gtn.forward_score(gtn.union(paths))
def token_graph(token_list): """ Constructs a graph with all the individual token transition models. """ tokens = [] for i, wp in enumerate(token_list): # We can consume one or more consecutive # word pieces for each emission: # E.g. [ab, ab, ab] transduces to [ab] graph = gtn.Graph() graph.add_node(True) graph.add_node(False, True) graph.add_arc(0, 1, i, i) graph.add_arc(1, 1, i, gtn.epsilon) tokens.append(graph) return gtn.closure(gtn.union(tokens))
def lexicon_graph(word_pieces, letters_to_idx): """ Constructs a graph which transudces letters to word pieces. """ lex = [] for i, wp in enumerate(word_pieces): graph = gtn.Graph() graph.add_node(True) for e, l in enumerate(wp): if e == len(wp) - 1: graph.add_node(False, True) graph.add_arc(e, e + 1, letters_to_idx[l], i) else: graph.add_node() graph.add_arc(e, e + 1, letters_to_idx[l], gtn.epsilon) lex.append(graph) return gtn.closure(gtn.union(lex))
def test_sum_grad(self): g1 = gtn.Graph() g1.add_node(True) g1.add_node() g1.add_node(False, True) g1.add_arc(0, 1, 0) g1.add_arc(1, 2, 1) # Works with a no gradient graph g2 = gtn.Graph()(False) g2.add_node(True) g2.add_node() g2.add_node(False, True) g2.add_arc(0, 1, 0) g2.add_arc(1, 2, 1) g3 = gtn.Graph() g3.add_node(True) g3.add_node() g3.add_node(False, True) g3.add_arc(0, 1, 0) g3.add_arc(1, 2, 1) gtn.backward(gtn.forward_score(gtn.union([g1, g2, g3]))) def forward_fn1(g, g2=g2, g3=g3): return gtn.forward_score(gtn.union([g, g2, g3])) self.assertTrue(numerical_grad_check(forward_fn1, g1, 1e-4, 1e-3)) def forward_fn2(g, g1=g1, g2=g2): return gtn.forward_score(gtn.union([g1, g2, g])) self.assertTrue(numerical_grad_check(forward_fn2, g3, 1e-4, 1e-3)) CHECK_THROWS(g2.grad())
def forward_fn2(g, g1=g1, g2=g2): return gtn.forward_score(gtn.union([g1, g2, g]))
def forward_fn1(g, g2=g2, g3=g3): return gtn.forward_score(gtn.union([g, g2, g3]))
g1.add_node(False, True) g1.add_arc(0, 1, 0) g1.add_arc(1, 2, 1) g1.add_arc(2, 2, 0) # Recognizes "ba" g2 = gtn.Graph(False) g2.add_node(True) g2.add_node() g2.add_node(False, True) g2.add_arc(0, 1, 1) g2.add_arc(1, 2, 0) # Recognizes "ac" g3 = gtn.Graph(False) g3.add_node(True) g3.add_node() g3.add_node(False, True) g3.add_arc(0, 1, 0) g3.add_arc(1, 2, 2) symbols = {0: "a", 1: "b", 2: "c"} gtn.draw(g1, "/tmp/union_g1.pdf", symbols, symbols) gtn.draw(g2, "/tmp/union_g2.pdf", symbols, symbols) gtn.draw(g3, "/tmp/union_g3.pdf", symbols, symbols) graph = gtn.union([g1, g2, g3]) gtn.draw(graph, "/tmp/union_graph.pdf", symbols, symbols)