예제 #1
0
    def test_sum(self):

        # Empty graph
        self.assertTrue(gtn.equal(gtn.union([]), gtn.Graph()))

        # Check single graph is a no-op
        g1 = gtn.Graph()
        g1.add_node(True)
        g1.add_node(False, True)
        g1.add_arc(0, 1, 1)
        self.assertTrue(gtn.equal(gtn.union([g1]), g1))

        # Simple union
        g1 = gtn.Graph()
        g1.add_node(True)
        g1.add_node(False, True)
        g1.add_arc(0, 1, 1)

        g2 = gtn.Graph()
        g2.add_node(True)
        g2.add_node(False, True)
        g2.add_arc(0, 1, 0)

        expected = gtn.Graph()
        expected.add_node(True)
        expected.add_node(True)
        expected.add_node(False, True)
        expected.add_node(False, True)
        expected.add_arc(0, 2, 1)
        expected.add_arc(1, 3, 0)
        self.assertTrue(gtn.isomorphic(gtn.union([g1, g2]), expected))

        # Check adding with an empty graph works
        g1 = gtn.Graph()
        g1.add_node(True)
        g1.add_node(False, True)
        g1.add_arc(0, 1, 1)

        g2 = gtn.Graph()

        g3 = gtn.Graph()
        g3.add_node(True, True)
        g3.add_arc(0, 0, 2)

        expected = gtn.Graph()
        expected.add_node(True)
        expected.add_node(False, True)
        expected.add_node(True, True)
        expected.add_arc(0, 1, 1)
        expected.add_arc(2, 2, 2)
        self.assertTrue(gtn.isomorphic(gtn.union([g1, g2, g3]), expected))
예제 #2
0
 def forward_fn(g):
     paths = [
         gtn.viterbi_path(g),
         gtn.viterbi_path(g),
         gtn.viterbi_path(g)
     ]
     return gtn.forward_score(gtn.union(paths))
예제 #3
0
def token_graph(token_list):
    """
    Constructs a graph with all the individual
    token transition models.
    """
    tokens = []
    for i, wp in enumerate(token_list):
        # We can consume one or more consecutive
        # word pieces for each emission:
        # E.g. [ab, ab, ab] transduces to [ab]
        graph = gtn.Graph()
        graph.add_node(True)
        graph.add_node(False, True)
        graph.add_arc(0, 1, i, i)
        graph.add_arc(1, 1, i, gtn.epsilon)
        tokens.append(graph)
    return gtn.closure(gtn.union(tokens))
예제 #4
0
def lexicon_graph(word_pieces, letters_to_idx):
    """
    Constructs a graph which transudces letters to word pieces.
    """
    lex = []
    for i, wp in enumerate(word_pieces):
        graph = gtn.Graph()
        graph.add_node(True)
        for e, l in enumerate(wp):
            if e == len(wp) - 1:
                graph.add_node(False, True)
                graph.add_arc(e, e + 1, letters_to_idx[l], i)
            else:
                graph.add_node()
                graph.add_arc(e, e + 1, letters_to_idx[l], gtn.epsilon)
        lex.append(graph)
    return gtn.closure(gtn.union(lex))
예제 #5
0
        def test_sum_grad(self):
            g1 = gtn.Graph()
            g1.add_node(True)
            g1.add_node()
            g1.add_node(False, True)
            g1.add_arc(0, 1, 0)
            g1.add_arc(1, 2, 1)

            # Works with a no gradient graph
            g2 = gtn.Graph()(False)
            g2.add_node(True)
            g2.add_node()
            g2.add_node(False, True)
            g2.add_arc(0, 1, 0)
            g2.add_arc(1, 2, 1)

            g3 = gtn.Graph()
            g3.add_node(True)
            g3.add_node()
            g3.add_node(False, True)
            g3.add_arc(0, 1, 0)
            g3.add_arc(1, 2, 1)

            gtn.backward(gtn.forward_score(gtn.union([g1, g2, g3])))

            def forward_fn1(g, g2=g2, g3=g3):
                return gtn.forward_score(gtn.union([g, g2, g3]))

            self.assertTrue(numerical_grad_check(forward_fn1, g1, 1e-4, 1e-3))

            def forward_fn2(g, g1=g1, g2=g2):
                return gtn.forward_score(gtn.union([g1, g2, g]))

            self.assertTrue(numerical_grad_check(forward_fn2, g3, 1e-4, 1e-3))

            CHECK_THROWS(g2.grad())
예제 #6
0
 def forward_fn2(g, g1=g1, g2=g2):
     return gtn.forward_score(gtn.union([g1, g2, g]))
예제 #7
0
 def forward_fn1(g, g2=g2, g3=g3):
     return gtn.forward_score(gtn.union([g, g2, g3]))
예제 #8
0
g1.add_node(False, True)
g1.add_arc(0, 1, 0)
g1.add_arc(1, 2, 1)
g1.add_arc(2, 2, 0)

# Recognizes "ba"
g2 = gtn.Graph(False)
g2.add_node(True)
g2.add_node()
g2.add_node(False, True)
g2.add_arc(0, 1, 1)
g2.add_arc(1, 2, 0)

# Recognizes "ac"
g3 = gtn.Graph(False)
g3.add_node(True)
g3.add_node()
g3.add_node(False, True)
g3.add_arc(0, 1, 0)
g3.add_arc(1, 2, 2)

symbols = {0: "a", 1: "b", 2: "c"}

gtn.draw(g1, "/tmp/union_g1.pdf", symbols, symbols)
gtn.draw(g2, "/tmp/union_g2.pdf", symbols, symbols)
gtn.draw(g3, "/tmp/union_g3.pdf", symbols, symbols)

graph = gtn.union([g1, g2, g3])

gtn.draw(graph, "/tmp/union_graph.pdf", symbols, symbols)