示例#1
0
  def testStateToOlabel(self):
    labels = [
        [3, 4, 3, 4],
        [1, 1, 1, 0],
    ]
    num_labels = 8

    # 3 frames, 2 batch, 10 states (5 label, 5 blank).
    states = [
        [[0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18, 0.19, 0.20],
         [0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 0.28, 0.29, 0.30]],
        [[1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0],
         [2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3.0]],
        [[11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0],
         [21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0]],
    ]
    labels = ops.convert_to_tensor(labels)
    states = math_ops.log(states)
    olabel = ctc_ops._state_to_olabel(labels, num_labels, states)
    olabel = math_ops.exp(olabel)
    blank = olabel[:, :, 0]
    self.assertAllClose(blank, [
        [0.16 + 0.17 + 0.18 + 0.19 + 0.20,
         0.26 + 0.27 + 0.28 + 0.29 + 0.30],
        [1.6 + 1.7 + 1.8 + 1.9 + 2.0,
         2.6 + 2.7 + 2.8 + 2.9 + 3.0],
        [16.0 + 17.0 + 18.0 + 19.0 + 20.0,
         26.0 + 27.0 + 28.0 + 29.0 + 30.0]
    ])
    self.assertAllClose(olabel[:, :, 1:], [
        [[0.0, 0.0, 0.12 + 0.14, 0.13 + 0.15, 0.0, 0.0, 0.0],
         [0.22 + 0.23 + 0.24, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
        [[0.0, 0.0, 1.2 + 1.4, 1.3 + 1.5, 0.0, 0.0, 0.0],
         [2.2 + 2.3 + 2.4, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
        [[0.0, 0.0, 12.0 + 14.0, 13.0 + 15.0, 0.0, 0.0, 0.0],
         [22.0 + 23.0 + 24.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
    ])
示例#2
0
  def testStateToOlabel(self):
    labels = [
        [3, 4, 3, 4],
        [1, 1, 1, 0],
    ]
    num_labels = 8

    # 3 frames, 2 batch, 10 states (5 label, 5 blank).
    states = [
        [[0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18, 0.19, 0.20],
         [0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 0.28, 0.29, 0.30]],
        [[1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0],
         [2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3.0]],
        [[11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0],
         [21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0]],
    ]
    labels = ops.convert_to_tensor(labels)
    states = math_ops.log(states)
    olabel = ctc_ops._state_to_olabel(labels, num_labels, states)
    olabel = math_ops.exp(olabel)
    blank = olabel[:, :, 0]
    self.assertAllClose(blank, [
        [0.16 + 0.17 + 0.18 + 0.19 + 0.20,
         0.26 + 0.27 + 0.28 + 0.29 + 0.30],
        [1.6 + 1.7 + 1.8 + 1.9 + 2.0,
         2.6 + 2.7 + 2.8 + 2.9 + 3.0],
        [16.0 + 17.0 + 18.0 + 19.0 + 20.0,
         26.0 + 27.0 + 28.0 + 29.0 + 30.0]
    ])
    self.assertAllClose(olabel[:, :, 1:], [
        [[0.0, 0.0, 0.12 + 0.14, 0.13 + 0.15, 0.0, 0.0, 0.0],
         [0.22 + 0.23 + 0.24, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
        [[0.0, 0.0, 1.2 + 1.4, 1.3 + 1.5, 0.0, 0.0, 0.0],
         [2.2 + 2.3 + 2.4, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
        [[0.0, 0.0, 12.0 + 14.0, 13.0 + 15.0, 0.0, 0.0, 0.0],
         [22.0 + 23.0 + 24.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
    ])