예제 #1
0
        def process(b):
            # Create emissions graph:
            emissions = gtn.linear_graph(T, C, inputs.requires_grad)
            cpu_data = inputs[b].cpu().contiguous()
            emissions.set_weights(cpu_data.data_ptr())
            target = make_chain_graph(targets[b])
            target.arc_sort(True)

            # Create token to grapheme decomposition graph
            tokens_target = gtn.remove(gtn.project_output(gtn.compose(target, lexicon)))
            tokens_target.arc_sort()

            # Create alignment graph:
            alignments = gtn.project_input(
                gtn.remove(gtn.compose(tokens, tokens_target))
            )
            alignments.arc_sort()

            # Add transition scores:
            if transitions is not None:
                alignments = gtn.intersect(transitions, alignments)
                alignments.arc_sort()

            loss = gtn.forward_score(gtn.intersect(emissions, alignments))

            # Normalize if needed:
            if transitions is not None:
                norm = gtn.forward_score(gtn.intersect(emissions, transitions))
                loss = gtn.subtract(loss, norm)

            losses[b] = gtn.negate(loss)

            # Save for backward:
            if emissions.calc_grad:
                emissions_graphs[b] = emissions
예제 #2
0
    def test_scalar_ops(self):
        g1 = gtn.Graph()
        g1.add_node(True)
        g1.add_node(False, True)
        g1.add_arc(0, 1, 0, 0, 1.0)

        # Test negate:
        res = gtn.negate(g1)
        self.assertEqual(res.item(), -1.0)
        gtn.backward(res)
        self.assertEqual(g1.grad().item(), -1.0)
        g1.zero_grad()

        g2 = gtn.Graph()
        g2.add_node(True)
        g2.add_node(False, True)
        g2.add_arc(0, 1, 0, 0, 3.0)

        # Test add:
        res = gtn.add(g1, g2)
        self.assertEqual(res.item(), 4.0)
        gtn.backward(res)
        self.assertEqual(g1.grad().item(), 1.0)
        self.assertEqual(g2.grad().item(), 1.0)
        g1.zero_grad()
        g2.zero_grad()

        # Test subtract:
        res = gtn.subtract(g1, g2)
        self.assertEqual(res.item(), -2.0)
        gtn.backward(res)
        self.assertEqual(g1.grad().item(), 1.0)
        self.assertEqual(g2.grad().item(), -1.0)
예제 #3
0
        def process(b):
            # create emission graph
            g_emissions = gtn.linear_graph(T, C, inputs.requires_grad)
            cpu_data = inputs[b].cpu().contiguous()
            g_emissions.set_weights(cpu_data.data_ptr())

            # create transition graph
            g_transitions = ASGLossFunction.create_transitions_graph(
                transitions, calc_trans_grad)

            # create force align criterion graph
            g_fal = ASGLossFunction.create_force_align_graph(targets[b])

            # compose the graphs
            g_fal_fwd = gtn.forward_score(
                gtn.intersect(gtn.intersect(g_fal, g_transitions),
                              g_emissions))
            g_fcc_fwd = gtn.forward_score(
                gtn.intersect(g_emissions, g_transitions))
            g_loss = gtn.subtract(g_fcc_fwd, g_fal_fwd)
            scale = 1.0
            if reduction == "mean":
                L = len(targets[b])
                scale = 1.0 / L if L > 0 else scale
            elif reduction != "none":
                raise ValueError("invalid value for reduction '" +
                                 str(reduction) + "'")

            # Save for backward:
            losses[b] = g_loss
            scales[b] = scale
            emissions_graphs[b] = g_emissions
            transitions_graphs[b] = g_transitions
    def test_simple_decomposition(self):
        T = 5
        tokens = ["a", "b", "ab", "ba", "aba"]
        scores = torch.randn((1, T, len(tokens)), requires_grad=True)
        labels = [[0, 1, 0]]
        transducer = Transducer(tokens=tokens,
                                graphemes_to_idx={
                                    "a": 0,
                                    "b": 1
                                })

        # Hand construct the alignment graph with all of the decompositions
        alignments = gtn.Graph(False)
        alignments.add_node(True)

        # Add the path ['a', 'b', 'a']
        alignments.add_node()
        alignments.add_arc(0, 1, 0)
        alignments.add_arc(1, 1, 0)
        alignments.add_node()
        alignments.add_arc(1, 2, 1)
        alignments.add_arc(2, 2, 1)
        alignments.add_node(False, True)
        alignments.add_arc(2, 3, 0)
        alignments.add_arc(3, 3, 0)

        # Add the path ['a', 'ba']
        alignments.add_node(False, True)
        alignments.add_arc(1, 4, 3)
        alignments.add_arc(4, 4, 3)

        # Add the path ['ab', 'a']
        alignments.add_node()
        alignments.add_arc(0, 5, 2)
        alignments.add_arc(5, 5, 2)
        alignments.add_arc(5, 3, 0)

        # Add the path ['aba']
        alignments.add_node(False, True)
        alignments.add_arc(0, 6, 4)
        alignments.add_arc(6, 6, 4)

        emissions = gtn.linear_graph(T, len(tokens), True)

        emissions.set_weights(scores.data_ptr())
        expected_loss = gtn.subtract(
            gtn.forward_score(emissions),
            gtn.forward_score(gtn.intersect(emissions, alignments)),
        )

        loss = transducer(scores, labels)
        self.assertAlmostEqual(loss.item(), expected_loss.item(), places=5)
        loss.backward()
        gtn.backward(expected_loss)

        expected_grad = torch.tensor(emissions.grad().weights_to_numpy())
        expected_grad = expected_grad.view((1, T, len(tokens)))
        self.assertTrue(
            torch.allclose(scores.grad, expected_grad, rtol=1e-4, atol=1e-5))
