Example #1
0
def _ctc_loss_with_beam_search(logits,
                               sparse_labels,
                               seq_length,
                               top_path=1,
                               merge_repeated=False):

    ctc_loss = math_ops.reduce_mean(
        ctc_ops.ctc_loss(sparse_labels, logits, seq_length))
    pre_label_tensors, log_prob = tf.nn.ctc_beam_search_decoder(
        logits, seq_length, merge_repeated=merge_repeated, top_paths=top_path)
    top1_label_tensor = math_ops.cast(pre_label_tensors[0], dtypes.int32)
    top1_ed = math_ops.reduce_mean(
        array_ops.edit_distance(top1_label_tensor, sparse_labels))
    return ctc_loss, top1_ed, pre_label_tensors, log_prob
  def _testEditDistanceST(self,
                          hypothesis_st,
                          truth_st,
                          normalize,
                          expected_output,
                          expected_shape,
                          expected_err_re=None):
    edit_distance = array_ops.edit_distance(
        hypothesis=hypothesis_st, truth=truth_st, normalize=normalize)

    if expected_err_re is None:
      self.assertEqual(edit_distance.get_shape(), expected_shape)
      output = edit_distance.eval()
      self.assertAllClose(output, expected_output)
    else:
      with self.assertRaisesOpError(expected_err_re):
        edit_distance.eval()
  def _testEditDistanceST(self,
                          hypothesis_st,
                          truth_st,
                          normalize,
                          expected_output,
                          expected_shape,
                          expected_err_re=None):
    edit_distance = array_ops.edit_distance(
        hypothesis=hypothesis_st, truth=truth_st, normalize=normalize)

    if expected_err_re is None:
      self.assertEqual(edit_distance.get_shape(), expected_shape)
      output = edit_distance.eval()
      self.assertAllClose(output, expected_output)
    else:
      with self.assertRaisesOpError(expected_err_re):
        edit_distance.eval()