def __init__(self, transformer, op): super(CTCKernel, self).__init__(transformer) self.warp_ctc = CTC(on_device='gpu') self.at_runtime = self.transformer.runtime self.stream = self.at_runtime.stream self.costs = op.tensor_description() (self.activs, self.lbls, self.uttlens_pct, self.lbl_lens, self.grads) = (_ for _ in op.call_info()) self.max_t, self.bsz, self.nout = self.activs.axes.lengths
def ctc_ref(acts, lbls, utt_lens, lbl_lens): """ CTC reference implementation """ warp_ctc = CTC(on_device='cpu') max_t, bsz, nout = acts.shape grads = np.zeros_like(acts) costs = np.zeros(bsz, dtype=acts.dtype) utt_lens = (utt_lens * max_t / 100).astype(np.int32) warp_ctc.bind_to_cpu(acts, lbls, utt_lens, lbl_lens, grads, costs, n_threads=8) return costs, grads
def ctc_cpu(acts, lbls, utt_lens, lbl_lens, grads, costs, n_threads=8): global warp_ctc if warp_ctc is None: warp_ctc = CTC(on_device='cpu') costs.fill(0.) grads.fill(0.) max_t, bsz, nout = acts.shape utt_lens = (utt_lens * max_t / 100).astype(np.int32) lbls = lbls.astype(np.int32) lbl_lens = lbl_lens.astype(np.int32) warp_ctc.bind_to_cpu(acts, lbls, utt_lens, lbl_lens, grads, costs, n_threads=n_threads)
class CTCKernel(GPUKernel): def __init__(self, transformer, op): super(CTCKernel, self).__init__(transformer) self.warp_ctc = CTC(on_device='gpu') self.at_runtime = self.transformer.runtime self.stream = self.at_runtime.stream self.costs = op.tensor_description() (self.activs, self.lbls, self.uttlens_pct, self.lbl_lens, self.grads) = (_ for _ in op.call_info()) self.max_t, self.bsz, self.nout = self.activs.axes.lengths def bind_buffers(self): self.activs = self.tensor_view_from_td(self.activs).tensor self.lbls = self.tensor_view_from_td(self.lbls).tensor self.uttlens_pct = self.tensor_view_from_td(self.uttlens_pct).tensor self.lbl_lens = self.tensor_view_from_td(self.lbl_lens).tensor self.grads = self.tensor_view_from_td(self.grads).tensor self.costs = self.tensor_view_from_td(self.costs).tensor super(CTCKernel, self).bind_buffers() def execute(self): self.grads.fill(0.) self.costs.fill(0.) warp_utt_lens = (self.uttlens_pct.get().ravel() * self.max_t / 100.).astype(np.int32) warp_lbls = self.lbls.get().ravel().astype(np.int32) warp_lbl_lens = self.lbl_lens.get().ravel().astype(np.int32) scratch_size = self.warp_ctc.get_gpu_workspace_size( warp_lbl_lens, warp_utt_lens, self.nout, self.bsz) self.at_runtime.set_scratch_size(scratch_size) workspace = self.at_runtime.scratch_buffer(scratch_size) self.warp_ctc.bind_to_gpu(self.activs, self.grads, warp_lbls, warp_lbl_lens, warp_utt_lens, self.costs, workspace, scratch_size, self.stream)