Ejemplo n.º 1
0
    def test_log_single_fsa(self):
        s = '''
            0 1 1 0.1
            0 2 2 0.2
            1 2 3 0.3
            1 3 4 0.4
            2 3 5 0.5
            3 4 -1 0
            4
        '''
        devices = [torch.device('cpu')]
        if torch.cuda.is_available():
            devices.append(torch.device('cuda', 0))

        for device in devices:
            fsa = k2.Fsa.from_str(s).to(device)
            fsa.requires_grad_(True)
            fsa_vec = k2.create_fsa_vec([fsa])
            log_like = k2.get_tot_scores(fsa_vec,
                                         log_semiring=True,
                                         use_float_scores=True)
            assert log_like.dtype == torch.float32
            # The expected_log_like is computed using gtn.
            # See https://bit.ly/3oUiRx9
            expected_log_like = torch.tensor([1.8119014501571655]).to(device)
            assert torch.allclose(log_like, expected_log_like)

            # The expected_grad is computed using gtn.
            # See https://bit.ly/3oUiRx9
            expected_grad = torch.tensor([
                0.6710670590400696, 0.32893291115760803, 0.4017595648765564,
                0.2693074941635132, 0.7306925058364868, 1.0
            ]).to(device)

            scale = -1.75
            (scale * log_like).sum().backward()
            assert torch.allclose(fsa.scores.grad, scale * expected_grad)

            # now for double
            fsa.scores.grad = None
            log_like = k2.get_tot_scores(fsa_vec,
                                         log_semiring=True,
                                         use_float_scores=False)
            assert log_like.dtype == torch.float64
            expected_log_like = expected_log_like.to(torch.float64)
            assert torch.allclose(log_like, expected_log_like)

            scale = 10.25
            (scale * log_like).sum().backward()
            assert torch.allclose(fsa.scores.grad, scale * expected_grad)
Ejemplo n.º 2
0
    def test_tropical_single_fsa(self):
        # best path arc indexes are: 1, 3, 5, 10
        s = '''
            0 4 1 1
            0 1 1 1
            1 2 1 2
            1 3 1 3
            2 7 1 4
            3 7 1 5
            4 6 1 2
            4 8 1 3
            5 9 -1 4
            6 9 -1 3
            7 9 -1 5
            8 9 -1 6
            9
        '''
        devices = [torch.device('cpu')]
        if torch.cuda.is_available():
            devices.append(torch.device('cuda'))

        for device in devices:
            fsa = k2.Fsa.from_str(s).to(device)
            fsa = k2.create_fsa_vec([fsa])
            fsa.requires_grad_(True)
            log_like = k2.get_tot_scores(fsa,
                                         log_semiring=False,
                                         use_float_scores=True)

            assert log_like == 14
            assert log_like.dtype == torch.float32

            scale = -10

            (scale * log_like).sum().backward()
            expected = torch.zeros(len(fsa.scores)).to(device)
            expected[torch.tensor([1, 3, 5, 10])] = 1
            assert torch.allclose(fsa.scores.grad, scale * expected)

            # now for double
            fsa.scores.grad = None
            log_like = k2.get_tot_scores(fsa,
                                         log_semiring=False,
                                         use_float_scores=False)
            assert log_like == 14
            assert log_like.dtype == torch.float64

            scale = -1.25
            (scale * log_like).sum().backward()
            assert torch.allclose(fsa.scores.grad, scale * expected)
