def test_closure_grad(self): g1 = gtn.Graph() g1.add_node(True) g1.add_node(False, True) g1.add_arc(0, 1, 0, 0, 1.3) g1.add_arc(1, 1, 1, 1, 2.1) g2 = gtn.Graph() g2.add_node(True) g2.add_node() g2.add_node() g2.add_node() g2.add_node(False, True) g2.add_arc(0, 1, 0) g2.add_arc(0, 1, 1) g2.add_arc(1, 2, 0) g2.add_arc(1, 2, 1) g2.add_arc(2, 3, 0) g2.add_arc(2, 3, 1) g2.add_arc(3, 4, 0) g2.add_arc(3, 4, 1) gtn.backward(gtn.forward_score(gtn.compose(closure(g1), g2))) def forward_fn(g, g2=g2): return gtn.forward_score(gtn.compose(closure(g), g2)) self.assertTrue(numerical_grad_check(forward_fn, g1, 1e-3, 1e-3))
def test_scalar_ops(self): g1 = gtn.Graph() g1.add_node(True) g1.add_node(False, True) g1.add_arc(0, 1, 0, 0, 1.0) # Test negate: res = gtn.negate(g1) self.assertEqual(res.item(), -1.0) gtn.backward(res) self.assertEqual(g1.grad().item(), -1.0) g1.zero_grad() g2 = gtn.Graph() g2.add_node(True) g2.add_node(False, True) g2.add_arc(0, 1, 0, 0, 3.0) # Test add: res = gtn.add(g1, g2) self.assertEqual(res.item(), 4.0) gtn.backward(res) self.assertEqual(g1.grad().item(), 1.0) self.assertEqual(g2.grad().item(), 1.0) g1.zero_grad() g2.zero_grad() # Test subtract: res = gtn.subtract(g1, g2) self.assertEqual(res.item(), -2.0) gtn.backward(res) self.assertEqual(g1.grad().item(), 1.0) self.assertEqual(g2.grad().item(), -1.0)
def test_loadsave(self): _, tmpfile = tempfile.mkstemp() g = gtn.Graph() gtn.save(tmpfile, g) g2 = gtn.load(tmpfile) self.assertTrue(gtn.equal(g, g2)) g = gtn.Graph() g.add_node(True) g.add_node(True) g.add_node() g.add_node() g.add_node(False, True) g.add_node(False, True) g.add_arc(0, 1, 0, 1, 1.1) g.add_arc(1, 2, 1, 2, 2.1) g.add_arc(2, 3, 2, 3, 3.1) g.add_arc(3, 4, 3, 4, 4.1) g.add_arc(4, 5, 4, gtn.epsilon, 5.1) gtn.save(tmpfile, g) g2 = gtn.load(tmpfile) self.assertTrue(gtn.equal(g, g2)) self.assertTrue(gtn.isomorphic(g, g2))
def test_retain_graph(self): # The graph is not retained by default g1 = gtn.Graph(True) g1.add_node(True) g1.add_node(False, True) g1.add_arc(0, 1, 0, 0, 3.0) g2 = gtn.Graph(True) g2.add_node(True) g2.add_node(False, True) g2.add_arc(0, 1, 0, 0, 3.0) result = gtn.add(g1, g2) gtn.backward(result) with self.assertRaises(ValueError): gtn.backward(result) # Check the graph is retained g1.zero_grad() g2.zero_grad() result = gtn.add(g1, g2) gtn.backward(result, True) g1.zero_grad() g2.zero_grad() result.zero_grad() gtn.backward(result, True) self.assertTrue(g1.grad().item() == 1.0) self.assertTrue(g2.grad().item() == 1.0)
def test_viterbi_score_grad(self): g = gtn.Graph() g.add_node(True) g.add_node() g.add_node(False, True) g.add_arc(0, 1, 0, 0, 1) g.add_arc(0, 1, 1, 1, 2) g.add_arc(0, 1, 2, 2, 3) g.add_arc(1, 2, 0, 0, 1) g.add_arc(1, 2, 1, 1, 2) g.add_arc(1, 2, 2, 2, 3) gtn.backward(gtn.viterbi_score(g)) expected = [0.0, 0.0, 1.0, 0.0, 0.0, 1.0] self.assertEqual(g.grad().weights_to_list(), expected) # Handle two start nodes g = gtn.Graph() g.add_node(True) g.add_node(True) g.add_node(False, True) g.add_arc(0, 1, 0, 0, -5) g.add_arc(0, 2, 0, 0, 1) g.add_arc(1, 2, 0, 0, 2) gtn.backward(gtn.viterbi_score(g)) expected = [0.0, 0.0, 1.0] self.assertEqual(g.grad().weights_to_list(), expected) # Handle two accept nodes g = gtn.Graph() g.add_node(True) g.add_node(False, True) g.add_node(False, True) g.add_arc(0, 1, 0, 0, 2) g.add_arc(0, 2, 0, 0, 2) g.add_arc(1, 2, 0, 0, 2) gtn.backward(gtn.viterbi_score(g)) expected = [1.0, 0.0, 1.0] self.assertEqual(g.grad().weights_to_list(), expected) # A more complex test case g_str = [ "0 1", "3 4", "0 1 0 0 2", "0 2 1 1 1", "1 2 0 0 2", "2 3 0 0 1", "2 3 1 1 1", "1 4 0 0 2", "2 4 1 1 3", "3 4 0 0 2", ] g = create_graph_from_text(g_str) gtn.backward(gtn.viterbi_score(g)) # two possible paths with same viterbi score expected1 = [1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0] expected2 = [1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0] self.assertTrue(g.grad().weights_to_list() == expected1 or g.grad().weights_to_list() == expected2)
def test_viterbi_score(self): # Check score of empty graph g = gtn.Graph() self.assertEqual(gtn.viterbi_score(g).item(), -math.inf) # A simple test case g = gtn.Graph() g.add_node(True) g.add_node() g.add_node(False, True) g.add_arc(0, 1, 0, 0, 1) g.add_arc(0, 1, 1, 1, 2) g.add_arc(0, 1, 2, 2, 3) g.add_arc(1, 2, 0, 0, 1) g.add_arc(1, 2, 1, 1, 2) g.add_arc(1, 2, 2, 2, 3) self.assertEqual(gtn.viterbi_score(g).item(), 6.0) # Handle two start nodes g = gtn.Graph() g.add_node(True) g.add_node(True) g.add_node(False, True) g.add_arc(0, 1, 0, 0, -5) g.add_arc(0, 2, 0, 0, 1) g.add_arc(1, 2, 0, 0, 2) self.assertEqual(gtn.viterbi_score(g).item(), 2.0) # Handle two accept nodes g = gtn.Graph() g.add_node(True) g.add_node(False, True) g.add_node(False, True) g.add_arc(0, 1, 0, 0, 2) g.add_arc(0, 2, 0, 0, 2) g.add_arc(1, 2, 0, 0, 2) self.assertEqual(gtn.viterbi_score(g).item(), 4.0) # A more complex test case g_str = [ "0 1", "3 4", "0 1 0 0 2", "0 2 1 1 1", "1 2 0 0 2", "2 3 0 0 1", "2 3 1 1 1", "1 4 0 0 2", "2 4 1 1 3", "3 4 0 0 2", ] g = create_graph_from_text(g_str) self.assertEqual(gtn.viterbi_score(g).item(), 7.0)
def test_transitions(self): num_tokens = 4 # unigram transitions = make_transitions_graph(1, num_tokens) expected = gtn.Graph() expected.add_node(True, True) for i in range(num_tokens): expected.add_arc(0, 0, i) self.assertTrue(gtn.isomorphic(transitions, expected)) # bigram transitions = make_transitions_graph(2, num_tokens) expected = gtn.Graph() expected.add_node(True, False) for i in range(num_tokens): expected.add_node(False, False) expected.add_arc(0, i + 1, i) for i in range(num_tokens): for j in range(num_tokens): expected.add_arc(i + 1, j + 1, j) expected.add_node(False, True) for i in range(num_tokens + 1): expected.add_arc(i, num_tokens + 1, gtn.epsilon) self.assertTrue(gtn.isomorphic(transitions, expected)) # trigram transitions = make_transitions_graph(3, num_tokens) expected = gtn.Graph() expected.add_node(True, False) for i in range(num_tokens): expected.add_node(False, False) expected.add_arc(0, i + 1, i) for i in range(num_tokens): for j in range(num_tokens): expected.add_node(False, False) expected.add_arc(i + 1, num_tokens * i + j + num_tokens + 1, j) for i in range(num_tokens): for j in range(num_tokens): for k in range(num_tokens): expected.add_arc( num_tokens * i + j + num_tokens + 1, num_tokens * j + k + num_tokens + 1, k, ) end_idx = expected.add_node(False, True) self.assertEqual(end_idx, num_tokens * num_tokens + num_tokens + 1) for i in range(end_idx): expected.add_arc(i, end_idx, gtn.epsilon) self.assertTrue(gtn.isomorphic(transitions, expected))
def make_transitions_graph(ngram, num_tokens, calc_grad=False): transitions = gtn.Graph(calc_grad) transitions.add_node(True, ngram == 1) state_map = {(): 0} # first build transitions which include <s>: for n in range(1, ngram): for state in itertools.product(range(num_tokens), repeat=n): in_idx = state_map[state[:-1]] out_idx = transitions.add_node(False, ngram == 1) state_map[state] = out_idx transitions.add_arc(in_idx, out_idx, state[-1]) for state in itertools.product(range(num_tokens), repeat=ngram): state_idx = state_map[state[:-1]] new_state_idx = state_map[state[1:]] # p(state[-1] | state[:-1]) transitions.add_arc(state_idx, new_state_idx, state[-1]) if ngram > 1: # build transitions which include </s>: end_idx = transitions.add_node(False, True) for in_idx in range(end_idx): transitions.add_arc(in_idx, end_idx, gtn.epsilon) return transitions
def make_chain_graph(sequence): graph = gtn.Graph(False) graph.add_node(True) for i, s in enumerate(sequence): graph.add_node(False, i == (len(sequence) - 1)) graph.add_arc(i, i + 1, s) return graph
def build_lm_graph(ngram_counts, vocab): graph = gtn.Graph(False) lm_order = len(ngram_counts) assert lm_order > 1, "build_lm_graph doesn't work for unigram LMs" state_to_node = {} def get_node(state): node = state_to_node.get(state, None) if node is not None: return node is_start = state == tuple([vocab[BOS]]) is_end = vocab[EOS] in state node = graph.add_node(is_start, is_end) state_to_node[state] = node return node for counts in ngram_counts: for ngram in counts.keys(): istate, ostate = ngram[0:-1], ngram[1 - lm_order:] inode = get_node(istate) onode = get_node(ostate) prob, bckoff = counts[ngram] # p(gram[-1] | gram[:-1]) lbl = ngram[-1] if ngram[-1] != vocab[EOS] else gtn.epsilon graph.add_arc(inode, onode, lbl, lbl, prob) if bckoff is not None and vocab[EOS] not in ngram: bnode = get_node(ngram[1:]) graph.add_arc(onode, bnode, gtn.epsilon, gtn.epsilon, bckoff) return graph
def test_autograd(self): # The graph is not retained by default g1 = gtn.scalar_graph(3.0) g2 = gtn.scalar_graph(3.0) result = gtn.add(g1, g2) gtn.backward(result) # Cannot backward twice when graph is cleared. self.assertRaises(ValueError, gtn.backward, result) # Check the graph is retained g1.zero_grad() g2.zero_grad() result = gtn.add(g1, g2) gtn.backward(result, True) result.zero_grad() g1.zero_grad() g2.zero_grad() gtn.backward(result, True) self.assertEqual(g1.grad().item(), 1.0) self.assertEqual(g2.grad().item(), 1.0) # Check that provided input gradients are used. g1.zero_grad() g2.zero_grad() result = gtn.add(g1, g2) deltas = gtn.Graph() deltas.add_node(True) deltas.add_node(False, True) deltas.add_arc(0, 1, 0, 0, 7.0) gtn.backward(result, deltas) self.assertEqual(g1.grad().item(), 7.0) self.assertEqual(g2.grad().item(), 7.0)
def test_simple_decomposition(self): T = 5 tokens = ["a", "b", "ab", "ba", "aba"] scores = torch.randn((1, T, len(tokens)), requires_grad=True) labels = [[0, 1, 0]] transducer = Transducer(tokens=tokens, graphemes_to_idx={ "a": 0, "b": 1 }) # Hand construct the alignment graph with all of the decompositions alignments = gtn.Graph(False) alignments.add_node(True) # Add the path ['a', 'b', 'a'] alignments.add_node() alignments.add_arc(0, 1, 0) alignments.add_arc(1, 1, 0) alignments.add_node() alignments.add_arc(1, 2, 1) alignments.add_arc(2, 2, 1) alignments.add_node(False, True) alignments.add_arc(2, 3, 0) alignments.add_arc(3, 3, 0) # Add the path ['a', 'ba'] alignments.add_node(False, True) alignments.add_arc(1, 4, 3) alignments.add_arc(4, 4, 3) # Add the path ['ab', 'a'] alignments.add_node() alignments.add_arc(0, 5, 2) alignments.add_arc(5, 5, 2) alignments.add_arc(5, 3, 0) # Add the path ['aba'] alignments.add_node(False, True) alignments.add_arc(0, 6, 4) alignments.add_arc(6, 6, 4) emissions = gtn.linear_graph(T, len(tokens), True) emissions.set_weights(scores.data_ptr()) expected_loss = gtn.subtract( gtn.forward_score(emissions), gtn.forward_score(gtn.intersect(emissions, alignments)), ) loss = transducer(scores, labels) self.assertAlmostEqual(loss.item(), expected_loss.item(), places=5) loss.backward() gtn.backward(expected_loss) expected_grad = torch.tensor(emissions.grad().weights_to_numpy()) expected_grad = expected_grad.view((1, T, len(tokens))) self.assertTrue( torch.allclose(scores.grad, expected_grad, rtol=1e-4, atol=1e-5))
def make_chain_graph(seq, calc_grad=False): """Make a simple chain graph from an iterable of integers.""" g = gtn.Graph(calc_grad) g.add_node(True) for e, s in enumerate(seq): g.add_node(False, e + 1 == len(seq)) g.add_arc(e, e + 1, s) return g
def gen_potentials(num_features, num_classes, calc_grad=False): """Make the unary potential graph""" g = gtn.Graph(calc_grad) g.add_node(True, True) for i in range(num_features): for c in range(num_classes): g.add_arc(0, 0, i, c) # f(i, c) return g
def test_sum(self): # Empty graph self.assertTrue(gtn.equal(gtn.union([]), gtn.Graph())) # Check single graph is a no-op g1 = gtn.Graph() g1.add_node(True) g1.add_node(False, True) g1.add_arc(0, 1, 1) self.assertTrue(gtn.equal(gtn.union([g1]), g1)) # Simple union g1 = gtn.Graph() g1.add_node(True) g1.add_node(False, True) g1.add_arc(0, 1, 1) g2 = gtn.Graph() g2.add_node(True) g2.add_node(False, True) g2.add_arc(0, 1, 0) expected = gtn.Graph() expected.add_node(True) expected.add_node(True) expected.add_node(False, True) expected.add_node(False, True) expected.add_arc(0, 2, 1) expected.add_arc(1, 3, 0) self.assertTrue(gtn.isomorphic(gtn.union([g1, g2]), expected)) # Check adding with an empty graph works g1 = gtn.Graph() g1.add_node(True) g1.add_node(False, True) g1.add_arc(0, 1, 1) g2 = gtn.Graph() g3 = gtn.Graph() g3.add_node(True, True) g3.add_arc(0, 0, 2) expected = gtn.Graph() expected.add_node(True) expected.add_node(False, True) expected.add_node(True, True) expected.add_arc(0, 1, 1) expected.add_arc(2, 2, 2) self.assertTrue(gtn.isomorphic(gtn.union([g1, g2, g3]), expected))
def fromTransitions(transition_weights, init_weights=None, final_weights=None, transition_ids=None, calc_grad=True): """ Instantiate a state machine from state transitions. Parameters ---------- Returns ------- """ num_states = transition_weights.shape[0] if transition_ids is None: _, transition_ids = makeTransitionVocabulary(transition_weights) if init_weights is None: init_weights = tuple(float(one) for __ in range(num_states)) if final_weights is None: final_weights = tuple(float(one) for __ in range(num_states)) fst = gtn.Graph(calc_grad=calc_grad) init_state = fst.add_node(start=True) final_state = fst.add_node(accept=True) def makeState(i): state = fst.add_node() transition_id = transition_ids[BOS, i] initial_weight = init_weights[i] if initial_weight != zero: fst.add_arc(init_state, state, gtn.epsilon, transition_id, initial_weight) transition_id = transition_ids[i, EOS] final_weight = final_weights[i] if final_weight != zero: fst.add_arc(state, final_state, gtn.epsilon, transition_id, final_weight) return state states = tuple(makeState(i) for i in range(num_states)) for i_cur, row in enumerate(transition_weights): for i_next, weight in enumerate(row): cur_state = states[i_cur] next_state = states[i_next] transition_id = transition_ids[i_cur, i_next] if weight != zero: fst.add_arc(cur_state, next_state, transition_id, transition_id, weight) fst.arc_sort(olabel=True) return fst
def build_setence_graph(sentence, vocab): graph = gtn.Graph(False) sidx = [vocab[w] if w in vocab else vocab[UNK] for w in sentence.split()] prev = graph.add_node(True, False) for e, idx in enumerate(sidx): cur = graph.add_node(False, e == len(sidx) - 1) graph.add_arc(prev, cur, idx) prev = cur return graph
def test_calc_grad(self): g1 = gtn.Graph(False) g1.calc_grad = True g1.add_node(True) g1.add_node(False, True) g1.add_arc(0, 1, 1, 1, 1.0) g2 = gtn.Graph(True) g2.calc_grad = False g2.add_node(True) g2.add_node(False, True) g2.add_arc(0, 0, 1, 1, 1.0) result = gtn.add(g1, g2) gtn.backward(result) self.assertTrue(g1.grad().item() == 1.0) with self.assertRaises(RuntimeError): g2.grad()
def create_force_align_graph(target): g_fal = gtn.Graph(False) L = len(target) g_fal.add_node(True) for l in range(1, L + 1): g_fal.add_node(False, l == L) g_fal.add_arc(l - 1, l, target[l - 1]) g_fal.add_arc(l, l, target[l - 1]) g_fal.arc_sort(True) return g_fal
def gen_transitions(num_classes, calc_grad=False): """Make a bigram transition graph.""" g = gtn.Graph(calc_grad) for i in range(num_classes): g.add_node(False, True) g.add_node(True, True) for i in range(num_classes): g.add_arc(num_classes, i, i) # s(<s>, i) for j in range(num_classes): g.add_arc(i, j, j) # s(i, j) return g
def test_comparisons(self): g1 = gtn.Graph() g1.add_node(True) g1.add_node(False, True) g1.add_arc(0, 1, 0) g2 = gtn.Graph() g2.add_node(True) g2.add_node(False, True) g2.add_arc(0, 1, 0) self.assertTrue(gtn.equal(g1, g2)) self.assertTrue(gtn.isomorphic(g1, g2)) g2 = gtn.Graph() g2.add_node(False, True) g2.add_node(True) g2.add_arc(1, 0, 0) self.assertFalse(gtn.equal(g1, g2)) self.assertTrue(gtn.isomorphic(g1, g2))
def test_input_grad(self): # Check that provided input gradients are used. g1 = gtn.Graph(True) g1.add_node(True) g1.add_node(False, True) g1.add_arc(0, 1, 0, 0, 3.0) g2 = gtn.Graph(True) g2.add_node(True) g2.add_node(False, True) g2.add_arc(0, 1, 0, 0, 3.0) result = gtn.add(g1, g2) deltas = gtn.Graph() deltas.add_node(True) deltas.add_node(False, True) deltas.add_arc(0, 1, 0, 0, 7.0) gtn.backward(result, deltas) self.assertTrue(g1.grad().item() == 7.0) self.assertTrue(g2.grad().item() == 7.0)
def make_lexicon_graph(word_pieces: List, graphemes_to_idx: Dict) -> gtn.Graph: """Constructs a graph which transduces letters to word pieces.""" graph = gtn.Graph(False) graph.add_node(True, True) for i, wp in enumerate(word_pieces): prev = 0 for l in wp[:-1]: n = graph.add_node() graph.add_arc(prev, n, graphemes_to_idx[l], gtn.epsilon) prev = n graph.add_arc(prev, 0, graphemes_to_idx[wp[-1]], i) graph.arc_sort() return graph
def make_target_graph(target): """ Construct the target graph for the sequence in target. Each token in target can align to one or more input frames. """ g = gtn.Graph(False) L = len(target) g.add_node(True) for l in range(1, L + 1): g.add_node(False, l == L) g.add_arc(l - 1, l, target[l - 1]) g.add_arc(l, l, target[l - 1]) g.arc_sort(True) return g
def setUp(self): g = gtn.Graph(False) g.add_node(True) g.add_node() g.add_node() g.add_node() g.add_node(False, True) g.add_arc(0, 1, 0) g.add_arc(0, 2, 1) g.add_arc(src_node=1, dst_node=2, label=0) g.add_arc(1, 1, ilabel=1, olabel=2, weight=2.1) g.add_arc(2, 3, 2) self.g = g
def test_asg_viterbi_path(self): # Test adapted from wav2letter https://tinyurl.com/yc6nxex9 T = 4 N = 3 # fmt: off input = [ 0, 0, 7, 5, 4, 3, 5, 8, 5, 5, 4, 3, ] trans = [ 0, 2, 0, 0, 0, 2, 2, 0, 0, ] expectedPath = [2, 1, 1, 0] # fmt: on transitions = gtn.Graph() transitions.add_node(True) for i in range(1, N + 1): transitions.add_node(False, True) transitions.add_arc(0, i, i - 1) # p(i | <s>) for i in range(N): for j in range(N): transitions.add_arc(j + 1, i + 1, i, i, trans[i * N + j]) # p(i | j) emissions = emissions_graph(input, T, N, True) path = gtn.viterbi_path(gtn.compose(emissions, transitions)) self.assertEqual(path.labels_to_list(), expectedPath)
def test_grad_available(self): g = gtn.Graph() g.add_node(True) g.add_node() g.add_node(False, True) g.add_arc(0, 1, 0, 0, 1) g.add_arc(0, 1, 1, 1, 2) g.add_arc(0, 1, 2, 2, 3) g.add_arc(1, 2, 0, 0, 1) g.add_arc(1, 2, 1, 1, 2) g.add_arc(1, 2, 2, 2, 3) self.assertFalse(g.is_grad_available()) gtn.backward(gtn.forward_score(g)) self.assertTrue(g.is_grad_available())
def test_compose_grad(self): first = gtn.Graph() first.add_node(True) first.add_node() first.add_node() first.add_node() first.add_node(False, True) first.add_arc(0, 1, 0, 0, 0) first.add_arc(0, 1, 1, 1, 1) first.add_arc(0, 1, 2, 2, 2) first.add_arc(1, 2, 0, 0, 0) first.add_arc(1, 2, 1, 1, 1) first.add_arc(1, 2, 2, 2, 2) first.add_arc(2, 3, 0, 0, 0) first.add_arc(2, 3, 1, 1, 1) first.add_arc(2, 3, 2, 2, 2) first.add_arc(3, 4, 0, 0, 0) first.add_arc(3, 4, 1, 1, 1) first.add_arc(3, 4, 2, 2, 2) second = gtn.Graph() second.add_node(True) second.add_node() second.add_node(False, True) second.add_arc(0, 1, 0, 0, 3.5) second.add_arc(1, 1, 0, 0, 2.5) second.add_arc(1, 2, 1, 1, 1.5) second.add_arc(2, 2, 1, 1, 4.5) composed = gtn.compose(first, second) gtn.backward(composed) gradsFirst = [1, 0, 0, 1, 1, 0, 1, 2, 0, 0, 2, 0] gradsSecond = [1, 2, 3, 2] self.assertEqual(gradsFirst, first.grad().weights_to_list()) self.assertEqual(gradsSecond, second.grad().weights_to_list())
def build_graph(ngrams: List, disable_backoff: bool = False) -> gtn.Graph: """Returns a gtn Graph based on the ngrams.""" graph = gtn.Graph(False) ngram = len(ngrams) state_to_node = {} def get_node(state: Optional[List]) -> Any: node = state_to_node.get(state, None) if node is not None: return node start = state == tuple([START_IDX]) if ngram > 1 else True end = state == tuple([END_IDX]) if ngram > 1 else True node = graph.add_node(start, end) state_to_node[state] = node if not disable_backoff and not end: # Add back off when adding node. for n in range(1, len(state) + 1): backoff_node = state_to_node.get(state[n:], None) # Epsilon transition to the back-off state. if backoff_node is not None: graph.add_arc(node, backoff_node, gtn.epsilon) break return node for grams in ngrams: for gram in grams: istate, ostate = gram[:-1], gram[len(gram) - ngram + 1:] inode = get_node(istate) if END_IDX not in gram[1:] and gram[1:] not in state_to_node: raise ValueError( "Ill formed counts: if (x, y_1, ..., y_{n-1}) is above" "the n-gram threshold, then (y_1, ..., y_{n-1}) must be" "above the (n-1)-gram threshold") if END_IDX in ostate: # Merge all state having </s> into one as final graph generated # will be similar. ostate = tuple([END_IDX]) onode = get_node(ostate) # p(gram[-1] | gram[:-1]) graph.add_arc(inode, onode, gtn.epsilon if gram[-1] == END_IDX else gram[-1]) return graph
def ctc_graph(target, blank): L = len(target) U = 2 * L + 1 ctc = gtn.Graph() for l in range(U): idx = (l - 1) // 2 ctc.add_node(l == 0, l == U - 1 or l == U - 2) label = target[idx] if l % 2 else blank ctc.add_arc(l, l, label) if l > 0: ctc.add_arc(l - 1, l, label) if l % 2 and l > 1 and label != target[idx - 1]: ctc.add_arc(l - 2, l, label) return ctc