def ctc_greedy_decoder(inputs, sequence_length, merge_repeated=True): """Performs greedy decoding on the logits given in input (best path). Note: Regardless of the value of merge_repeated, if the maximum index of a given time and batch corresponds to the blank index `(num_classes - 1)`, no new element is emitted. If `merge_repeated` is `True`, merge repeated classes in output. This means that if consecutive logits' maximum indices are the same, only the first of these is emitted. The sequence `A B B * B * B` (where '*' is the blank label) becomes * `A B B B` if `merge_repeated=True`. * `A B B B B` if `merge_repeated=False`. Args: inputs: 3-D `float` `Tensor` sized `[max_time, batch_size, num_classes]`. The logits. sequence_length: 1-D `int32` vector containing sequence lengths, having size `[batch_size]`. merge_repeated: Boolean. Default: True. Returns: A tuple `(decoded, neg_sum_logits)` where decoded: A single-element list. `decoded[0]` is an `SparseTensor` containing the decoded outputs s.t.: `decoded.indices`: Indices matrix `(total_decoded_outputs, 2)`. The rows store: `[batch, time]`. `decoded.values`: Values vector, size `(total_decoded_outputs)`. The vector stores the decoded classes. `decoded.dense_shape`: Shape vector, size `(2)`. The shape values are: `[batch_size, max_decoded_length]` neg_sum_logits: A `float` matrix `(batch_size x 1)` containing, for the sequence found, the negative of the sum of the greatest logit at each timeframe. """ outputs = gen_ctc_ops.ctc_greedy_decoder(inputs, sequence_length, merge_repeated=merge_repeated) (decoded_ix, decoded_val, decoded_shape, log_probabilities) = outputs return ([ sparse_tensor.SparseTensor(decoded_ix, decoded_val, decoded_shape) ], log_probabilities)
def ctc_greedy_decoder(inputs, sequence_length, merge_repeated=True): """Performs greedy decoding on the logits given in input (best path). Note: Regardless of the value of merge_repeated, if the maximum index of a given time and batch corresponds to the blank index `(num_classes - 1)`, no new element is emitted. If `merge_repeated` is `True`, merge repeated classes in output. This means that if consecutive logits' maximum indices are the same, only the first of these is emitted. The sequence `A B B * B * B` (where '*' is the blank label) becomes * `A B B B` if `merge_repeated=True`. * `A B B B B` if `merge_repeated=False`. Args: inputs: 3-D `float` `Tensor` sized `[max_time, batch_size, num_classes]`. The logits. sequence_length: 1-D `int32` vector containing sequence lengths, having size `[batch_size]`. merge_repeated: Boolean. Default: True. Returns: A tuple `(decoded, neg_sum_logits)` where decoded: A single-element list. `decoded[0]` is an `SparseTensor` containing the decoded outputs s.t.: `decoded.indices`: Indices matrix `(total_decoded_outputs, 2)`. The rows store: `[batch, time]`. `decoded.values`: Values vector, size `(total_decoded_outputs)`. The vector stores the decoded classes. `decoded.dense_shape`: Shape vector, size `(2)`. The shape values are: `[batch_size, max_decoded_length]` neg_sum_logits: A `float` matrix `(batch_size x 1)` containing, for the sequence found, the negative of the sum of the greatest logit at each timeframe. """ outputs = gen_ctc_ops.ctc_greedy_decoder( inputs, sequence_length, merge_repeated=merge_repeated) (decoded_ix, decoded_val, decoded_shape, log_probabilities) = outputs return ([sparse_tensor.SparseTensor(decoded_ix, decoded_val, decoded_shape)], log_probabilities)
def decode_ctc( logits: Union[list, np.ndarray], merge_repeated=True, alphabet: Union[None, np.ndarray, List[str]] = None, seq_lens: Optional[List[int]] = None, ): if alphabet is not None and isinstance(alphabet, list): alphabet = np.array(alphabet) if isinstance(logits, list): logits = np.array(logits) decoded_ix, decoded_val, decoded_shape, log_probabilities = ctc_greedy_decoder( np.transpose(logits, (1, 0, 2)), np.full((logits.shape[0], ), fill_value=logits.shape[1], dtype=np.int) if not seq_lens else seq_lens, merge_repeated=merge_repeated, ) return _decoded_to_rows(decoded_ix.numpy(), decoded_val.numpy(), decoded_shape.numpy(), alphabet=alphabet, aslist=True)