def test_sequence_last():
    a = nd.arange(0, LARGE_X * SMALL_Y * 2).reshape(LARGE_X, 2, SMALL_Y)

    # test if returns last sequence
    b = nd.SequenceLast(a)
    assert_almost_equal(b, a[-1])  # only checks for (2,SMALL_Y) tensor
    assert b.shape == (2, SMALL_Y)

    # test with sequence length
    # parameter sequence_length - NDArray with shape (batch_size)
    # (2,3) indicates 2nd sequence from batch 1 and 3rd sequence from batch 2
    b = nd.SequenceLast(a, sequence_length=mx.nd.array([2, 3]),
                        use_sequence_length=True)
    # check if it takes 2nd sequence from the first batch
    assert b[0][-1] == a[1][0][-1]
Beispiel #2
0
 def check_sequence_last():
     a = nd.arange(0, LARGE_X * 2).reshape(LARGE_X, 2)
     # test if returns last sequence
     b = nd.SequenceLast(a)
     assert_almost_equal(b.asnumpy(), a[-1].asnumpy())
     assert b.shape == (2,)
     # test with sequence length
     # parameter sequence_length - NDArray with shape (batch_size)
     # (2,3) indicates 2nd sequence from batch 1 and 3rd sequence from batch 2
     # need to mention dtype = int64 for sequence_length ndarray to support large indices
     # else it defaults to float32 and errors
     b = nd.SequenceLast(a, sequence_length=mx.nd.array([2, 3], dtype="int64"),
                         use_sequence_length=True)
     # check if it takes 2nd sequence from the first batch
     assert b[0] == a[1][0]