Example #1
0
 def test_error_lengths(self, inplace):
     ir = pir.Ir()
     with ir.main_graph():
         t = pir.variable(data)
         with pytest.raises(ValueError):
             if inplace:
                 y = ops.slice_(t, start=[2], stop=[3, 4], axis=[2, 1])
             else:
                 y = ops.slice(t, start=[2], stop=[3, 4], axis=[2, 1])
Example #2
0
    def test_identity_numerically(self, inplace):
        ir = pir.Ir()
        with ir.main_graph():
            t = pir.variable(data)
            if inplace:
                y = ops.slice_(t, axis=0)  # `axis=0` is redundant
            else:
                y = ops.slice(t, axis=0)  # `axis=0` is redundant
            y_host = run_ir(ir, y)

        assert_array_equal(y_host, data)
Example #3
0
    def test_identity_fn(self, inplace):
        ir = pir.Ir()
        with ir.main_graph():
            t = pir.variable(data)
            if inplace:
                y = ops.slice_(t, axis=0)  # `axis=0` is redundant
            else:
                y = ops.slice(t, axis=0)  # `axis=0` is redundant

        assert len(ir.main_graph().get_tensors()) == 1
        assert len(ir.main_graph().get_variables()) == 1
Example #4
0
    def test_start_only(self, inplace):
        ir = pir.Ir()
        with ir.main_graph():
            t = pir.variable(data)
            if inplace:
                y = ops.slice_(t, start=1)
            else:
                y = ops.slice(t, start=1)
            y_host = run_ir(ir, y)

        y_numpy = data[1:]
        assert_array_equal(y_host, y_numpy)
Example #5
0
    def test_axis(self, inplace):
        ir = pir.Ir()
        with ir.main_graph():
            t = pir.variable(data)
            if inplace:
                y = ops.slice_(t, start=[1, 2], stop=[3, 4], axis=[2, 1])
            else:
                y = ops.slice(t, start=[1, 2], stop=[3, 4], axis=[2, 1])
            y_host = run_ir(ir, y)

        y_numpy = data[:, 2:4, 1:3]
        assert_array_equal(y_host, y_numpy)
Example #6
0
    def test_negative_start(self, inplace):
        ir = pir.Ir()
        with ir.main_graph():
            t = pir.variable(data)
            if inplace:
                y = ops.slice_(t, start=-2, step=-1)
            else:
                y = ops.slice(t, start=-2, step=-1)
            y_host = run_ir(ir, y)

        y_numpy = data[-2::-1]
        assert_array_equal(y_host, y_numpy)
Example #7
0
    def test_step(self, inplace):
        ir = pir.Ir()
        with ir.main_graph():
            t = pir.variable(data)
            if inplace:
                y = ops.slice_(t, start=[1, 3], stop=[3, 1], step=[1, -1])
            else:
                y = ops.slice(t, start=[1, 3], stop=[3, 1], step=[1, -1])
            y_host = run_ir(ir, y)

        y_numpy = data[1:3, 3:1:-1]
        assert_array_equal(y_host, y_numpy)
Example #8
0
    def test_stop_only_multidim(self, inplace):
        ir = pir.Ir()
        with ir.main_graph():
            t = pir.variable(data)
            if inplace:
                y = ops.slice_(t, stop=[2, 3])
            else:
                y = ops.slice(t, stop=[2, 3])
            y_host = run_ir(ir, y)

        y_numpy = data[:2, :3]
        assert_array_equal(y_host, y_numpy)
Example #9
0
    def test_fn_numerically(self, inplace):
        ir = pir.Ir()
        g = ir.main_graph()
        with g:
            t = pir.variable(data)
            if inplace:
                y = ops.slice_(t, start=1, stop=3, step=1, axis=0)
            else:
                y = ops.slice(t, start=1, stop=3, step=1, axis=0)
            y_host = run_ir(ir, y)

        y_numpy = data[1:3]
        assert_array_equal(y_host, y_numpy)
Example #10
0
    def test_fn(self, inplace):
        ir = pir.Ir()
        g = ir.main_graph()
        with g:
            t = pir.variable(data)
            if inplace:
                y = ops.slice_(t, start=1, stop=3, step=1, axis=0)
            else:
                y = ops.slice(t, start=1, stop=3, step=1, axis=0)

        if not inplace:
            assert contains_op_of_type("Slice", _ir.op.SliceOp, g)
        else:
            assert contains_op_of_type("SliceInplace", _ir.op.SliceInplaceOp,
                                       g)
        assert len(g.get_tensors()) == 2
        assert len(g.get_variables()) == 1