Ejemplo n.º 3
0
    def test_two_fsas(self):
        s1 = '''
            0 1 1 1.0
            1 1 1 50.0
            1 2 2 2.0
            2 3 -1 3.0
            3
        '''

        s2 = '''
            0 1 1 1.0
            1 2 2 2.0
            2 3 -1 3.0
            3
        '''

        fsa1 = k2.Fsa.from_str(s1)
        fsa2 = k2.Fsa.from_str(s2)

        fsa1.requires_grad_(True)
        fsa2.requires_grad_(True)

        fsa_vec = k2.create_fsa_vec([fsa1, fsa2])

        log_prob = torch.tensor(
            [[[0.1, 0.2, 0.3], [0.04, 0.05, 0.06], [0.0, 0.0, 0.0]],
             [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.0, 0.0, 0.0]]],
            dtype=torch.float32,
            requires_grad=True)

        supervision_segments = torch.tensor([[0, 0, 3], [1, 0, 2]],
                                            dtype=torch.int32)
        dense_fsa_vec = k2.DenseFsaVec(log_prob, supervision_segments)
        out_fsa = k2.intersect_dense(fsa_vec,
                                     dense_fsa_vec,
                                     output_beam=100000)
        assert out_fsa.shape == (2, None, None), 'There should be two FSAs!'

        scores = k2.get_tot_scores(out_fsa,
                                   log_semiring=False,
                                   use_float_scores=True)
        scores.sum().backward()

        # `expected` results are computed using gtn.
        # See https://bit.ly/3oYObeb
        #  expected_scores_out_fsa = torch.tensor(
        #      [1.2, 2.06, 3.0, 1.2, 50.5, 2.0, 3.0])

        expected_grad_fsa1 = torch.tensor([1.0, 1.0, 1.0, 1.0])
        expected_grad_fsa2 = torch.tensor([1.0, 1.0, 1.0])
        print("fsa2 is ", fsa2.__str__())
        # TODO(dan):: fix this..
        #  expected_grad_log_prob = torch.tensor([
        #      0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0, 0, 0, 0.0, 1.0, 0.0, 0.0, 1.0,
        #      0.0, 0.0, 0.0, 1.0
        #  ]).reshape_as(log_prob)

        # assert torch.allclose(out_fsa.scores, expected_scores_out_fsa)
        assert torch.allclose(expected_grad_fsa1, fsa1.scores.grad)
        assert torch.allclose(expected_grad_fsa2, fsa2.scores.grad)
Ejemplo n.º 4
0
    def test_simple(self):
        s = '''
            0 1 1 1.0
            1 1 1 50.0
            1 2 2 2.0
            2 3 -1 3.0
            3
        '''
        fsa = k2.Fsa.from_str(s)
        fsa.requires_grad_(True)
        fsa_vec = k2.create_fsa_vec([fsa])
        log_prob = torch.tensor([[[0.1, 0.2, 0.3], [0.04, 0.05, 0.06]]],
                                dtype=torch.float32,
                                requires_grad=True)

        supervision_segments = torch.tensor([[0, 0, 2]], dtype=torch.int32)
        dense_fsa_vec = k2.DenseFsaVec(log_prob, supervision_segments)
        out_fsa = k2.intersect_dense(fsa_vec,
                                     dense_fsa_vec,
                                     output_beam=100000)
        scores = k2.get_tot_scores(out_fsa,
                                   log_semiring=False,
                                   use_float_scores=True)
        scores.sum().backward()

        # `expected` results are computed using gtn.
        # See https://bit.ly/3oYObeb
        expected_scores_out_fsa = torch.tensor([1.2, 2.06, 3.0])
        expected_grad_fsa = torch.tensor([1.0, 0.0, 1.0, 1.0])
        expected_grad_log_prob = torch.tensor([0.0, 1.0, 0.0, 0.0, 0.0,
                                               1.0]).reshape_as(log_prob)
        assert torch.allclose(out_fsa.scores, expected_scores_out_fsa)
        assert torch.allclose(expected_grad_fsa, fsa.scores.grad)
        assert torch.allclose(expected_grad_log_prob, log_prob.grad)
