def test_fuse_getitem(): pairs = [ ((getitem, (getitem, 'x', slice(1000, 2000)), slice(15, 20)), (getitem, 'x', slice(1015, 1020))), ((getitem, (getitem, 'x', (slice(1000, 2000), slice(100, 200))), (slice(15, 20), slice(50, 60))), (getitem, 'x', (slice(1015, 1020), slice(150, 160)))), ((getitem, (getitem, 'x', slice(1000, 2000)), 10), (getitem, 'x', 1010)), ((getitem, (getitem, 'x', (slice(1000, 2000), 10)), (slice(15, 20), )), (getitem, 'x', (slice(1015, 1020), 10))), ((getitem, (getitem, 'x', (10, slice(1000, 2000))), (slice(15, 20), )), (getitem, 'x', (10, slice(1015, 1020)))), ((getitem, (getitem, 'x', (slice(1000, 2000), slice(100, 200))), (slice(None, None), slice(50, 60))), (getitem, 'x', (slice(1000, 2000), slice(150, 160)))), ((getitem, (getitem, 'x', (None, slice(None, None))), (slice(None, None), 5)), (getitem, 'x', (None, 5))), ((getitem, (getitem, 'x', (slice(1000, 2000), slice(10, 20))), (slice(5, 10), )), (getitem, 'x', (slice(1005, 1010), slice(10, 20)))), ((getitem, (getitem, 'x', (slice(1000, 2000), )), (slice(5, 10), slice(10, 20))), (getitem, 'x', (slice(1005, 1010), slice(10, 20)))) ] for inp, expected in pairs: result = rewrite_rules.rewrite(inp) assert result == expected
def test_fuse_getitem(): pairs = [((getarray, (getarray, 'x', slice(1000, 2000)), slice(15, 20)), (getarray, 'x', slice(1015, 1020))), ((getitem, (getarray, 'x', (slice(1000, 2000), slice(100, 200))), (slice(15, 20), slice(50, 60))), (getarray, 'x', (slice(1015, 1020), slice(150, 160)))), ((getarray, (getarray, 'x', slice(1000, 2000)), 10), (getarray, 'x', 1010)), ((getitem, (getarray, 'x', (slice(1000, 2000), 10)), (slice(15, 20),)), (getarray, 'x', (slice(1015, 1020), 10))), ((getarray, (getarray, 'x', (10, slice(1000, 2000))), (slice(15, 20),)), (getarray, 'x', (10, slice(1015, 1020)))), ((getarray, (getarray, 'x', (slice(1000, 2000), slice(100, 200))), (slice(None, None), slice(50, 60))), (getarray, 'x', (slice(1000, 2000), slice(150, 160)))), ((getarray, (getarray, 'x', (None, slice(None, None))), (slice(None, None), 5)), (getarray, 'x', (None, 5))), ((getarray, (getarray, 'x', (slice(1000, 2000), slice(10, 20))), (slice(5, 10),)), (getarray, 'x', (slice(1005, 1010), slice(10, 20)))), ((getitem, (getitem, 'x', (slice(1000, 2000),)), (slice(5, 10), slice(10, 20))), (getitem, 'x', (slice(1005, 1010), slice(10, 20)))) ] for inp, expected in pairs: result = rewrite_rules.rewrite(inp) assert result == expected
def test_hard_fuse_slice_cases(): term = (getarray, (getarray, 'x', (None, slice(None, None))), (slice(None, None), 5)) assert rewrite_rules.rewrite(term) == (getarray, 'x', (None, 5))