def test_sliced_batch_axis(): """ slicing a batch axis should result in a batch axis """ a = ng.make_axis(10, name='N') s = slice_axis(a, slice(0, 5)) assert s.is_batch is True
def test_sliced_recurrent_axis(): """ slicing a recurrent axis should result in a recurrent axis """ a = ng.make_axis(10, name='REC') s = slice_axis(a, slice(0, 5)) assert s.is_recurrent is True
def test_sliced_axis_flip(): a = ng.make_axis(10) s = slice_axis(a, slice(None, None, -1)) assert s.length == 10
def test_sliced_axis_invalid_step(): a = ng.make_axis(10) with pytest.raises(ValueError): slice_axis(a, slice(0, 5, 2))
def test_sliced_axis_negative_invalid(): a = ng.make_axis(10) s = slice_axis(a, slice(0, 5, -1)) assert s.length == 0
def test_sliced_axis_negative(): a = ng.make_axis(10) s = slice_axis(a, slice(5, 0, -1)) assert s.length == 5
def test_sliced_axis_none_end(): a = ng.make_axis(10) s = slice_axis(a, slice(0, None)) assert s.length == 10
def test_sliced_axis_invalid(): a = ng.make_axis(10) s = slice_axis(a, slice(5, 0)) assert s.length == 0
def test_sliced_axis(): a = ng.make_axis(10) s = slice_axis(a, slice(0, 5)) assert s.length == 5