Ejemplo n.º 5
0
    def forward(self, log_probs: torch.Tensor, targets: torch.Tensor,
                input_lengths: torch.Tensor,
                target_lengths: torch.Tensor) -> torch.Tensor:

        log_probs = log_probs.permute(1, 0, 2).cpu(
        )  # now log_probs is [N, T, C]  batchSize x seqLength x alphabet_size
        supervision_segments = torch.stack(
            (torch.tensor(range(input_lengths.shape[0])),
             torch.zeros(input_lengths.shape[0]), input_lengths),
            1).to(torch.int32)
        indices = torch.argsort(supervision_segments[:, 2], descending=True)
        supervision_segments = supervision_segments[indices]

        dense_fsa_vec = k2.DenseFsaVec(log_probs, supervision_segments)
        decoding_graph = self.graph_compiler.compile(targets.cpu(),
                                                     target_lengths)
        decoding_graph = k2.index(decoding_graph,
                                  indices.to(torch.int32)).to(log_probs.device)

        target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec, 10.0)
        tot_scores = k2.get_tot_scores(target_graph,
                                       log_semiring=True,
                                       use_double_scores=True)
        (tot_score, tot_frames,
         all_frames) = get_tot_objf_and_num_frames(tot_scores,
                                                   supervision_segments[:, 2])
        return -tot_score
Ejemplo n.º 6
0
    def test_autograd(self):
        s0 = '''
            0 1 1 0.1
            0 2 2 0.2
            1 3 -1 0.3
            1 2 2 0.4
            2 3 -1 0.5
            3
        '''

        s1 = '''
            0 2 -1 0.6
            0 1 1 0.7
            1 2 -1 0.8
            2
        '''

        s2 = '''
            0 1 1 1.1
            1 2 -1 1.2
            2
        '''
        devices = [torch.device('cpu')]
        if torch.cuda.is_available():
            devices.append(torch.device('cuda', 0))

        for device in devices:
            fsa0 = k2.Fsa.from_str(s0).to(device).requires_grad_(True)
            fsa1 = k2.Fsa.from_str(s1).to(device).requires_grad_(True)
            fsa2 = k2.Fsa.from_str(s2).to(device).requires_grad_(True)

            fsa_vec = k2.create_fsa_vec([fsa0, fsa1, fsa2])
            fsa = k2.union(fsa_vec)
            fsa_vec = k2.create_fsa_vec([fsa])
            log_like = k2.get_tot_scores(fsa_vec,
                                         log_semiring=True,
                                         use_double_scores=False)
            # expected log_like and gradients are computed using gtn.
            # See https://bit.ly/35uVaUv
            log_like.backward()

            expected_log_like = torch.tensor([3.1136]).to(log_like)
            assert torch.allclose(log_like, expected_log_like)

            expected_grad_fsa0 = torch.tensor([
                0.18710044026374817, 0.08949274569749832, 0.06629786640405655,
                0.12080258131027222, 0.21029533445835114
            ]).to(device)

            expected_grad_fsa1 = torch.tensor([
                0.08097638934850693, 0.19916976988315582, 0.19916976988315582
            ]).to(device)

            expected_grad_fsa2 = torch.tensor(
                [0.4432605803012848, 0.4432605803012848]).to(device)

            assert torch.allclose(fsa0.grad, expected_grad_fsa0)
            assert torch.allclose(fsa1.grad, expected_grad_fsa1)
            assert torch.allclose(fsa2.grad, expected_grad_fsa2)
Ejemplo n.º 7
0
    def test_tropical_single_fsa(self):
        # best path arc indexes are: 1, 3, 5, 10
        s = '''
            0 4 1 1
            0 1 1 1
            1 2 1 2
            1 3 1 3
            2 7 1 4
            3 7 1 5
            4 6 1 2
            4 8 1 3
            5 9 -1 4
            6 9 -1 3
            7 9 -1 5
            8 9 -1 6
            9
        '''
        fsa = k2.Fsa.from_str(s)
        fsa = k2.create_fsa_vec([fsa])
        fsa.requires_grad_(True)
        log_like = k2.get_tot_scores(fsa,
                                     log_semiring=False,
                                     use_float_scores=True)

        assert log_like == 14
        assert log_like.dtype == torch.float32

        log_like.sum().backward()
        expected = torch.zeros(len(fsa.scores))
        expected[torch.tensor([1, 3, 5, 10])] = 1
        assert torch.allclose(fsa.scores.grad, expected)

        # now for double
        fsa.scores.grad = None
        log_like = k2.get_tot_scores(fsa,
                                     log_semiring=False,
                                     use_float_scores=False)
        assert log_like == 14
        assert log_like.dtype == torch.float64
        log_like.sum().backward()
        assert torch.allclose(fsa.scores.grad, expected)
