def test_simplify_consecutive_cast():
    x = relay.var("x", shape=(3, 4, 5), dtype="int8")
    y = relay.var("y", shape=(3, 4), dtype="int64")
    z = relay.var("z", shape=(3, ), dtype="float32")

    expr1 = relay.cast(x, "int16")
    expr2 = relay.cast(expr1, "int32")
    expr3 = relay.cast_like(expr2, y)
    expr4 = relay.cast_like(expr3, z)

    actual1 = run_opt_pass(expr2, relay.transform.SimplifyExpr())
    expected = run_infer_type(relay.cast(x, "int32"))
    assert tvm.ir.structural_equal(actual1, expected)
    actual2 = run_opt_pass(expr3, relay.transform.SimplifyExpr())
    expected = run_infer_type(relay.cast(x, "int64"))
    assert tvm.ir.structural_equal(actual2, expected)
    actual3 = run_opt_pass(expr4, relay.transform.SimplifyExpr())
    expected = run_infer_type(relay.cast(x, "float32"))
    assert tvm.ir.structural_equal(actual3, expected)

    # cannot simplify the narrow cast
    x = relay.var("x", shape=(3, 4, 5), dtype="float32")
    y = relay.var("y", shape=(3, 4), dtype="float32")
    expr1 = relay.cast(x, "int32")
    expr2 = relay.cast_like(expr1, y)
    actual = run_opt_pass(expr2, relay.transform.SimplifyExpr())
    expected = run_infer_type(expr2)
    assert tvm.ir.structural_equal(actual, expected)
def test_simplify_cast():
    dtype = "int32"
    data = relay.var("data", shape=(3, 4, 5), dtype=dtype)
    expr1 = relay.cast(data, dtype)
    dtype_like = relay.var("dtype_like", shape=(2, 2, 2), dtype=dtype)
    expr2 = relay.cast_like(data, dtype_like)

    expected = run_infer_type(data)
    actual1 = run_opt_pass(expr1, relay.transform.SimplifyExpr())
    assert tvm.ir.structural_equal(actual1, expected)
    actual2 = run_opt_pass(expr2, relay.transform.SimplifyExpr())
    assert tvm.ir.structural_equal(actual2, expected)
def test_cast_like_grad():
    data = relay.var("data", shape=(10, 4), dtype="float32")
    like = relay.var("like", shape=(1, ), dtype="float64")
    fwd_func = relay.Function([data, like], relay.cast_like(data, like))
    check_grad(fwd_func)