예제 #5
0
def crf_loss(X, Y, potentials, transitions):
    feature_graph = gtn.compose(X, potentials)

    # Compute the unnormalized score of `(X, Y)`
    target_graph = gtn.compose(feature_graph, gtn.intersect(Y, transitions))
    target_score = gtn.forward_score(target_graph)

    # Compute the partition function
    norm_graph = gtn.compose(feature_graph, transitions)
    norm_score = gtn.forward_score(norm_graph)

    return gtn.subtract(norm_score, target_score)
예제 #6
0
    def test_scalar_ops(self):
        g1 = gtn.scalar_graph(3.0)

        result = gtn.negate(g1)
        self.assertEqual(result.item(), -3.0)

        g2 = gtn.scalar_graph(4.0)

        result = gtn.add(g1, g2)
        self.assertEqual(result.item(), 7.0)

        result = gtn.subtract(g2, g1)
        self.assertEqual(result.item(), 1.0)
예제 #7
0
        def forward_single(b):
            emissions = gtn.linear_graph(T, C, inputs.requires_grad)
            data = inputs[b].contiguous()
            emissions.set_weights(data.data_ptr())

            target = GTNLossFunction.make_target_graph(targets[b])

            # Score the target:
            target_score = gtn.forward_score(gtn.intersect(target, emissions))

            # Normalization term:
            norm = gtn.forward_score(emissions)

            # Compute the loss:
            loss = gtn.subtract(norm, target_score)

            # Save state for backward:
            losses[b] = loss
            emissions_graphs[b] = emissions
예제 #8
0
        def seq_loss(batch_index):
            obs_fst = linearFstFromArray(arc_scores[batch_index].reshape(
                num_samples, -1))
            gt_fst = fromSequence(arc_labels[batch_index])

            # Compose each sequence fst individually: it seems like composition
            # only works for lattices
            denom_fst = obs_fst
            for seq_fst in seq_fsts:
                denom_fst = gtn.compose(denom_fst, seq_fst)
                denom_fst = gtn.project_output(denom_fst)

            num_fst = gtn.compose(denom_fst, gt_fst)

            loss = gtn.subtract(gtn.forward_score(num_fst),
                                gtn.forward_score(denom_fst))

            losses[batch_index] = loss
            obs_fsts[batch_index] = obs_fst
