Exemplo n.º 1
0
def ctc_greedy_decoder(inputs, sequence_length, merge_repeated=True):
    """Performs greedy decoding on the logits given in input (best path).

  Note: Regardless of the value of merge_repeated, if the maximum index of a
  given time and batch corresponds to the blank index `(num_classes - 1)`, no
  new element is emitted.

  If `merge_repeated` is `True`, merge repeated classes in output.
  This means that if consecutive logits' maximum indices are the same,
  only the first of these is emitted.  The sequence `A B B * B * B` (where '*'
  is the blank label) becomes

    * `A B B B` if `merge_repeated=True`.
    * `A B B B B` if `merge_repeated=False`.

  Args:
    inputs: 3-D `float` `Tensor` sized `[max_time, batch_size, num_classes]`.
      The logits.
    sequence_length: 1-D `int32` vector containing sequence lengths, having size
      `[batch_size]`.
    merge_repeated: Boolean.  Default: True.

  Returns:
    A tuple `(decoded, neg_sum_logits)` where

    decoded: A single-element list. `decoded[0]`
      is an `SparseTensor` containing the decoded outputs s.t.:

      `decoded.indices`: Indices matrix `(total_decoded_outputs, 2)`.
        The rows store: `[batch, time]`.

      `decoded.values`: Values vector, size `(total_decoded_outputs)`.
        The vector stores the decoded classes.

      `decoded.dense_shape`: Shape vector, size `(2)`.
        The shape values are: `[batch_size, max_decoded_length]`

    neg_sum_logits: A `float` matrix `(batch_size x 1)` containing, for the
        sequence found, the negative of the sum of the greatest logit at each
        timeframe.
  """
    outputs = gen_ctc_ops.ctc_greedy_decoder(inputs,
                                             sequence_length,
                                             merge_repeated=merge_repeated)
    (decoded_ix, decoded_val, decoded_shape, log_probabilities) = outputs
    return ([
        sparse_tensor.SparseTensor(decoded_ix, decoded_val, decoded_shape)
    ], log_probabilities)
Exemplo n.º 2
0
def ctc_greedy_decoder(inputs, sequence_length, merge_repeated=True):
  """Performs greedy decoding on the logits given in input (best path).

  Note: Regardless of the value of merge_repeated, if the maximum index of a
  given time and batch corresponds to the blank index `(num_classes - 1)`, no
  new element is emitted.

  If `merge_repeated` is `True`, merge repeated classes in output.
  This means that if consecutive logits' maximum indices are the same,
  only the first of these is emitted.  The sequence `A B B * B * B` (where '*'
  is the blank label) becomes

    * `A B B B` if `merge_repeated=True`.
    * `A B B B B` if `merge_repeated=False`.

  Args:
    inputs: 3-D `float` `Tensor` sized
      `[max_time, batch_size, num_classes]`.  The logits.
    sequence_length: 1-D `int32` vector containing sequence lengths,
      having size `[batch_size]`.
    merge_repeated: Boolean.  Default: True.

  Returns:
    A tuple `(decoded, neg_sum_logits)` where

    decoded: A single-element list. `decoded[0]`
      is an `SparseTensor` containing the decoded outputs s.t.:

      `decoded.indices`: Indices matrix `(total_decoded_outputs, 2)`.
        The rows store: `[batch, time]`.

      `decoded.values`: Values vector, size `(total_decoded_outputs)`.
        The vector stores the decoded classes.

      `decoded.dense_shape`: Shape vector, size `(2)`.
        The shape values are: `[batch_size, max_decoded_length]`

    neg_sum_logits: A `float` matrix `(batch_size x 1)` containing, for the
        sequence found, the negative of the sum of the greatest logit at each
        timeframe.
  """
  outputs = gen_ctc_ops.ctc_greedy_decoder(
      inputs, sequence_length, merge_repeated=merge_repeated)
  (decoded_ix, decoded_val, decoded_shape, log_probabilities) = outputs
  return ([sparse_tensor.SparseTensor(decoded_ix, decoded_val, decoded_shape)],
          log_probabilities)
Exemplo n.º 3
0
def decode_ctc(
    logits: Union[list, np.ndarray],
    merge_repeated=True,
    alphabet: Union[None, np.ndarray, List[str]] = None,
    seq_lens: Optional[List[int]] = None,
):
    if alphabet is not None and isinstance(alphabet, list):
        alphabet = np.array(alphabet)
    if isinstance(logits, list):
        logits = np.array(logits)
    decoded_ix, decoded_val, decoded_shape, log_probabilities = ctc_greedy_decoder(
        np.transpose(logits, (1, 0, 2)),
        np.full((logits.shape[0], ), fill_value=logits.shape[1], dtype=np.int)
        if not seq_lens else seq_lens,
        merge_repeated=merge_repeated,
    )
    return _decoded_to_rows(decoded_ix.numpy(),
                            decoded_val.numpy(),
                            decoded_shape.numpy(),
                            alphabet=alphabet,
                            aslist=True)