def test_error_lengths(self, inplace): ir = pir.Ir() with ir.main_graph(): t = pir.variable(data) with pytest.raises(ValueError): if inplace: y = ops.slice_(t, start=[2], stop=[3, 4], axis=[2, 1]) else: y = ops.slice(t, start=[2], stop=[3, 4], axis=[2, 1])
def test_identity_numerically(self, inplace): ir = pir.Ir() with ir.main_graph(): t = pir.variable(data) if inplace: y = ops.slice_(t, axis=0) # `axis=0` is redundant else: y = ops.slice(t, axis=0) # `axis=0` is redundant y_host = run_ir(ir, y) assert_array_equal(y_host, data)
def test_identity_fn(self, inplace): ir = pir.Ir() with ir.main_graph(): t = pir.variable(data) if inplace: y = ops.slice_(t, axis=0) # `axis=0` is redundant else: y = ops.slice(t, axis=0) # `axis=0` is redundant assert len(ir.main_graph().get_tensors()) == 1 assert len(ir.main_graph().get_variables()) == 1
def test_start_only(self, inplace): ir = pir.Ir() with ir.main_graph(): t = pir.variable(data) if inplace: y = ops.slice_(t, start=1) else: y = ops.slice(t, start=1) y_host = run_ir(ir, y) y_numpy = data[1:] assert_array_equal(y_host, y_numpy)
def test_axis(self, inplace): ir = pir.Ir() with ir.main_graph(): t = pir.variable(data) if inplace: y = ops.slice_(t, start=[1, 2], stop=[3, 4], axis=[2, 1]) else: y = ops.slice(t, start=[1, 2], stop=[3, 4], axis=[2, 1]) y_host = run_ir(ir, y) y_numpy = data[:, 2:4, 1:3] assert_array_equal(y_host, y_numpy)
def test_negative_start(self, inplace): ir = pir.Ir() with ir.main_graph(): t = pir.variable(data) if inplace: y = ops.slice_(t, start=-2, step=-1) else: y = ops.slice(t, start=-2, step=-1) y_host = run_ir(ir, y) y_numpy = data[-2::-1] assert_array_equal(y_host, y_numpy)
def test_step(self, inplace): ir = pir.Ir() with ir.main_graph(): t = pir.variable(data) if inplace: y = ops.slice_(t, start=[1, 3], stop=[3, 1], step=[1, -1]) else: y = ops.slice(t, start=[1, 3], stop=[3, 1], step=[1, -1]) y_host = run_ir(ir, y) y_numpy = data[1:3, 3:1:-1] assert_array_equal(y_host, y_numpy)
def test_stop_only_multidim(self, inplace): ir = pir.Ir() with ir.main_graph(): t = pir.variable(data) if inplace: y = ops.slice_(t, stop=[2, 3]) else: y = ops.slice(t, stop=[2, 3]) y_host = run_ir(ir, y) y_numpy = data[:2, :3] assert_array_equal(y_host, y_numpy)
def test_fn_numerically(self, inplace): ir = pir.Ir() g = ir.main_graph() with g: t = pir.variable(data) if inplace: y = ops.slice_(t, start=1, stop=3, step=1, axis=0) else: y = ops.slice(t, start=1, stop=3, step=1, axis=0) y_host = run_ir(ir, y) y_numpy = data[1:3] assert_array_equal(y_host, y_numpy)
def test_fn(self, inplace): ir = pir.Ir() g = ir.main_graph() with g: t = pir.variable(data) if inplace: y = ops.slice_(t, start=1, stop=3, step=1, axis=0) else: y = ops.slice(t, start=1, stop=3, step=1, axis=0) if not inplace: assert contains_op_of_type("Slice", _ir.op.SliceOp, g) else: assert contains_op_of_type("SliceInplace", _ir.op.SliceInplaceOp, g) assert len(g.get_tensors()) == 2 assert len(g.get_variables()) == 1