def ctc_align(prob, target): """Perform CTC alignment on torch sequence batches (using ocrolstm)""" prob_ = prob.cpu() target = target.cpu() b, l, d = prob.size() bt, lt, dt = target.size() assert bt == b, (bt, b) assert dt == d, (dt, d) assert is_normalized(prob), prob assert is_normalized(target), target result = torch.rand(1) cctc.ctc_align_targets_batch(result, prob_, target) return typeas(result, prob)
def ctc_align(prob, target): """Perform CTC alignment on torch sequence batches (using ocrolstm). Inputs are in BDL format. """ import cctc assert dlh.sequence_is_normalized(prob), prob assert dlh.sequence_is_normalized(target), target # inputs are BDL prob_ = dlh.novar(prob).permute(0, 2, 1).cpu().contiguous() target_ = dlh.novar(target).permute(0, 2, 1).cpu().contiguous() # prob_ and target_ are both BLD now assert prob_.size(0) == target_.size(0), (prob_.size(), target_.size()) assert prob_.size(2) == target_.size(2), (prob_.size(), target_.size()) assert prob_.size(1) >= target_.size(1), (prob_.size(), target_.size()) result = torch.rand(1) cctc.ctc_align_targets_batch(result, prob_, target_) return dlh.typeas(result.permute(0, 2, 1).contiguous(), prob)
a = torch.randn(3, 3) print a cctc.square(a) print a def rownorm(t): return t / t.sum(1).repeat(1, t.size(1)) def batch_rownorm(t): for i in range(len(t)): t.select(0, i).copy_(rownorm(t.select(0, i))) return t a = rownorm(torch.rand(100, 17)) b = rownorm(torch.rand(20, 17)) c = torch.rand(1, 1) print c cctc.ctc_align_targets(c, a, b) print c a = batch_rownorm(torch.rand(3, 100, 17)) b = batch_rownorm(torch.rand(3, 20, 17)) c = torch.rand(1) print a.size(), b.size(), c.size() print c cctc.ctc_align_targets_batch(c, a, b) print c