Ejemplo n.º 8
0
    def test_two_fsas_long_pruned(self):
        # as test_two_fsas_long in intersect_dense_test.py,
        # but with pruned intersection
        s1 = '''
            0 1 1 1.0
            1 1 1 50.0
            1 2 2 2.0
            2 3 -1 3.0
            3
        '''

        s2 = '''
            0 1 1 1.0
            1 2 2 2.0
            2 3 -1 3.0
            3
        '''

        devices = [torch.device('cpu')]
        if torch.cuda.is_available():
            devices.append(torch.device('cuda', 0))
        for device in devices:
            fsa1 = k2.Fsa.from_str(s1)
            fsa2 = k2.Fsa.from_str(s2)

            fsa1.requires_grad_(True)
            fsa2.requires_grad_(True)

            fsa_vec = k2.create_fsa_vec([fsa1, fsa2])
            log_prob = torch.rand((2, 100, 3),
                                  dtype=torch.float32,
                                  device=device,
                                  requires_grad=True)

            supervision_segments = torch.tensor([[0, 0, 95], [1, 20, 50]],
                                                dtype=torch.int32)
            dense_fsa_vec = k2.DenseFsaVec(log_prob, supervision_segments)
            fsa_vec = fsa_vec.to(device)
            out_fsa = k2.intersect_dense_pruned(fsa_vec,
                                                dense_fsa_vec,
                                                search_beam=100,
                                                output_beam=100,
                                                min_active_states=1,
                                                max_active_states=10)
            assert out_fsa.shape == (2, None,
                                     None), 'There should be two FSAs!'

            scores = k2.get_tot_scores(out_fsa,
                                       log_semiring=False,
                                       use_double_scores=False)
            scores.sum().backward()
Ejemplo n.º 9
0
def get_objf(batch, model, device, L, symbols, training, optimizer=None):
    feature = batch['features']
    supervisions = batch['supervisions']
    supervision_segments = torch.stack(
        (supervisions['sequence_idx'], supervisions['start_frame'],
         supervisions['num_frames']), 1).to(torch.int32)
    texts = supervisions['text']
    assert feature.ndim == 3
    #print(feature.shape)
    #print(supervision_segments[:, 1] + supervision_segments[:, 2])

    # at entry, feature is [N, T, C]
    feature = feature.permute(0, 2, 1)  # now feature is [N, C, T]
    feature = feature.to(device)
    if training:
        nnet_output = model(feature)
    else:
        with torch.no_grad():
            nnet_output = model(feature)

    # nnet_output is [N, C, T]
    nnet_output = nnet_output.permute(0, 2, 1)  # now nnet_output is [N, T, C]

    # TODO(haowen): create decoding graph at the beginning of training
    decoding_graph = create_decoding_graph(texts, L, symbols)
    decoding_graph.to_(device)
    decoding_graph.scores.requires_grad_(False)
    #print(nnet_output.shape)
    dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)
    #dense_fsa_vec.scores.requires_grad_(True)
    assert decoding_graph.is_cuda()
    assert decoding_graph.device == device
    assert nnet_output.device == device
    #print(nnet_output.get_device())
    print(decoding_graph.arcs)
    print(dense_fsa_vec.dense_fsa_vec)
    target_graph = k2.intersect_dense_pruned(decoding_graph, dense_fsa_vec, 10,
                                             10000, 0)
    tot_scores = -k2.get_tot_scores(target_graph, True, False).sum()
    if training:
        optimizer.zero_grad()
        tot_scores.backward()
        clip_grad_value_(model.parameters(), 5.0)
        optimizer.step()

    objf = tot_scores.detach().cpu()
    total_objf = objf.item()
    total_frames = nnet_output.shape[0]

    return total_objf, total_frames
