def process(b): # Create emissions graph: emissions = gtn.linear_graph(T, C, inputs.requires_grad) cpu_data = inputs[b].cpu().contiguous() emissions.set_weights(cpu_data.data_ptr()) target = make_chain_graph(targets[b]) target.arc_sort(True) # Create token to grapheme decomposition graph tokens_target = gtn.remove(gtn.project_output(gtn.compose(target, lexicon))) tokens_target.arc_sort() # Create alignment graph: alignments = gtn.project_input( gtn.remove(gtn.compose(tokens, tokens_target)) ) alignments.arc_sort() # Add transition scores: if transitions is not None: alignments = gtn.intersect(transitions, alignments) alignments.arc_sort() loss = gtn.forward_score(gtn.intersect(emissions, alignments)) # Normalize if needed: if transitions is not None: norm = gtn.forward_score(gtn.intersect(emissions, transitions)) loss = gtn.subtract(loss, norm) losses[b] = gtn.negate(loss) # Save for backward: if emissions.calc_grad: emissions_graphs[b] = emissions
def test_scalar_ops(self): g1 = gtn.Graph() g1.add_node(True) g1.add_node(False, True) g1.add_arc(0, 1, 0, 0, 1.0) # Test negate: res = gtn.negate(g1) self.assertEqual(res.item(), -1.0) gtn.backward(res) self.assertEqual(g1.grad().item(), -1.0) g1.zero_grad() g2 = gtn.Graph() g2.add_node(True) g2.add_node(False, True) g2.add_arc(0, 1, 0, 0, 3.0) # Test add: res = gtn.add(g1, g2) self.assertEqual(res.item(), 4.0) gtn.backward(res) self.assertEqual(g1.grad().item(), 1.0) self.assertEqual(g2.grad().item(), 1.0) g1.zero_grad() g2.zero_grad() # Test subtract: res = gtn.subtract(g1, g2) self.assertEqual(res.item(), -2.0) gtn.backward(res) self.assertEqual(g1.grad().item(), 1.0) self.assertEqual(g2.grad().item(), -1.0)
def process(b): # create emission graph g_emissions = gtn.linear_graph(T, C, inputs.requires_grad) cpu_data = inputs[b].cpu().contiguous() g_emissions.set_weights(cpu_data.data_ptr()) # create transition graph g_transitions = ASGLossFunction.create_transitions_graph( transitions, calc_trans_grad) # create force align criterion graph g_fal = ASGLossFunction.create_force_align_graph(targets[b]) # compose the graphs g_fal_fwd = gtn.forward_score( gtn.intersect(gtn.intersect(g_fal, g_transitions), g_emissions)) g_fcc_fwd = gtn.forward_score( gtn.intersect(g_emissions, g_transitions)) g_loss = gtn.subtract(g_fcc_fwd, g_fal_fwd) scale = 1.0 if reduction == "mean": L = len(targets[b]) scale = 1.0 / L if L > 0 else scale elif reduction != "none": raise ValueError("invalid value for reduction '" + str(reduction) + "'") # Save for backward: losses[b] = g_loss scales[b] = scale emissions_graphs[b] = g_emissions transitions_graphs[b] = g_transitions
def test_simple_decomposition(self): T = 5 tokens = ["a", "b", "ab", "ba", "aba"] scores = torch.randn((1, T, len(tokens)), requires_grad=True) labels = [[0, 1, 0]] transducer = Transducer(tokens=tokens, graphemes_to_idx={ "a": 0, "b": 1 }) # Hand construct the alignment graph with all of the decompositions alignments = gtn.Graph(False) alignments.add_node(True) # Add the path ['a', 'b', 'a'] alignments.add_node() alignments.add_arc(0, 1, 0) alignments.add_arc(1, 1, 0) alignments.add_node() alignments.add_arc(1, 2, 1) alignments.add_arc(2, 2, 1) alignments.add_node(False, True) alignments.add_arc(2, 3, 0) alignments.add_arc(3, 3, 0) # Add the path ['a', 'ba'] alignments.add_node(False, True) alignments.add_arc(1, 4, 3) alignments.add_arc(4, 4, 3) # Add the path ['ab', 'a'] alignments.add_node() alignments.add_arc(0, 5, 2) alignments.add_arc(5, 5, 2) alignments.add_arc(5, 3, 0) # Add the path ['aba'] alignments.add_node(False, True) alignments.add_arc(0, 6, 4) alignments.add_arc(6, 6, 4) emissions = gtn.linear_graph(T, len(tokens), True) emissions.set_weights(scores.data_ptr()) expected_loss = gtn.subtract( gtn.forward_score(emissions), gtn.forward_score(gtn.intersect(emissions, alignments)), ) loss = transducer(scores, labels) self.assertAlmostEqual(loss.item(), expected_loss.item(), places=5) loss.backward() gtn.backward(expected_loss) expected_grad = torch.tensor(emissions.grad().weights_to_numpy()) expected_grad = expected_grad.view((1, T, len(tokens))) self.assertTrue( torch.allclose(scores.grad, expected_grad, rtol=1e-4, atol=1e-5))
def crf_loss(X, Y, potentials, transitions): feature_graph = gtn.compose(X, potentials) # Compute the unnormalized score of `(X, Y)` target_graph = gtn.compose(feature_graph, gtn.intersect(Y, transitions)) target_score = gtn.forward_score(target_graph) # Compute the partition function norm_graph = gtn.compose(feature_graph, transitions) norm_score = gtn.forward_score(norm_graph) return gtn.subtract(norm_score, target_score)
def test_scalar_ops(self): g1 = gtn.scalar_graph(3.0) result = gtn.negate(g1) self.assertEqual(result.item(), -3.0) g2 = gtn.scalar_graph(4.0) result = gtn.add(g1, g2) self.assertEqual(result.item(), 7.0) result = gtn.subtract(g2, g1) self.assertEqual(result.item(), 1.0)
def forward_single(b): emissions = gtn.linear_graph(T, C, inputs.requires_grad) data = inputs[b].contiguous() emissions.set_weights(data.data_ptr()) target = GTNLossFunction.make_target_graph(targets[b]) # Score the target: target_score = gtn.forward_score(gtn.intersect(target, emissions)) # Normalization term: norm = gtn.forward_score(emissions) # Compute the loss: loss = gtn.subtract(norm, target_score) # Save state for backward: losses[b] = loss emissions_graphs[b] = emissions
def seq_loss(batch_index): obs_fst = linearFstFromArray(arc_scores[batch_index].reshape( num_samples, -1)) gt_fst = fromSequence(arc_labels[batch_index]) # Compose each sequence fst individually: it seems like composition # only works for lattices denom_fst = obs_fst for seq_fst in seq_fsts: denom_fst = gtn.compose(denom_fst, seq_fst) denom_fst = gtn.project_output(denom_fst) num_fst = gtn.compose(denom_fst, gt_fst) loss = gtn.subtract(gtn.forward_score(num_fst), gtn.forward_score(denom_fst)) losses[batch_index] = loss obs_fsts[batch_index] = obs_fst
def test_scalar_ops_grad(self): g1 = gtn.scalar_graph(3.0) result = gtn.negate(g1) gtn.backward(result) self.assertEqual(g1.grad().item(), -1.0) g1.zero_grad() g2 = gtn.scalar_graph(4.0) result = gtn.add(g1, g2) gtn.backward(result) self.assertEqual(g1.grad().item(), 1.0) self.assertEqual(g2.grad().item(), 1.0) g1.zero_grad() g2.zero_grad() result = gtn.subtract(g1, g2) gtn.backward(result) self.assertEqual(g1.grad().item(), 1.0) self.assertEqual(g2.grad().item(), -1.0) g1.zero_grad() g2.zero_grad() result = gtn.add(gtn.add(g1, g2), g1) gtn.backward(result) self.assertEqual(g1.grad().item(), 2.0) self.assertEqual(g2.grad().item(), 1.0) g1.zero_grad() g2nograd = gtn.scalar_graph(4.0, False) result = gtn.add(g1, g2nograd) gtn.backward(result) self.assertEqual(g1.grad().item(), 1.0) self.assertRaises(RuntimeError, g2nograd.grad)
def test_ctc_criterion(self): # These test cases are taken from wav2letter: https:#fburl.com/msom2e4v # Test case 1 ctc = ctc_graph([0, 0], 1) emissions = emissions_graph([1.0, 0.0, 0.0, 1.0, 1.0, 0.0], 3, 2) loss = gtn.forward_score(gtn.compose(ctc, emissions)) self.assertEqual(loss.item(), 0.0) # Should be 0 since scores are normalized z = gtn.forward_score(emissions) self.assertEqual(z.item(), 0.0) # Test case 2 T = 3 N = 4 ctc = ctc_graph([1, 2], N - 1) emissions = emissions_graph([1.0] * (T * N), T, N) expected_loss = -math.log(0.25 * 0.25 * 0.25 * 5) loss = gtn.subtract(gtn.forward_score(gtn.compose(ctc, emissions)), gtn.forward_score(emissions)) self.assertAlmostEqual(-loss.item(), expected_loss) # Test case 3 T = 5 N = 6 target = [0, 1, 2, 1, 0] # generate CTC graph ctc = ctc_graph(target, N - 1) # fmt: off emissions_vec = [ 0.633766, 0.221185, 0.0917319, 0.0129757, 0.0142857, 0.0260553, 0.111121, 0.588392, 0.278779, 0.0055756, 0.00569609, 0.010436, 0.0357786, 0.633813, 0.321418, 0.00249248, 0.00272882, 0.0037688, 0.0663296, 0.643849, 0.280111, 0.00283995, 0.0035545, 0.00331533, 0.458235, 0.396634, 0.123377, 0.00648837, 0.00903441, 0.00623107, ] # fmt: on emissions = emissions_graph(emissions_vec, T, N) # The log probabilities are already normalized, # so this should be close to 0 z = gtn.forward_score(emissions) self.assertTrue(abs(z.item()) < 1e-5) loss = gtn.subtract(z, gtn.forward_score(gtn.compose(ctc, emissions))) expected_loss = 3.34211 self.assertAlmostEqual(loss.item(), expected_loss, places=5) # Check the gradients gtn.backward(loss) # fmt: off expected_grad = [ -0.366234, 0.221185, 0.0917319, 0.0129757, 0.0142857, 0.0260553, 0.111121, -0.411608, 0.278779, 0.0055756, 0.00569609, 0.010436, 0.0357786, 0.633813, -0.678582, 0.00249248, 0.00272882, 0.0037688, 0.0663296, -0.356151, 0.280111, 0.00283995, 0.0035545, 0.00331533, -0.541765, 0.396634, 0.123377, 0.00648837, 0.00903441, 0.00623107 ] # fmt: on all_close = True grad = emissions.grad() grad_weights = grad.weights_to_list() for i in range(T * N): g = grad_weights[i] all_close = all_close and (abs(expected_grad[i] - g) < 1e-5) self.assertTrue(all_close) # Test case 4 # This test case is taken from Tensor Flow CTC implementation # tinyurl.com/y9du5v5a T = 5 N = 6 target = [0, 1, 1, 0] # generate CTC graph ctc = ctc_graph(target, N - 1) # fmt: off emissions_vec = [ 0.30176, 0.28562, 0.0831517, 0.0862751, 0.0816851, 0.161508, 0.24082, 0.397533, 0.0557226, 0.0546814, 0.0557528, 0.19549, 0.230246, 0.450868, 0.0389607, 0.038309, 0.0391602, 0.202456, 0.280884, 0.429522, 0.0326593, 0.0339046, 0.0326856, 0.190345, 0.423286, 0.315517, 0.0338439, 0.0393744, 0.0339315, 0.154046, ] # fmt: on emissions = emissions_graph(emissions_vec, T, N) # The log probabilities are already normalized, # so this should be close to 0 z = gtn.forward_score(emissions) self.assertTrue(abs(z.item()) < 1e-5) loss = gtn.subtract(z, gtn.forward_score(gtn.compose(ctc, emissions))) expected_loss = 5.42262 self.assertAlmostEqual(loss.item(), expected_loss, places=4) # Check the gradients gtn.backward(loss) # fmt: off expected_grad = [ -0.69824, 0.28562, 0.0831517, 0.0862751, 0.0816851, 0.161508, 0.24082, -0.602467, 0.0557226, 0.0546814, 0.0557528, 0.19549, 0.230246, 0.450868, 0.0389607, 0.038309, 0.0391602, -0.797544, 0.280884, -0.570478, 0.0326593, 0.0339046, 0.0326856, 0.190345, -0.576714, 0.315517, 0.0338439, 0.0393744, 0.0339315, 0.154046, ] # fmt: on all_close = True grad = emissions.grad() grad_weights = grad.weights_to_list() for i in range(T * N): g = grad_weights[i] all_close = all_close and (abs(expected_grad[i] - g) < 1e-5) self.assertTrue(all_close)
def test_asg_criterion(self): # This test cases is taken from wav2letter: https://fburl.com/msom2e4v T = 5 N = 6 # fmt: off targets = [ [2, 1, 5, 1, 3], [4, 3, 5], [3, 2, 2, 1], ] expected_loss = [ 7.7417464256287, 6.4200420379639, 8.2780694961548, ] emissions_vecs = [ [ -0.4340, -0.0254, 0.3667, 0.4180, -0.3805, -0.1707, 0.1060, 0.3631, -0.1122, -0.3825, -0.0031, -0.3801, 0.0443, -0.3795, 0.3194, -0.3130, 0.0094, 0.1560, 0.1252, 0.2877, 0.1997, -0.4554, 0.2774, -0.2526, -0.4001, -0.2402, 0.1295, 0.0172, 0.1805, -0.3299 ], [ 0.3298, -0.2259, -0.0959, 0.4909, 0.2996, -0.2543, -0.2863, 0.3239, -0.3988, 0.0732, -0.2107, -0.4739, -0.0906, 0.0480, -0.1301, 0.3975, -0.3317, -0.1967, 0.4372, -0.2006, 0.0094, 0.3281, 0.1873, -0.2945, 0.2399, 0.0320, -0.3768, -0.2849, -0.2248, 0.3186, ], [ 0.0225, -0.3867, -0.1929, -0.2904, -0.4958, -0.2533, 0.4001, -0.1517, -0.2799, -0.2915, 0.4198, 0.4506, 0.1446, -0.4753, -0.0711, 0.2876, -0.1851, -0.1066, 0.2081, -0.1190, -0.3902, -0.1668, 0.1911, -0.2848, -0.3846, 0.1175, 0.1052, 0.2172, -0.0362, 0.3055, ], ] emissions_grads = [ [ 0.1060, 0.1595, -0.7639, 0.2485, 0.1118, 0.1380, 0.1915, -0.7524, 0.1539, 0.1175, 0.1717, 0.1178, 0.1738, 0.1137, 0.2288, 0.1216, 0.1678, -0.8057, 0.1766, -0.7923, 0.1902, 0.0988, 0.2056, 0.1210, 0.1212, 0.1422, 0.2059, -0.8160, 0.2166, 0.1300, ], [ 0.2029, 0.1164, 0.1325, 0.2383, -0.8032, 0.1131, 0.1414, 0.2602, 0.1263, -0.3441, -0.3009, 0.1172, 0.1557, 0.1788, 0.1496, -0.5498, 0.0140, 0.0516, 0.2306, 0.1219, 0.1503, -0.4244, 0.1796, -0.2579, 0.2149, 0.1745, 0.1160, 0.1271, 0.1350, -0.7675, ], [ 0.2195, 0.1458, 0.1770, -0.8395, 0.1307, 0.1666, 0.2148, 0.1237, -0.6613, -0.1223, 0.2191, 0.2259, 0.2002, 0.1077, -0.8386, 0.2310, 0.1440, 0.1557, 0.2197, -0.1466, -0.5742, 0.1510, 0.2160, 0.1342, 0.1050, -0.8265, 0.1714, 0.1917, 0.1488, 0.2094, ], ] # fmt: on transitions = gtn.Graph() transitions.add_node(True) for i in range(1, N + 1): transitions.add_node(False, True) transitions.add_arc(0, i, i - 1) # p(i | <s>) for i in range(N): for j in range(N): transitions.add_arc(j + 1, i + 1, i) # p(i | j) for b in range(len(targets)): target = targets[b] emissions_vec = emissions_vecs[b] emissions_grad = emissions_grads[b] fal = gtn.Graph() fal.add_node(True) for l in range(1, len(target) + 1): fal.add_node(False, l == len(target)) fal.add_arc(l - 1, l, target[l - 1]) fal.add_arc(l, l, target[l - 1]) emissions = emissions_graph(emissions_vec, T, N, True) loss = gtn.subtract( gtn.forward_score(gtn.compose(emissions, transitions)), gtn.forward_score( gtn.compose(gtn.compose(fal, transitions), emissions)), ) self.assertAlmostEqual(loss.item(), expected_loss[b], places=3) # Check the gradients gtn.backward(loss) all_close = True grad = emissions.grad() grad_weights = grad.weights_to_list() for i in range(T * N): g = grad_weights[i] all_close = all_close and (abs(emissions_grad[i] - g) < 1e-4) self.assertTrue(all_close) all_close = True # fmt: off trans_grad = [ 0.3990, 0.3396, 0.3486, 0.3922, 0.3504, 0.3155, 0.3666, 0.0116, -1.6678, 0.3737, 0.3361, -0.7152, 0.3468, 0.3163, -1.1583, -0.6803, 0.3216, 0.2722, 0.3694, -0.6688, 0.3047, -0.8531, -0.6571, 0.2870, 0.3866, 0.3321, 0.3447, 0.3664, -0.2163, 0.3039, 0.3640, -0.6943, 0.2988, -0.6722, 0.3215, -0.1860, ] # fmt: on grad = transitions.grad() grad_weights = grad.weights_to_list() for i in range(N * N): g = grad_weights[i + N] all_close = all_close and (abs(trans_grad[i] - g) < 1e-4) self.assertTrue(all_close)
def main(out_dir=None, gpu_dev_id=None, num_samples=10, random_seed=None, learning_rate=1e-3, num_epochs=500, dataset_kwargs={}, dataloader_kwargs={}, model_kwargs={}): if out_dir is None: out_dir = os.path.join('~', 'data', 'output', 'seqtools', 'test_gtn') out_dir = os.path.expanduser(out_dir) if not os.path.exists(out_dir): os.makedirs(out_dir) fig_dir = os.path.join(out_dir, 'figures') if not os.path.exists(fig_dir): os.makedirs(fig_dir) vocabulary = ['a', 'b', 'c', 'd', 'e'] transition = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 1, 0], [0, 0, 0, 0, 1], [0, 1, 0, 0, 1], [0, 0, 0, 0, 0]], dtype=float) initial = np.array([1, 0, 1, 0, 0], dtype=float) final = np.array([0, 1, 0, 0, 1], dtype=float) / 10 seq_params = (transition, initial, final) simulated_dataset = simulate(num_samples, *seq_params) label_seqs, obsv_seqs = tuple(zip(*simulated_dataset)) seq_params = tuple(map(lambda x: -np.log(x), seq_params)) dataset = torchutils.SequenceDataset(obsv_seqs, label_seqs, **dataset_kwargs) data_loader = torch.utils.data.DataLoader(dataset, **dataloader_kwargs) train_loader = data_loader val_loader = data_loader transition_weights = torch.tensor(transition, dtype=torch.float).log() initial_weights = torch.tensor(initial, dtype=torch.float).log() final_weights = torch.tensor(final, dtype=torch.float).log() model = libfst.LatticeCrf(vocabulary, transition_weights=transition_weights, initial_weights=initial_weights, final_weights=final_weights, debug_output_dir=fig_dir, **model_kwargs) gtn.draw(model._transition_fst, os.path.join(fig_dir, 'transitions-init.png'), isymbols=model._arc_symbols, osymbols=model._arc_symbols) gtn.draw(model._duration_fst, os.path.join(fig_dir, 'durations-init.png'), isymbols=model._arc_symbols, osymbols=model._arc_symbols) if True: for i, (inputs, targets, seq_id) in enumerate(train_loader): arc_scores = model.scores_to_arc(inputs) arc_labels = model.labels_to_arc(targets) batch_size, num_samples, num_classes = arc_scores.shape obs_fst = libfst.linearFstFromArray(arc_scores[0].reshape( num_samples, -1)) gt_fst = libfst.fromSequence(arc_labels[0]) d1_fst = gtn.compose(obs_fst, model._duration_fst) d1_fst = gtn.project_output(d1_fst) denom_fst = gtn.compose(d1_fst, model._transition_fst) # denom_fst = gtn.project_output(denom_fst) num_fst = gtn.compose(denom_fst, gt_fst) viterbi_fst = gtn.viterbi_path(denom_fst) pred_fst = gtn.remove(gtn.project_output(viterbi_fst)) loss = gtn.subtract(gtn.forward_score(num_fst), gtn.forward_score(denom_fst)) loss = torch.tensor(loss.item()) if torch.isinf(loss).any(): denom_alt = gtn.compose(obs_fst, model._transition_fst) d1_min = gtn.remove(gtn.project_output(d1_fst)) denom_alt = gtn.compose(d1_min, model._transition_fst) num_alt = gtn.compose(denom_alt, gt_fst) gtn.draw(obs_fst, os.path.join(fig_dir, 'observations-init.png'), isymbols=model._arc_symbols, osymbols=model._arc_symbols) gtn.draw(gt_fst, os.path.join(fig_dir, 'labels-init.png'), isymbols=model._arc_symbols, osymbols=model._arc_symbols) gtn.draw(d1_fst, os.path.join(fig_dir, 'd1-init.png'), isymbols=model._arc_symbols, osymbols=model._arc_symbols) gtn.draw(d1_min, os.path.join(fig_dir, 'd1-min-init.png'), isymbols=model._arc_symbols, osymbols=model._arc_symbols) gtn.draw(denom_fst, os.path.join(fig_dir, 'denominator-init.png'), isymbols=model._arc_symbols, osymbols=model._arc_symbols) gtn.draw(denom_alt, os.path.join(fig_dir, 'denominator-alt-init.png'), isymbols=model._arc_symbols, osymbols=model._arc_symbols) gtn.draw(num_fst, os.path.join(fig_dir, 'numerator-init.png'), isymbols=model._arc_symbols, osymbols=model._arc_symbols) gtn.draw(num_alt, os.path.join(fig_dir, 'numerator-alt-init.png'), isymbols=model._arc_symbols, osymbols=model._arc_symbols) gtn.draw(viterbi_fst, os.path.join(fig_dir, 'viterbi-init.png'), isymbols=model._arc_symbols, osymbols=model._arc_symbols) gtn.draw(pred_fst, os.path.join(fig_dir, 'pred-init.png'), isymbols=model._arc_symbols, osymbols=model._arc_symbols) import pdb pdb.set_trace() # Train the model train_epoch_log = collections.defaultdict(list) val_epoch_log = collections.defaultdict(list) metric_dict = { 'Avg Loss': metrics.AverageLoss(), 'Accuracy': metrics.Accuracy() } criterion = model.nllLoss optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=1.00) model, last_model_wts = torchutils.trainModel( model, criterion, optimizer, scheduler, train_loader, val_loader, metrics=metric_dict, test_metric='Avg Loss', train_epoch_log=train_epoch_log, val_epoch_log=val_epoch_log, num_epochs=num_epochs) gtn.draw(model._transition_fst, os.path.join(fig_dir, 'transitions-trained.png'), isymbols=model._arc_symbols, osymbols=model._arc_symbols) gtn.draw(model._duration_fst, os.path.join(fig_dir, 'durations-trained.png'), isymbols=model._arc_symbols, osymbols=model._arc_symbols) torchutils.plotEpochLog(train_epoch_log, title="Train Epoch Log", fn=os.path.join(fig_dir, "train-log.png"))
tokens = token_graph(word_pieces) gtn.draw(tokens, "tokens.pdf", idx_to_wp, idx_to_wp) # Recognizes "abc": abc = gtn.Graph(False) abc.add_node(True) abc.add_node() abc.add_node() abc.add_node(False, True) abc.add_arc(0, 1, let_to_idx["a"]) abc.add_arc(1, 2, let_to_idx["b"]) abc.add_arc(2, 3, let_to_idx["c"]) gtn.draw(abc, "abc.pdf", idx_to_let) # Compute the decomposition graph for "abc": abc_decomps = gtn.remove(gtn.project_output(gtn.compose(abc, lex))) gtn.draw(abc_decomps, "abc_decomps.pdf", idx_to_wp, idx_to_wp) # Compute the alignment graph for "abc": abc_alignments = gtn.project_input( gtn.remove(gtn.compose(tokens, abc_decomps))) gtn.draw(abc_alignments, "abc_alignments.pdf", idx_to_wp) # From here we can use the alignment graph with an emissions graph and # transitions graphs to compute the sequence level criterion: emissions = gtn.linear_graph(10, len(word_pieces), True) loss = gtn.subtract( gtn.forward_score(emissions), gtn.forward_score(gtn.intersect(emissions, abc_alignments))) print(f"Loss is {loss.item():.2f}")