def test_ctc_loss(): A = np.ones((2, INT_OVERFLOW, 4)) A.attach_grad() with mx.autograd.record(): B = npx.ctc_loss(A, np.ones((INT_OVERFLOW, 2))) assert B.shape == (INT_OVERFLOW, ) assert type(B).__name__ == 'ndarray' B.backward() assert A.grad.shape == (2, INT_OVERFLOW, 4) assert A.grad[0][0][0] == 0
def test_ctc_loss(): def test_ctc_loss_size_check(A, label): assertRaises(ValueError, npx.ctc_loss, A, label) L_SEQ, L_ALP, L_LAB, BAT = 2**10, 2**20, 2**6, 2 A = np.zeros((L_SEQ, BAT, L_ALP)) label = np.random.randint(0, L_ALP, (BAT, L_LAB)) # test for expected exception test_ctc_loss_size_check(A, label) # now we shrink the size a little bit and test for an allowed case L_ALP = 2**20 - 1 A = np.zeros((L_SEQ, BAT, L_ALP)) label = np.random.randint(0, L_ALP, (BAT, L_LAB)) A.attach_grad() with mx.autograd.record(): B = npx.ctc_loss(A, label) assert B.shape == (BAT, ) assert type(B[0]).__name__ == 'ndarray' B.backward() assert A.grad.shape == (L_SEQ, BAT, L_ALP) assert type(A[0]).__name__ == 'ndarray'