def test_batch_flatten(): A = np.ones((2, 1, INT_OVERFLOW)) A.attach_grad() with mx.autograd.record(): B = npx.batch_flatten(A) assert B.shape == (2, INT_OVERFLOW) assert B[0][0] == 1 B.backward() assert A.grad.shape == (2, 1, INT_OVERFLOW) assert A.grad[0][0][0] == 1
def flatten_pred(pred): return npx.batch_flatten(pred.transpose(0, 2, 3, 1))