예제 #9
0
    def test_scalar_ops_grad(self):
        g1 = gtn.scalar_graph(3.0)

        result = gtn.negate(g1)
        gtn.backward(result)
        self.assertEqual(g1.grad().item(), -1.0)

        g1.zero_grad()

        g2 = gtn.scalar_graph(4.0)

        result = gtn.add(g1, g2)
        gtn.backward(result)
        self.assertEqual(g1.grad().item(), 1.0)
        self.assertEqual(g2.grad().item(), 1.0)

        g1.zero_grad()
        g2.zero_grad()

        result = gtn.subtract(g1, g2)
        gtn.backward(result)
        self.assertEqual(g1.grad().item(), 1.0)
        self.assertEqual(g2.grad().item(), -1.0)
        g1.zero_grad()
        g2.zero_grad()

        result = gtn.add(gtn.add(g1, g2), g1)
        gtn.backward(result)
        self.assertEqual(g1.grad().item(), 2.0)
        self.assertEqual(g2.grad().item(), 1.0)
        g1.zero_grad()

        g2nograd = gtn.scalar_graph(4.0, False)

        result = gtn.add(g1, g2nograd)
        gtn.backward(result)
        self.assertEqual(g1.grad().item(), 1.0)
        self.assertRaises(RuntimeError, g2nograd.grad)
예제 #10
0
    def test_ctc_criterion(self):
        # These test cases are taken from wav2letter: https:#fburl.com/msom2e4v

        # Test case 1
        ctc = ctc_graph([0, 0], 1)

        emissions = emissions_graph([1.0, 0.0, 0.0, 1.0, 1.0, 0.0], 3, 2)

        loss = gtn.forward_score(gtn.compose(ctc, emissions))
        self.assertEqual(loss.item(), 0.0)

        # Should be 0 since scores are normalized
        z = gtn.forward_score(emissions)
        self.assertEqual(z.item(), 0.0)

        # Test case 2
        T = 3
        N = 4
        ctc = ctc_graph([1, 2], N - 1)
        emissions = emissions_graph([1.0] * (T * N), T, N)

        expected_loss = -math.log(0.25 * 0.25 * 0.25 * 5)

        loss = gtn.subtract(gtn.forward_score(gtn.compose(ctc, emissions)),
                            gtn.forward_score(emissions))
        self.assertAlmostEqual(-loss.item(), expected_loss)

        # Test case 3
        T = 5
        N = 6
        target = [0, 1, 2, 1, 0]

        # generate CTC graph
        ctc = ctc_graph(target, N - 1)

        # fmt: off
        emissions_vec = [
            0.633766,
            0.221185,
            0.0917319,
            0.0129757,
            0.0142857,
            0.0260553,
            0.111121,
            0.588392,
            0.278779,
            0.0055756,
            0.00569609,
            0.010436,
            0.0357786,
            0.633813,
            0.321418,
            0.00249248,
            0.00272882,
            0.0037688,
            0.0663296,
            0.643849,
            0.280111,
            0.00283995,
            0.0035545,
            0.00331533,
            0.458235,
            0.396634,
            0.123377,
            0.00648837,
            0.00903441,
            0.00623107,
        ]
        # fmt: on

        emissions = emissions_graph(emissions_vec, T, N)

        # The log probabilities are already normalized,
        # so this should be close to 0
        z = gtn.forward_score(emissions)
        self.assertTrue(abs(z.item()) < 1e-5)

        loss = gtn.subtract(z, gtn.forward_score(gtn.compose(ctc, emissions)))
        expected_loss = 3.34211
        self.assertAlmostEqual(loss.item(), expected_loss, places=5)

        # Check the gradients
        gtn.backward(loss)

        # fmt: off
        expected_grad = [
            -0.366234, 0.221185, 0.0917319, 0.0129757, 0.0142857, 0.0260553,
            0.111121, -0.411608, 0.278779, 0.0055756, 0.00569609, 0.010436,
            0.0357786, 0.633813, -0.678582, 0.00249248, 0.00272882, 0.0037688,
            0.0663296, -0.356151, 0.280111, 0.00283995, 0.0035545, 0.00331533,
            -0.541765, 0.396634, 0.123377, 0.00648837, 0.00903441, 0.00623107
        ]
        # fmt: on
        all_close = True
        grad = emissions.grad()
        grad_weights = grad.weights_to_list()
        for i in range(T * N):
            g = grad_weights[i]
            all_close = all_close and (abs(expected_grad[i] - g) < 1e-5)

        self.assertTrue(all_close)

        # Test case 4
        # This test case is  taken from Tensor Flow CTC implementation
        # tinyurl.com/y9du5v5a
        T = 5
        N = 6
        target = [0, 1, 1, 0]

        # generate CTC graph
        ctc = ctc_graph(target, N - 1)
        # fmt: off
        emissions_vec = [
            0.30176,
            0.28562,
            0.0831517,
            0.0862751,
            0.0816851,
            0.161508,
            0.24082,
            0.397533,
            0.0557226,
            0.0546814,
            0.0557528,
            0.19549,
            0.230246,
            0.450868,
            0.0389607,
            0.038309,
            0.0391602,
            0.202456,
            0.280884,
            0.429522,
            0.0326593,
            0.0339046,
            0.0326856,
            0.190345,
            0.423286,
            0.315517,
            0.0338439,
            0.0393744,
            0.0339315,
            0.154046,
        ]
        # fmt: on

        emissions = emissions_graph(emissions_vec, T, N)

        # The log probabilities are already normalized,
        # so this should be close to 0
        z = gtn.forward_score(emissions)
        self.assertTrue(abs(z.item()) < 1e-5)

        loss = gtn.subtract(z, gtn.forward_score(gtn.compose(ctc, emissions)))
        expected_loss = 5.42262
        self.assertAlmostEqual(loss.item(), expected_loss, places=4)

        # Check the gradients
        gtn.backward(loss)
        # fmt: off
        expected_grad = [
            -0.69824,
            0.28562,
            0.0831517,
            0.0862751,
            0.0816851,
            0.161508,
            0.24082,
            -0.602467,
            0.0557226,
            0.0546814,
            0.0557528,
            0.19549,
            0.230246,
            0.450868,
            0.0389607,
            0.038309,
            0.0391602,
            -0.797544,
            0.280884,
            -0.570478,
            0.0326593,
            0.0339046,
            0.0326856,
            0.190345,
            -0.576714,
            0.315517,
            0.0338439,
            0.0393744,
            0.0339315,
            0.154046,
        ]
        # fmt: on

        all_close = True
        grad = emissions.grad()
        grad_weights = grad.weights_to_list()
        for i in range(T * N):
            g = grad_weights[i]
            all_close = all_close and (abs(expected_grad[i] - g) < 1e-5)
        self.assertTrue(all_close)