Ejemplo n.º 10
0
    def test_two_fsas_long(self):
        # as test_two_fsas, but generate long DenseFsaVec for easier profiling.
        s1 = '''
            0 1 1 1.0
            1 1 1 50.0
            1 2 2 2.0
            2 3 -1 3.0
            3
        '''

        s2 = '''
            0 1 1 1.0
            1 2 2 2.0
            2 3 -1 3.0
            3
        '''

        devices = [torch.device('cpu')]
        if torch.cuda.is_available():
            devices.append(torch.device('cuda', 0))
        for device in devices:
            fsa1 = k2.Fsa.from_str(s1)
            fsa2 = k2.Fsa.from_str(s2)

            fsa1.requires_grad_(True)
            fsa2.requires_grad_(True)

            fsa_vec = k2.create_fsa_vec([fsa1, fsa2])
            log_prob = torch.rand((2, 500, 3),
                                  dtype=torch.float32,
                                  device=device,
                                  requires_grad=True)

            supervision_segments = torch.tensor([[0, 0, 490], [1, 0, 300]],
                                                dtype=torch.int32)
            dense_fsa_vec = k2.DenseFsaVec(log_prob, supervision_segments)
            fsa_vec = fsa_vec.to(device)
            out_fsa = k2.intersect_dense(fsa_vec,
                                         dense_fsa_vec,
                                         output_beam=100000)
            assert out_fsa.shape == (2, None,
                                     None), 'There should be two FSAs!'

            scores = k2.get_tot_scores(out_fsa,
                                       log_semiring=False,
                                       use_float_scores=True)
            scores.sum().backward()
Ejemplo n.º 11
0
    def test_two_dense(self):
        s = '''
            0 1 1 1.0
            1 1 1 50.0
            1 2 2 2.0
            2 3 -1 3.0
            3
        '''

        fsa = k2.Fsa.from_str(s)
        fsa.requires_grad_(True)
        fsa_vec = k2.create_fsa_vec([fsa])
        log_prob = torch.tensor(
            [[[0.1, 0.2, 0.3], [0.04, 0.05, 0.06], [0.0, 0.0, 0.0]],
             [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.0, 0.0, 0.0]]],
            dtype=torch.float32,
            requires_grad=True)

        supervision_segments = torch.tensor([[0, 0, 2], [1, 0, 3]],
                                            dtype=torch.int32)
        dense_fsa_vec = k2.DenseFsaVec(log_prob, supervision_segments)
        out_fsa = k2.intersect_dense_pruned(fsa_vec,
                                            dense_fsa_vec,
                                            search_beam=100000,
                                            output_beam=100000,
                                            min_active_states=0,
                                            max_active_states=10000)
        assert out_fsa.shape == (2, None, None), 'There should be two FSAs!'

        scores = k2.get_tot_scores(out_fsa,
                                   log_semiring=False,
                                   use_double_scores=False)
        scores.sum().backward()

        # `expected` results are computed using gtn.
        # See https://bit.ly/3oYObeb
        expected_scores_out_fsa = torch.tensor(
            [1.2, 2.06, 3.0, 1.2, 50.5, 2.0, 3.0])
        expected_grad_fsa = torch.tensor([2.0, 1.0, 2.0, 2.0])
        expected_grad_log_prob = torch.tensor([
            0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0, 0, 0, 0.0, 1.0, 0.0, 0.0, 1.0,
            0.0, 0.0, 0.0, 1.0
        ]).reshape_as(log_prob)

        assert torch.allclose(out_fsa.scores, expected_scores_out_fsa)
        assert torch.allclose(expected_grad_fsa, fsa.scores.grad)
        assert torch.allclose(expected_grad_log_prob, log_prob.grad)
