コード例 #1
0
  def test_logits_blankid_to_last(self):
    ''' unit test case for the logits_blankid_to_last interface '''
    with self.cached_session():

      with self.assertRaises(ValueError) as valueErr:
        logits = ctc_utils.logits_blankid_to_last(logits=tf.constant(self.logits), blank_index=10)
      the_exception = valueErr.exception
      self.assertEqual(
          str(the_exception),
          'blank_index must be less than or equal to num_class - 1')

      logits = ctc_utils.logits_blankid_to_last(logits=tf.constant(self.logits), blank_index=0)
      logits_transform = np.asarray(
        [[[0.221185, 0.0917319, 0.0129757, 0.0142857, 0.0260553, 0.633766],
          [0.588392, 0.278779, 0.0055756, 0.00569609, 0.010436, 0.111121],
          [0.633813, 0.321418, 0.00249248, 0.00272882, 0.0037688, 0.0357786],
          [0.643849, 0.280111, 0.00283995, 0.0035545, 0.00331533, 0.0663296],
          [0.196634, 0.123377, 0.50648837, 0.00903441, 0.00623107, 0.158235]],
         [[0.28562, 0.0831517, 0.0862751, 0.0816851, 0.161508, 0.30176],
          [0.397533, 0.0557226, 0.0546814, 0.0557528, 0.19549, 0.24082],
          [0.450868, 0.0389607, 0.038309, 0.0391602, 0.202456, 0.230246],
          [0.429522, 0.0326593, 0.0339046, 0.0326856, 0.190345, 0.280884],
          [0.315517, 0.0338439, 0.0393744, 0.0339315, 0.154046, 0.423286]]],
        dtype=np.float32)
      self.assertAllClose(logits.eval(), logits)
コード例 #2
0
def ctc_decode_blankid_to_last(logits, sequence_length, blank_id=None):
    '''
    Moves the blank_label cloumn to the end of the logit matrix,
    and adjust the rank of sequence_length to 1
    param: logits, (B, T, C), output of ctc asr model
    param: sequence_length, (B, 1), sequence lengths
    param: blank_id, None, default blank_id is 0, same to espnet.
    return: logits_return, (T, B, C)
            sequence_length_return, (B)
    '''

    logits = tf.transpose(logits, [1, 0, 2])
    #blank_id=0 is used as default in Espnet,
    #while blank_id is set as C-1 in tf.nn.ctc_decoder
    if blank_id is None:
        blank_id = 0
    logits = ctc_utils.logits_blankid_to_last(logits=logits,
                                              blank_index=blank_id)

    sequence_length_return = tf.cond(
        pred=tf.equal(tf.rank(sequence_length), 1),
        true_fn=lambda: sequence_length,
        false_fn=lambda: tf.squeeze(sequence_length),
    )

    return logits, sequence_length_return, blank_id
コード例 #3
0
ファイル: loss_utils.py プロジェクト: sumepr/delta
def ctc_data_transform(labels, logits, blank_index):
  '''
  data transform according blank_index
  '''
  logits = ctc_utils.logits_blankid_to_last(
      logits=logits, blank_index=blank_index)

  num_class = logits.shape[2]
  labels = ctc_utils.labels_blankid_to_last(
      labels=labels, blank_index=blank_index, num_class=num_class)
  return labels, logits