예제 #11
0
    def test_asg_criterion(self):
        # This test cases is taken from wav2letter: https://fburl.com/msom2e4v
        T = 5
        N = 6

        # fmt: off
        targets = [
            [2, 1, 5, 1, 3],
            [4, 3, 5],
            [3, 2, 2, 1],
        ]

        expected_loss = [
            7.7417464256287,
            6.4200420379639,
            8.2780694961548,
        ]

        emissions_vecs = [
            [
                -0.4340, -0.0254, 0.3667, 0.4180, -0.3805, -0.1707, 0.1060,
                0.3631, -0.1122, -0.3825, -0.0031, -0.3801, 0.0443, -0.3795,
                0.3194, -0.3130, 0.0094, 0.1560, 0.1252, 0.2877, 0.1997,
                -0.4554, 0.2774, -0.2526, -0.4001, -0.2402, 0.1295, 0.0172,
                0.1805, -0.3299
            ],
            [
                0.3298,
                -0.2259,
                -0.0959,
                0.4909,
                0.2996,
                -0.2543,
                -0.2863,
                0.3239,
                -0.3988,
                0.0732,
                -0.2107,
                -0.4739,
                -0.0906,
                0.0480,
                -0.1301,
                0.3975,
                -0.3317,
                -0.1967,
                0.4372,
                -0.2006,
                0.0094,
                0.3281,
                0.1873,
                -0.2945,
                0.2399,
                0.0320,
                -0.3768,
                -0.2849,
                -0.2248,
                0.3186,
            ],
            [
                0.0225,
                -0.3867,
                -0.1929,
                -0.2904,
                -0.4958,
                -0.2533,
                0.4001,
                -0.1517,
                -0.2799,
                -0.2915,
                0.4198,
                0.4506,
                0.1446,
                -0.4753,
                -0.0711,
                0.2876,
                -0.1851,
                -0.1066,
                0.2081,
                -0.1190,
                -0.3902,
                -0.1668,
                0.1911,
                -0.2848,
                -0.3846,
                0.1175,
                0.1052,
                0.2172,
                -0.0362,
                0.3055,
            ],
        ]

        emissions_grads = [
            [
                0.1060,
                0.1595,
                -0.7639,
                0.2485,
                0.1118,
                0.1380,
                0.1915,
                -0.7524,
                0.1539,
                0.1175,
                0.1717,
                0.1178,
                0.1738,
                0.1137,
                0.2288,
                0.1216,
                0.1678,
                -0.8057,
                0.1766,
                -0.7923,
                0.1902,
                0.0988,
                0.2056,
                0.1210,
                0.1212,
                0.1422,
                0.2059,
                -0.8160,
                0.2166,
                0.1300,
            ],
            [
                0.2029,
                0.1164,
                0.1325,
                0.2383,
                -0.8032,
                0.1131,
                0.1414,
                0.2602,
                0.1263,
                -0.3441,
                -0.3009,
                0.1172,
                0.1557,
                0.1788,
                0.1496,
                -0.5498,
                0.0140,
                0.0516,
                0.2306,
                0.1219,
                0.1503,
                -0.4244,
                0.1796,
                -0.2579,
                0.2149,
                0.1745,
                0.1160,
                0.1271,
                0.1350,
                -0.7675,
            ],
            [
                0.2195,
                0.1458,
                0.1770,
                -0.8395,
                0.1307,
                0.1666,
                0.2148,
                0.1237,
                -0.6613,
                -0.1223,
                0.2191,
                0.2259,
                0.2002,
                0.1077,
                -0.8386,
                0.2310,
                0.1440,
                0.1557,
                0.2197,
                -0.1466,
                -0.5742,
                0.1510,
                0.2160,
                0.1342,
                0.1050,
                -0.8265,
                0.1714,
                0.1917,
                0.1488,
                0.2094,
            ],
        ]
        # fmt: on
        transitions = gtn.Graph()
        transitions.add_node(True)
        for i in range(1, N + 1):
            transitions.add_node(False, True)
            transitions.add_arc(0, i, i - 1)  # p(i | <s>)

        for i in range(N):
            for j in range(N):
                transitions.add_arc(j + 1, i + 1, i)  # p(i | j)

        for b in range(len(targets)):
            target = targets[b]
            emissions_vec = emissions_vecs[b]
            emissions_grad = emissions_grads[b]

            fal = gtn.Graph()
            fal.add_node(True)
            for l in range(1, len(target) + 1):
                fal.add_node(False, l == len(target))
                fal.add_arc(l - 1, l, target[l - 1])
                fal.add_arc(l, l, target[l - 1])

            emissions = emissions_graph(emissions_vec, T, N, True)

            loss = gtn.subtract(
                gtn.forward_score(gtn.compose(emissions, transitions)),
                gtn.forward_score(
                    gtn.compose(gtn.compose(fal, transitions), emissions)),
            )

            self.assertAlmostEqual(loss.item(), expected_loss[b], places=3)

            # Check the gradients
            gtn.backward(loss)

            all_close = True
            grad = emissions.grad()
            grad_weights = grad.weights_to_list()
            for i in range(T * N):
                g = grad_weights[i]
                all_close = all_close and (abs(emissions_grad[i] - g) < 1e-4)
            self.assertTrue(all_close)

        all_close = True
        # fmt: off
        trans_grad = [
            0.3990,
            0.3396,
            0.3486,
            0.3922,
            0.3504,
            0.3155,
            0.3666,
            0.0116,
            -1.6678,
            0.3737,
            0.3361,
            -0.7152,
            0.3468,
            0.3163,
            -1.1583,
            -0.6803,
            0.3216,
            0.2722,
            0.3694,
            -0.6688,
            0.3047,
            -0.8531,
            -0.6571,
            0.2870,
            0.3866,
            0.3321,
            0.3447,
            0.3664,
            -0.2163,
            0.3039,
            0.3640,
            -0.6943,
            0.2988,
            -0.6722,
            0.3215,
            -0.1860,
        ]
        # fmt: on

        grad = transitions.grad()
        grad_weights = grad.weights_to_list()
        for i in range(N * N):
            g = grad_weights[i + N]
            all_close = all_close and (abs(trans_grad[i] - g) < 1e-4)
        self.assertTrue(all_close)
