Exemple #1
0
    def testStateToOlabelUniqueSinglePath(self):
        labels = [
            [3, 4, 3],
            [1, 0, 0],
        ]
        num_labels = 8

        # 3 frames, 2 batch, 8 states (4 label, 4 blank).
        #
        # There is only single valid path for each sequence because the frame
        # lengths and the label lengths are the same.
        states = [[[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                   [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
                  [[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                   [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
                  [[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0],
                   [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]]
        labels = ops.convert_to_tensor(labels)
        states = math_ops.log(states)
        olabel = ctc_ops._state_to_olabel_unique(
            labels, num_labels, states, ctc_ops.ctc_unique_labels(labels))
        olabel = math_ops.exp(olabel)
        blank = olabel[:, :, 0]

        self.assertAllClose(blank, [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]])
        self.assertAllClose(olabel[:, :, 1:], [
            [[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0],
             [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
            [[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
             [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
            [[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0],
             [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
        ])
    def testStateToOlabelUnique(self):
        labels = [
            [3, 4, 3, 4],
            [1, 1, 1, 0],
        ]
        num_labels = 8

        # 3 frames, 2 batch, 10 states (5 label, 5 blank).
        states = [
            [[0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18, 0.19, 0.20],
             [0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 0.28, 0.29, 0.30]],
            [[1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0],
             [2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3.0]],
            [[11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0],
             [21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0]],
        ]
        labels = ops.convert_to_tensor(labels)
        states = math_ops.log(states)
        olabel = ctc_ops._state_to_olabel_unique(
            labels, num_labels, states, ctc_ops.ctc_unique_labels(labels))
        olabel = math_ops.exp(olabel)
        blank = olabel[:, :, 0]
        self.assertAllClose(blank, [[
            0.16 + 0.17 + 0.18 + 0.19 + 0.20, 0.26 + 0.27 + 0.28 + 0.29 + 0.30
        ], [1.6 + 1.7 + 1.8 + 1.9 + 2.0, 2.6 + 2.7 + 2.8 + 2.9 + 3.0],
                                    [
                                        16.0 + 17.0 + 18.0 + 19.0 + 20.0,
                                        26.0 + 27.0 + 28.0 + 29.0 + 30.0
                                    ]])
        self.assertAllClose(olabel[:, :, 1:], [
            [[0.0, 0.0, 0.12 + 0.14, 0.13 + 0.15, 0.0, 0.0, 0.0],
             [0.22 + 0.23 + 0.24, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
            [[0.0, 0.0, 1.2 + 1.4, 1.3 + 1.5, 0.0, 0.0, 0.0],
             [2.2 + 2.3 + 2.4, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
            [[0.0, 0.0, 12.0 + 14.0, 13.0 + 15.0, 0.0, 0.0, 0.0],
             [22.0 + 23.0 + 24.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
        ])
  def testStateToOlabelUnique(self):
    labels = [
        [3, 4, 3, 4],
        [1, 1, 1, 0],
    ]
    num_labels = 8

    # 3 frames, 2 batch, 10 states (5 label, 5 blank).
    states = [
        [[0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18, 0.19, 0.20],
         [0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 0.28, 0.29, 0.30]],
        [[1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0],
         [2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3.0]],
        [[11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0],
         [21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0]],
    ]
    labels = ops.convert_to_tensor(labels)
    states = math_ops.log(states)
    olabel = ctc_ops._state_to_olabel_unique(
        labels, num_labels, states, ctc_ops.ctc_unique_labels(labels))
    olabel = math_ops.exp(olabel)
    blank = olabel[:, :, 0]
    self.assertAllClose(blank, [
        [0.16 + 0.17 + 0.18 + 0.19 + 0.20,
         0.26 + 0.27 + 0.28 + 0.29 + 0.30],
        [1.6 + 1.7 + 1.8 + 1.9 + 2.0,
         2.6 + 2.7 + 2.8 + 2.9 + 3.0],
        [16.0 + 17.0 + 18.0 + 19.0 + 20.0,
         26.0 + 27.0 + 28.0 + 29.0 + 30.0]])
    self.assertAllClose(olabel[:, :, 1:], [
        [[0.0, 0.0, 0.12 + 0.14, 0.13 + 0.15, 0.0, 0.0, 0.0],
         [0.22 + 0.23 + 0.24, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
        [[0.0, 0.0, 1.2 + 1.4, 1.3 + 1.5, 0.0, 0.0, 0.0],
         [2.2 + 2.3 + 2.4, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
        [[0.0, 0.0, 12.0 + 14.0, 13.0 + 15.0, 0.0, 0.0, 0.0],
         [22.0 + 23.0 + 24.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
    ])