def test_sequence_last(): a = nd.arange(0, LARGE_X * SMALL_Y * 2).reshape(LARGE_X, 2, SMALL_Y) # test if returns last sequence b = nd.SequenceLast(a) assert_almost_equal(b, a[-1]) # only checks for (2,SMALL_Y) tensor assert b.shape == (2, SMALL_Y) # test with sequence length # parameter sequence_length - NDArray with shape (batch_size) # (2,3) indicates 2nd sequence from batch 1 and 3rd sequence from batch 2 b = nd.SequenceLast(a, sequence_length=mx.nd.array([2, 3]), use_sequence_length=True) # check if it takes 2nd sequence from the first batch assert b[0][-1] == a[1][0][-1]
def check_sequence_last(): a = nd.arange(0, LARGE_X * 2).reshape(LARGE_X, 2) # test if returns last sequence b = nd.SequenceLast(a) assert_almost_equal(b.asnumpy(), a[-1].asnumpy()) assert b.shape == (2,) # test with sequence length # parameter sequence_length - NDArray with shape (batch_size) # (2,3) indicates 2nd sequence from batch 1 and 3rd sequence from batch 2 # need to mention dtype = int64 for sequence_length ndarray to support large indices # else it defaults to float32 and errors b = nd.SequenceLast(a, sequence_length=mx.nd.array([2, 3], dtype="int64"), use_sequence_length=True) # check if it takes 2nd sequence from the first batch assert b[0] == a[1][0]