예제 #12
0
def main(out_dir=None,
         gpu_dev_id=None,
         num_samples=10,
         random_seed=None,
         learning_rate=1e-3,
         num_epochs=500,
         dataset_kwargs={},
         dataloader_kwargs={},
         model_kwargs={}):

    if out_dir is None:
        out_dir = os.path.join('~', 'data', 'output', 'seqtools', 'test_gtn')

    out_dir = os.path.expanduser(out_dir)

    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    fig_dir = os.path.join(out_dir, 'figures')
    if not os.path.exists(fig_dir):
        os.makedirs(fig_dir)

    vocabulary = ['a', 'b', 'c', 'd', 'e']

    transition = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 1, 0], [0, 0, 0, 0, 1],
                           [0, 1, 0, 0, 1], [0, 0, 0, 0, 0]],
                          dtype=float)
    initial = np.array([1, 0, 1, 0, 0], dtype=float)
    final = np.array([0, 1, 0, 0, 1], dtype=float) / 10

    seq_params = (transition, initial, final)
    simulated_dataset = simulate(num_samples, *seq_params)
    label_seqs, obsv_seqs = tuple(zip(*simulated_dataset))
    seq_params = tuple(map(lambda x: -np.log(x), seq_params))

    dataset = torchutils.SequenceDataset(obsv_seqs, label_seqs,
                                         **dataset_kwargs)
    data_loader = torch.utils.data.DataLoader(dataset, **dataloader_kwargs)

    train_loader = data_loader
    val_loader = data_loader

    transition_weights = torch.tensor(transition, dtype=torch.float).log()
    initial_weights = torch.tensor(initial, dtype=torch.float).log()
    final_weights = torch.tensor(final, dtype=torch.float).log()

    model = libfst.LatticeCrf(vocabulary,
                              transition_weights=transition_weights,
                              initial_weights=initial_weights,
                              final_weights=final_weights,
                              debug_output_dir=fig_dir,
                              **model_kwargs)

    gtn.draw(model._transition_fst,
             os.path.join(fig_dir, 'transitions-init.png'),
             isymbols=model._arc_symbols,
             osymbols=model._arc_symbols)

    gtn.draw(model._duration_fst,
             os.path.join(fig_dir, 'durations-init.png'),
             isymbols=model._arc_symbols,
             osymbols=model._arc_symbols)

    if True:
        for i, (inputs, targets, seq_id) in enumerate(train_loader):
            arc_scores = model.scores_to_arc(inputs)
            arc_labels = model.labels_to_arc(targets)

            batch_size, num_samples, num_classes = arc_scores.shape

            obs_fst = libfst.linearFstFromArray(arc_scores[0].reshape(
                num_samples, -1))
            gt_fst = libfst.fromSequence(arc_labels[0])
            d1_fst = gtn.compose(obs_fst, model._duration_fst)
            d1_fst = gtn.project_output(d1_fst)
            denom_fst = gtn.compose(d1_fst, model._transition_fst)
            # denom_fst = gtn.project_output(denom_fst)
            num_fst = gtn.compose(denom_fst, gt_fst)
            viterbi_fst = gtn.viterbi_path(denom_fst)
            pred_fst = gtn.remove(gtn.project_output(viterbi_fst))

            loss = gtn.subtract(gtn.forward_score(num_fst),
                                gtn.forward_score(denom_fst))
            loss = torch.tensor(loss.item())

            if torch.isinf(loss).any():
                denom_alt = gtn.compose(obs_fst, model._transition_fst)
                d1_min = gtn.remove(gtn.project_output(d1_fst))
                denom_alt = gtn.compose(d1_min, model._transition_fst)
                num_alt = gtn.compose(denom_alt, gt_fst)
                gtn.draw(obs_fst,
                         os.path.join(fig_dir, 'observations-init.png'),
                         isymbols=model._arc_symbols,
                         osymbols=model._arc_symbols)
                gtn.draw(gt_fst,
                         os.path.join(fig_dir, 'labels-init.png'),
                         isymbols=model._arc_symbols,
                         osymbols=model._arc_symbols)
                gtn.draw(d1_fst,
                         os.path.join(fig_dir, 'd1-init.png'),
                         isymbols=model._arc_symbols,
                         osymbols=model._arc_symbols)
                gtn.draw(d1_min,
                         os.path.join(fig_dir, 'd1-min-init.png'),
                         isymbols=model._arc_symbols,
                         osymbols=model._arc_symbols)
                gtn.draw(denom_fst,
                         os.path.join(fig_dir, 'denominator-init.png'),
                         isymbols=model._arc_symbols,
                         osymbols=model._arc_symbols)
                gtn.draw(denom_alt,
                         os.path.join(fig_dir, 'denominator-alt-init.png'),
                         isymbols=model._arc_symbols,
                         osymbols=model._arc_symbols)
                gtn.draw(num_fst,
                         os.path.join(fig_dir, 'numerator-init.png'),
                         isymbols=model._arc_symbols,
                         osymbols=model._arc_symbols)
                gtn.draw(num_alt,
                         os.path.join(fig_dir, 'numerator-alt-init.png'),
                         isymbols=model._arc_symbols,
                         osymbols=model._arc_symbols)
                gtn.draw(viterbi_fst,
                         os.path.join(fig_dir, 'viterbi-init.png'),
                         isymbols=model._arc_symbols,
                         osymbols=model._arc_symbols)
                gtn.draw(pred_fst,
                         os.path.join(fig_dir, 'pred-init.png'),
                         isymbols=model._arc_symbols,
                         osymbols=model._arc_symbols)
                import pdb
                pdb.set_trace()

    # Train the model
    train_epoch_log = collections.defaultdict(list)
    val_epoch_log = collections.defaultdict(list)
    metric_dict = {
        'Avg Loss': metrics.AverageLoss(),
        'Accuracy': metrics.Accuracy()
    }

    criterion = model.nllLoss
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=1,
                                                gamma=1.00)

    model, last_model_wts = torchutils.trainModel(
        model,
        criterion,
        optimizer,
        scheduler,
        train_loader,
        val_loader,
        metrics=metric_dict,
        test_metric='Avg Loss',
        train_epoch_log=train_epoch_log,
        val_epoch_log=val_epoch_log,
        num_epochs=num_epochs)

    gtn.draw(model._transition_fst,
             os.path.join(fig_dir, 'transitions-trained.png'),
             isymbols=model._arc_symbols,
             osymbols=model._arc_symbols)
    gtn.draw(model._duration_fst,
             os.path.join(fig_dir, 'durations-trained.png'),
             isymbols=model._arc_symbols,
             osymbols=model._arc_symbols)

    torchutils.plotEpochLog(train_epoch_log,
                            title="Train Epoch Log",
                            fn=os.path.join(fig_dir, "train-log.png"))