Ejemplo n.º 12
0
    def test_compose(self):
        s = '''
            0 1 11 1 1.0
            0 2 12 2 2.5
            1 3 -1 -1 0
            2 3 -1 -1 2.5
            3
        '''
        a_fsa = k2.Fsa.from_str(s).requires_grad_(True)

        s = '''
            0 1 1 1 1.0
            0 2 2 3 3.0
            1 2 3 2 2.5
            2 3 -1 -1 2.0
            3
        '''
        b_fsa = k2.Fsa.from_str(s).requires_grad_(True)

        ans = k2.compose(a_fsa, b_fsa, inner_labels='inner')
        ans = k2.connect(ans)

        # Convert a single FSA to a FsaVec.
        # It will retain `requires_grad_` of `ans`.
        ans.__dict__['arcs'] = _k2.create_fsa_vec([ans.arcs])

        scores = k2.get_tot_scores(ans,
                                   log_semiring=True,
                                   use_double_scores=False)
        # The reference values for `scores`, `a_fsa.grad` and `b_fsa.grad`
        # are computed using GTN.
        # See https://bit.ly/3heLAJq
        assert scores.item() == 10
        scores.backward()
        assert torch.allclose(a_fsa.grad, torch.tensor([0., 1., 0., 1.]))
        assert torch.allclose(b_fsa.grad, torch.tensor([0., 1., 0., 1.]))
        print(ans)
Ejemplo n.º 13
0
def get_objf(batch: Dict,
             model: AcousticModel,
             device: torch.device,
             graph_compiler: CtcTrainingGraphCompiler,
             training: bool,
             optimizer: Optional[torch.optim.Optimizer] = None):
    feature = batch['features']
    supervisions = batch['supervisions']
    supervision_segments = torch.stack(
        (supervisions['sequence_idx'],
         torch.floor_divide(supervisions['start_frame'],
                            model.subsampling_factor),
         torch.floor_divide(supervisions['num_frames'],
                            model.subsampling_factor)), 1).to(torch.int32)
    indices = torch.argsort(supervision_segments[:, 2], descending=True)
    supervision_segments = supervision_segments[indices]

    texts = supervisions['text']
    texts = [texts[idx] for idx in indices]
    assert feature.ndim == 3
    # print(supervision_segments[:, 1] + supervision_segments[:, 2])

    feature = feature.to(device)
    # at entry, feature is [N, T, C]
    feature = feature.permute(0, 2, 1)  # now feature is [N, C, T]
    if training:
        nnet_output = model(feature)
    else:
        with torch.no_grad():
            nnet_output = model(feature)

    # nnet_output is [N, C, T]
    nnet_output = nnet_output.permute(0, 2, 1)  # now nnet_output is [N, T, C]

    decoding_graph = graph_compiler.compile(texts).to(device)

    # nnet_output2 = nnet_output.clone()
    # blank_bias = -7.0
    # nnet_output2[:,:,0] += blank_bias

    dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)
    assert decoding_graph.is_cuda()
    assert decoding_graph.device == device
    assert nnet_output.device == device
    # TODO(haowen): with a small `beam`, we may get empty `target_graph`,
    # thus `tot_scores` will be `inf`. Definitely we need to handle this later.
    target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec, 10.0)

    tot_scores = k2.get_tot_scores(target_graph,
                                   log_semiring=True,
                                   use_double_scores=True)

    (tot_score, tot_frames,
     all_frames) = get_tot_objf_and_num_frames(tot_scores,
                                               supervision_segments[:, 2])

    if training:
        optimizer.zero_grad()
        (-tot_score).backward()
        clip_grad_value_(model.parameters(), 5.0)
        optimizer.step()

    ans = -tot_score.detach().cpu().item(), tot_frames.cpu().item(
    ), all_frames.cpu().item()
    return ans
