def test_ctc_loss():
    A = np.ones((2, INT_OVERFLOW, 4))
    A.attach_grad()
    with mx.autograd.record():
        B = npx.ctc_loss(A, np.ones((INT_OVERFLOW, 2)))
    assert B.shape == (INT_OVERFLOW, )
    assert type(B).__name__ == 'ndarray'
    B.backward()
    assert A.grad.shape == (2, INT_OVERFLOW, 4)
    assert A.grad[0][0][0] == 0
def test_ctc_loss():
    def test_ctc_loss_size_check(A, label):
        assertRaises(ValueError, npx.ctc_loss, A, label)

    L_SEQ, L_ALP, L_LAB, BAT = 2**10, 2**20, 2**6, 2
    A = np.zeros((L_SEQ, BAT, L_ALP))
    label = np.random.randint(0, L_ALP, (BAT, L_LAB))
    # test for expected exception
    test_ctc_loss_size_check(A, label)
    # now we shrink the size a little bit and test for an allowed case
    L_ALP = 2**20 - 1
    A = np.zeros((L_SEQ, BAT, L_ALP))
    label = np.random.randint(0, L_ALP, (BAT, L_LAB))
    A.attach_grad()
    with mx.autograd.record():
        B = npx.ctc_loss(A, label)
    assert B.shape == (BAT, )
    assert type(B[0]).__name__ == 'ndarray'
    B.backward()
    assert A.grad.shape == (L_SEQ, BAT, L_ALP)
    assert type(A[0]).__name__ == 'ndarray'