예제 #13
0
    tokens = token_graph(word_pieces)
    gtn.draw(tokens, "tokens.pdf", idx_to_wp, idx_to_wp)

    # Recognizes "abc":
    abc = gtn.Graph(False)
    abc.add_node(True)
    abc.add_node()
    abc.add_node()
    abc.add_node(False, True)
    abc.add_arc(0, 1, let_to_idx["a"])
    abc.add_arc(1, 2, let_to_idx["b"])
    abc.add_arc(2, 3, let_to_idx["c"])
    gtn.draw(abc, "abc.pdf", idx_to_let)

    # Compute the decomposition graph for "abc":
    abc_decomps = gtn.remove(gtn.project_output(gtn.compose(abc, lex)))
    gtn.draw(abc_decomps, "abc_decomps.pdf", idx_to_wp, idx_to_wp)

    # Compute the alignment graph for "abc":
    abc_alignments = gtn.project_input(
        gtn.remove(gtn.compose(tokens, abc_decomps)))
    gtn.draw(abc_alignments, "abc_alignments.pdf", idx_to_wp)

    # From here we can use the alignment graph with an emissions graph and
    # transitions graphs to compute the sequence level criterion:
    emissions = gtn.linear_graph(10, len(word_pieces), True)
    loss = gtn.subtract(
        gtn.forward_score(emissions),
        gtn.forward_score(gtn.intersect(emissions, abc_alignments)))
    print(f"Loss is {loss.item():.2f}")