def validate(shape, value, dtype): def before_left(x, elem_op, full): return elem_op(full, x) def after_left(x, elem_op, value): return elem_op(relay.const(value, dtype), x) def before_right(x, elem_op, full): return elem_op(x, full) def after_right(x, elem_op, value): return elem_op(x, relay.const(value, dtype)) x = relay.var("x", shape=shape, dtype=dtype) elem_ops = [relay.add, relay.multiply, relay.subtract, relay.divide] full_ops = [] if value == 0: full_ops.append(relay.zeros(shape, dtype)) full_ops.append(relay.zeros_like(x)) if value == 1: full_ops.append(relay.ones(shape, dtype)) full_ops.append(relay.ones_like(x)) else: full_ops.append(relay.full(relay.const(value, dtype), shape)) full_ops.append(relay.full_like(x, relay.const(value, dtype))) for op in elem_ops: for full in full_ops: z = before_left(x, op, full) zz = run_opt_pass(z, transform.SimplifyExpr()) after = run_opt_pass(after_left(x, op, value), transform.InferType()) assert tvm.ir.structural_equal(zz, after) z = before_right(x, op, full) zz = run_opt_pass(z, transform.SimplifyExpr()) after = run_opt_pass(after_right(x, op, value), transform.InferType()) assert tvm.ir.structural_equal(zz, after) # Test the case in which x is broadcast to full's shape full_ops = [] if value == 0: full_ops.append(relay.zeros(shape * 2, dtype)) if value == 1: full_ops.append(relay.ones(shape * 2, dtype)) else: full_ops.append(relay.full(relay.const(value, dtype), shape * 2)) for op in elem_ops: for full in full_ops: z = before_left(x, op, full) zz = run_opt_pass(z, transform.SimplifyExpr()) after = run_opt_pass(before_left(x, op, full), transform.InferType()) assert tvm.ir.structural_equal(zz, after) z = before_right(x, op, full) zz = run_opt_pass(z, transform.SimplifyExpr()) after = run_opt_pass(before_right(x, op, full), transform.InferType()) assert tvm.ir.structural_equal(zz, after)
def test_concretize_ones_like(): dtype = "int32" shape_like = relay.var("shape_like", shape=(3, 4, 5), dtype=dtype) expr = relay.ones_like(shape_like) expected = run_infer_type(relay.ones((3, 4, 5), dtype)) actual = run_opt_pass(expr, relay.transform.SimplifyExpr()) assert tvm.ir.structural_equal(actual, expected)
def test_closure(): mod = tvm.IRModule() x = relay.var('x', shape=(2,)) y = relay.var('y', shape=(2,)) inner_func = relay.Function([x], x + y) outer_func = relay.Function([y], inner_func) clo = outer_func(relay.ones(shape=(2,), dtype="float32")) mod["main"] = relay.Function([], relay.Call(clo, [relay.zeros(shape=(2,), dtype="float32")])) new_mod = transform.LambdaLift()(mod) assert len(new_mod.functions) == 3
def test_ones(): """Simple test using "ones" op""" mod = tvm.IRModule() shape = (10, 10) dtype = "float32" t = relay.TensorType(shape, dtype) x = relay.var("x", t) y = relay.Function([x], x + relay.ones(shape, dtype)) mod["main"] = y mod = transform.InferType()(mod) mod = transform.LazyGradientInit()(mod) y = mod["main"] assert mod["main"].checked_type == relay.FuncType([t], t) x = rand(dtype, *shape) y = create_executor(mod=mod).evaluate(y)(x) assert_allclose(y.numpy(), x.numpy() + np.ones_like(x.numpy()))
def test_concretize_multiple(): x = relay.var("x", shape=(2, 3), dtype="float32") y = relay.var("y", shape=(3, ), dtype="float32") l = x + y dl = relay.ones_like(l) dx = relay.zeros_like(x) dy = relay.zeros_like(y) dx = dx + relay.collapse_sum_like(dl, dx) dy = dy + relay.collapse_sum_like(dl, dy) ret = relay.Tuple([dx, dy]) dl_c = relay.ones((2, 3), "float32") # NOTE: these are removed by EliminateIdentity # dx_c = relay.zeros((2, 3), "float32") # dy_c = relay.zeros((3,), "float32") dx_c = relay.collapse_sum_to(dl_c, (2, 3)) dy_c = relay.collapse_sum_to(dl_c, (3, )) ret_c = relay.Tuple([dx_c, dy_c]) expected = run_infer_type(ret_c) actual = run_opt_pass(ret, relay.transform.SimplifyExpr()) assert tvm.ir.structural_equal(actual, expected)