Beispiel #1
0
def ctc_align(prob, target):
    """Perform CTC alignment on torch sequence batches (using ocrolstm)"""
    prob_ = prob.cpu()
    target = target.cpu()
    b, l, d = prob.size()
    bt, lt, dt = target.size()
    assert bt == b, (bt, b)
    assert dt == d, (dt, d)
    assert is_normalized(prob), prob
    assert is_normalized(target), target
    result = torch.rand(1)
    cctc.ctc_align_targets_batch(result, prob_, target)
    return typeas(result, prob)
Beispiel #2
0
def ctc_align(prob, target):
    """Perform CTC alignment on torch sequence batches (using ocrolstm).

    Inputs are in BDL format.
    """
    import cctc
    assert dlh.sequence_is_normalized(prob), prob
    assert dlh.sequence_is_normalized(target), target
    # inputs are BDL
    prob_ = dlh.novar(prob).permute(0, 2, 1).cpu().contiguous()
    target_ = dlh.novar(target).permute(0, 2, 1).cpu().contiguous()
    # prob_ and target_ are both BLD now
    assert prob_.size(0) == target_.size(0), (prob_.size(), target_.size())
    assert prob_.size(2) == target_.size(2), (prob_.size(), target_.size())
    assert prob_.size(1) >= target_.size(1), (prob_.size(), target_.size())
    result = torch.rand(1)
    cctc.ctc_align_targets_batch(result, prob_, target_)
    return dlh.typeas(result.permute(0, 2, 1).contiguous(), prob)
Beispiel #3
0
a = torch.randn(3, 3)
print a
cctc.square(a)
print a


def rownorm(t):
    return t / t.sum(1).repeat(1, t.size(1))


def batch_rownorm(t):
    for i in range(len(t)):
        t.select(0, i).copy_(rownorm(t.select(0, i)))
    return t


a = rownorm(torch.rand(100, 17))
b = rownorm(torch.rand(20, 17))
c = torch.rand(1, 1)
print c
cctc.ctc_align_targets(c, a, b)
print c

a = batch_rownorm(torch.rand(3, 100, 17))
b = batch_rownorm(torch.rand(3, 20, 17))
c = torch.rand(1)
print a.size(), b.size(), c.size()
print c
cctc.ctc_align_targets_batch(c, a, b)
print c