def test_fuse_getitem():
    pairs = [
        ((getitem, (getitem, 'x', slice(1000, 2000)), slice(15, 20)),
         (getitem, 'x', slice(1015, 1020))),
        ((getitem, (getitem, 'x', (slice(1000, 2000), slice(100, 200))),
          (slice(15, 20), slice(50, 60))), (getitem, 'x', (slice(1015, 1020),
                                                           slice(150, 160)))),
        ((getitem, (getitem, 'x', slice(1000,
                                        2000)), 10), (getitem, 'x', 1010)),
        ((getitem, (getitem, 'x', (slice(1000, 2000), 10)), (slice(15, 20), )),
         (getitem, 'x', (slice(1015, 1020), 10))),
        ((getitem, (getitem, 'x', (10, slice(1000, 2000))), (slice(15, 20), )),
         (getitem, 'x', (10, slice(1015, 1020)))),
        ((getitem, (getitem, 'x', (slice(1000, 2000), slice(100, 200))),
          (slice(None, None), slice(50,
                                    60))), (getitem, 'x', (slice(1000, 2000),
                                                           slice(150, 160)))),
        ((getitem, (getitem, 'x', (None, slice(None, None))),
          (slice(None, None), 5)), (getitem, 'x', (None, 5))),
        ((getitem, (getitem, 'x', (slice(1000, 2000), slice(10, 20))),
          (slice(5, 10), )), (getitem, 'x', (slice(1005, 1010), slice(10,
                                                                      20)))),
        ((getitem, (getitem, 'x', (slice(1000, 2000), )),
          (slice(5, 10), slice(10, 20))), (getitem, 'x', (slice(1005, 1010),
                                                          slice(10, 20))))
    ]

    for inp, expected in pairs:
        result = rewrite_rules.rewrite(inp)
        assert result == expected
Exemple #2
0
def test_fuse_getitem():
    pairs = [((getarray, (getarray, 'x', slice(1000, 2000)), slice(15, 20)),
              (getarray, 'x', slice(1015, 1020))),

             ((getitem, (getarray, 'x', (slice(1000, 2000), slice(100, 200))),
                        (slice(15, 20), slice(50, 60))),
              (getarray, 'x', (slice(1015, 1020), slice(150, 160)))),

             ((getarray, (getarray, 'x', slice(1000, 2000)), 10),
              (getarray, 'x', 1010)),

             ((getitem, (getarray, 'x', (slice(1000, 2000), 10)),
                        (slice(15, 20),)),
              (getarray, 'x', (slice(1015, 1020), 10))),

             ((getarray, (getarray, 'x', (10, slice(1000, 2000))),
                        (slice(15, 20),)),
              (getarray, 'x', (10, slice(1015, 1020)))),

             ((getarray, (getarray, 'x', (slice(1000, 2000), slice(100, 200))),
                        (slice(None, None), slice(50, 60))),
              (getarray, 'x', (slice(1000, 2000), slice(150, 160)))),

             ((getarray, (getarray, 'x', (None, slice(None, None))),
                        (slice(None, None), 5)),
              (getarray, 'x', (None, 5))),

             ((getarray, (getarray, 'x', (slice(1000, 2000), slice(10, 20))),
                        (slice(5, 10),)),
              (getarray, 'x', (slice(1005, 1010), slice(10, 20)))),

             ((getitem, (getitem, 'x', (slice(1000, 2000),)),
                        (slice(5, 10), slice(10, 20))),
              (getitem, 'x', (slice(1005, 1010), slice(10, 20))))

        ]

    for inp, expected in pairs:
        result = rewrite_rules.rewrite(inp)
        assert result == expected
Exemple #3
0
def test_hard_fuse_slice_cases():
    term = (getarray, (getarray, 'x', (None, slice(None, None))),
                     (slice(None, None), 5))
    assert rewrite_rules.rewrite(term) == (getarray, 'x', (None, 5))