def test_simplify_consecutive_cast(): x = relay.var("x", shape=(3, 4, 5), dtype="int8") y = relay.var("y", shape=(3, 4), dtype="int64") z = relay.var("z", shape=(3, ), dtype="float32") expr1 = relay.cast(x, "int16") expr2 = relay.cast(expr1, "int32") expr3 = relay.cast_like(expr2, y) expr4 = relay.cast_like(expr3, z) actual1 = run_opt_pass(expr2, relay.transform.SimplifyExpr()) expected = run_infer_type(relay.cast(x, "int32")) assert tvm.ir.structural_equal(actual1, expected) actual2 = run_opt_pass(expr3, relay.transform.SimplifyExpr()) expected = run_infer_type(relay.cast(x, "int64")) assert tvm.ir.structural_equal(actual2, expected) actual3 = run_opt_pass(expr4, relay.transform.SimplifyExpr()) expected = run_infer_type(relay.cast(x, "float32")) assert tvm.ir.structural_equal(actual3, expected) # cannot simplify the narrow cast x = relay.var("x", shape=(3, 4, 5), dtype="float32") y = relay.var("y", shape=(3, 4), dtype="float32") expr1 = relay.cast(x, "int32") expr2 = relay.cast_like(expr1, y) actual = run_opt_pass(expr2, relay.transform.SimplifyExpr()) expected = run_infer_type(expr2) assert tvm.ir.structural_equal(actual, expected)
def test_simplify_cast(): dtype = "int32" data = relay.var("data", shape=(3, 4, 5), dtype=dtype) expr1 = relay.cast(data, dtype) dtype_like = relay.var("dtype_like", shape=(2, 2, 2), dtype=dtype) expr2 = relay.cast_like(data, dtype_like) expected = run_infer_type(data) actual1 = run_opt_pass(expr1, relay.transform.SimplifyExpr()) assert tvm.ir.structural_equal(actual1, expected) actual2 = run_opt_pass(expr2, relay.transform.SimplifyExpr()) assert tvm.ir.structural_equal(actual2, expected)
def test_cast_like_grad(): data = relay.var("data", shape=(10, 4), dtype="float32") like = relay.var("like", shape=(1, ), dtype="float64") fwd_func = relay.Function([data, like], relay.cast_like(data, like)) check_grad(fwd_func)