Ejemplo n.º 14
0
    def test_tropical_multiple_fsas(self):
        # best path:
        #  states: 0 -> 1 -> 3 -> 7 -> 9
        #  arcs:     1 -> 3 -> 5 -> 10
        #  scores: 1 + 3 + 5 + 5 = 14
        s1 = '''
            0 4 1 1
            0 1 1 1
            1 2 1 2
            1 3 1 3
            2 7 1 4
            3 7 1 5
            4 6 1 2
            4 8 1 3
            5 9 -1 4
            6 9 -1 3
            7 9 -1 5
            8 9 -1 6
            9
        '''

        #  best path:
        #   states: 0 -> 2 -> 3 -> 4 -> 5
        #   arcs:     1 -> 4 -> 5 -> 7
        #   scores: 6 + 4 + 3 + 0 = 13
        s2 = '''
            0 1 1 1
            0 2 2 6
            1 2 3 3
            1 3 4 2
            2 3 5 4
            3 4 6 3
            3 5 -1 2
            4 5 -1 0
            5
        '''

        #  best path:
        #   states: 0 -> 2 -> 3
        #   arcs:     1 -> 3
        #   scores: 100 + 5.5 = 105.5
        s3 = '''
            0 1 1 10
            0 2 2 100
            1 3 -1 3.5
            2 3 -1 5.5
            3
        '''

        cpu_device = torch.device('cpu')
        cuda_device = torch.device('cuda', 0)
        for device in (cpu_device, cuda_device):
            fsa1 = k2.Fsa.from_str(s1).to(device)
            fsa2 = k2.Fsa.from_str(s2).to(device)
            fsa3 = k2.Fsa.from_str(s3).to(device)

            fsa1.requires_grad_(True)
            fsa2.requires_grad_(True)
            fsa3.requires_grad_(True)

            fsa_vec = k2.create_fsa_vec([fsa1, fsa2, fsa3])

            log_like = k2.get_tot_scores(fsa_vec,
                                         log_semiring=False,
                                         use_float_scores=True)
            assert log_like.dtype == torch.float32
            expected_log_like = torch.tensor([14, 13, 105.5]).to(device)
            assert torch.allclose(log_like, expected_log_like)

            scale = torch.tensor([-10, -20, -30.]).to(log_like)
            (log_like * scale).sum().backward()

            fsa1_best_arc_indexes = torch.tensor([1, 3, 5, 10]).to(device)
            assert torch.allclose(
                fsa1.scores.grad[fsa1_best_arc_indexes],
                scale[0] * torch.ones(4, dtype=torch.float32).to(device))
            assert fsa1.scores.grad.sum() == 4 * scale[0]

            fsa2_best_arc_indexes = torch.tensor([1, 4, 5, 7]).to(device)
            assert torch.allclose(
                fsa2.scores.grad[fsa2_best_arc_indexes],
                scale[1] * torch.ones(4, dtype=torch.float32).to(device))
            assert fsa2.scores.grad.sum() == 4 * scale[1]

            fsa3_best_arc_indexes = torch.tensor([1, 3]).to(device)
            assert torch.allclose(
                fsa3.scores.grad[fsa3_best_arc_indexes],
                scale[2] * torch.ones(2, dtype=torch.float32).to(device))
            assert fsa3.scores.grad.sum() == 2 * scale[2]

            # now for double
            fsa1.scores.grad = None
            fsa2.scores.grad = None
            fsa3.scores.grad = None
            log_like = k2.get_tot_scores(fsa_vec,
                                         log_semiring=False,
                                         use_float_scores=False)

            assert log_like.dtype == torch.float64
            expected_log_like = expected_log_like.to(torch.float64)
            assert torch.allclose(log_like, expected_log_like)

            scale = torch.tensor([-1.25, -2.5, 3.5]).to(log_like)
            (scale * log_like).sum().backward()

            assert torch.allclose(
                fsa1.scores.grad[fsa1_best_arc_indexes],
                scale[0] * torch.ones(4, dtype=torch.float32).to(device))
            assert fsa1.scores.grad.sum() == 4 * scale[0]

            assert torch.allclose(
                fsa2.scores.grad[fsa2_best_arc_indexes],
                scale[1] * torch.ones(4, dtype=torch.float32).to(device))
            assert fsa2.scores.grad.sum() == 4 * scale[1]

            assert torch.allclose(
                fsa3.scores.grad[fsa3_best_arc_indexes],
                scale[2] * torch.ones(2, dtype=torch.float32).to(device))
            assert fsa3.scores.grad.sum() == 2 * scale[2]
