Beispiel #1
0
def ctc_ref(acts, lbls, utt_lens, lbl_lens):
    """
    CTC reference implementation
    """
    warp_ctc = CTC(on_device='cpu')
    max_t, bsz, nout = acts.shape
    grads = np.zeros_like(acts)
    costs = np.zeros(bsz, dtype=acts.dtype)
    utt_lens = (utt_lens * max_t / 100).astype(np.int32)
    warp_ctc.bind_to_cpu(acts,
                         lbls,
                         utt_lens,
                         lbl_lens,
                         grads,
                         costs,
                         n_threads=8)
    return costs, grads
Beispiel #2
0
def ctc_cpu(acts, lbls, utt_lens, lbl_lens, grads, costs, n_threads=8):
    global warp_ctc
    if warp_ctc is None:
        warp_ctc = CTC(on_device='cpu')
    costs.fill(0.)
    grads.fill(0.)
    max_t, bsz, nout = acts.shape
    utt_lens = (utt_lens * max_t / 100).astype(np.int32)
    lbls = lbls.astype(np.int32)
    lbl_lens = lbl_lens.astype(np.int32)
    warp_ctc.bind_to_cpu(acts,
                         lbls,
                         utt_lens,
                         lbl_lens,
                         grads,
                         costs,
                         n_threads=n_threads)