Ejemplo n.º 15
0
    def test_log_multiple_fsas(self):
        s1 = '''
            0 1 1 0.1
            0 2 2 0.2
            1 2 3 0.3
            1 3 4 0.4
            2 3 5 0.5
            3 4 -1 0
            4
        '''

        s2 = '''
            0 3 3 0.1
            0 1 1 0.2
            0 2 2 0.3
            1 2 2 0.4
            1 3 3 0.5
            2 3 3 0.6
            2 4 4 0.7
            3 4 4 0.8
            3 5 -1 0.9
            4 5 -1 1.0
            5
        '''
        cpu_device = torch.device('cpu')
        cuda_device = torch.device('cuda', 0)
        for device in (cpu_device, cuda_device):
            fsa1 = k2.Fsa.from_str(s1).to(device)
            fsa2 = k2.Fsa.from_str(s2).to(device)

            fsa1.requires_grad_(True)
            fsa2.requires_grad_(True)

            fsa_vec = k2.create_fsa_vec([fsa1, fsa2])
            log_like = k2.get_tot_scores(fsa_vec,
                                         log_semiring=True,
                                         use_float_scores=True)
            assert log_like.dtype == torch.float32
            # The expected_log_likes are computed using gtn.
            # See https://bit.ly/3oUiRx9
            expected_log_like = torch.tensor(
                [1.8119014501571655, 4.533502578735352]).to(device)
            assert torch.allclose(log_like, expected_log_like)

            scale = torch.tensor([1.25, -5.25]).to(log_like)
            (scale * log_like).sum().backward()

            # The expected_grads are computed using gtn.
            # See https://bit.ly/3oUiRx9
            expected_grad_fsa1 = torch.tensor([
                0.6710670590400696, 0.32893291115760803, 0.4017595648765564,
                0.2693074941635132, 0.7306925058364868, 1.0
            ]).to(device)

            expected_grad_fsa2 = torch.tensor([
                0.10102888941764832, 0.5947467088699341, 0.3042244613170624,
                0.410660058259964, 0.1840866357088089, 0.5283515453338623,
                0.18653297424316406, 0.5783339142799377, 0.2351330667734146,
                0.764866828918457
            ]).to(device)

            assert torch.allclose(fsa1.scores.grad,
                                  scale[0] * expected_grad_fsa1)
            assert torch.allclose(fsa2.scores.grad,
                                  scale[1] * expected_grad_fsa2)

            # now for double
            fsa1.scores.grad = None
            fsa2.scores.grad = None

            log_like = k2.get_tot_scores(fsa_vec,
                                         log_semiring=True,
                                         use_float_scores=False)
            assert log_like.dtype == torch.float64
            expected_log_like = expected_log_like.to(torch.float64)
            assert torch.allclose(log_like, expected_log_like)

            scale = torch.tensor([-10.25, 8.25]).to(log_like)
            (scale * log_like).sum().backward()

            assert torch.allclose(fsa1.scores.grad,
                                  scale[0] * expected_grad_fsa1)
            assert torch.allclose(fsa2.scores.grad,
                                  scale[1] * expected_grad_fsa2)