def verify_mixed_precision_output_close(
    mod: tvm.runtime.Module,
    mod_params: Dict[str, Any],
    mixed_precision_dtype="float16",
    rtol: float = 1e-3,
    atol: float = 0,
    keep_orig_output_dtype=False,
) -> tvm.runtime.Module:

    mod = InferType()(mod)
    result_fp32 = run_module(mod, mod_params)

    if not keep_orig_output_dtype:
        fp16_mod = ToMixedPrecision(mixed_precision_dtype)(mod)
        result_fp16 = run_module(fp16_mod, mod_params)
    else:
        with tvm.transform.PassContext(
                config={"relay.ToMixedPrecision.keep_orig_output_dtype": True
                        }):
            fp16_mod = ToMixedPrecision(mixed_precision_dtype)(mod)
            result_fp16 = run_module(fp16_mod, mod_params)

    # Ensure the results are close
    for fp32, fp16 in zip(result_fp32, result_fp16):
        np.testing.assert_allclose(fp32, fp16, rtol=rtol, atol=atol)

    if keep_orig_output_dtype:
        assert (np.array(result_fp16).dtype == np.array(result_fp32).dtype
                ), "output type and original type mismatch"

    return fp16_mod
def test_do_not_convert_arange():
    """Arange is a red listed operation and therefore should never be fp16."""
    dtype = "float32"
    arange = relay.arange(relay.const(1, dtype), relay.const(128, dtype))
    mod = tvm.IRModule.from_expr(arange)
    out_mod = ToMixedPrecision("float16")(mod)
    orig_mod = tvm.relay.transform.InferType()(mod)
    assert tvm.ir.structural_equal(orig_mod, out_mod)
def test_do_not_convert_softmax():
    """Softmax is a red listed operation and therefore should never be fp16."""
    shape = [1, 2, 3]
    a = relay.var("a", shape=shape)
    b = relay.nn.softmax(a)
    mod = tvm.IRModule.from_expr(b)
    mod = tvm.relay.transform.InferType()(mod)
    out_mod = ToMixedPrecision("float16")(mod)
    orig_mod = tvm.relay.transform.InferType()(mod)
    assert tvm.ir.structural_equal(orig_mod, out_mod)
示例#4
0
def test_conv2d_bwd():
    IC = 16
    OC = 8
    dshape = (16, IC, 32, 32)
    wshape = (OC, IC, 3, 3)
    padding = (0, 0)
    strides = (1, 1)

    conv = get_conv2d_nchw(
        dshape,
        wshape,
        padding,
        strides=strides,
        out_dtype="float32",
        data_dtype="float32",
        weight_dtype="float32",
    )
    fwd_mod = InferType()(tvm.IRModule.from_expr(conv))

    # Note: large difference in tvm and cutlass Wgrad results if use fp16.
    # Cutlass wgrad uses fp32 accumulation even if the output is fp16.
    use_fp16 = False
    verify_dgrad = False  # False to verify wgrad
    tol = 1e-5 if verify_dgrad else 1e-4  # Wgrad slightly less accurate

    if use_fp16:
        fwd_mod = ToMixedPrecision("float16")(fwd_mod)

    fwd_bwd_func = FirstOrderGradient()(fwd_mod)["main"]

    bwd_func = relay.Function(
        fwd_bwd_func.params,
        relay.TupleGetItem(relay.TupleGetItem(fwd_bwd_func.body, 1),
                           0 if verify_dgrad else 1),
    )

    verify_conv2d(
        bwd_func,
        bwd_func,
        dshape,
        wshape,
        sm=80,
        atol=1e-2 if use_fp16 else tol,
        rtol=1e-2 if use_fp16 else tol,
        use_cudnn_ref=False,
        data_dtype="float32",
        weight_dtype="float32",
        use_vm=True,
    )
def test_do_not_convert_summation():
    """Ops that could involve a large summation are not allowed in fp16."""
    shape = [1, 3, 16, 16]
    a = relay.var("a", shape=shape)
    ops = [
        relay.sum,
        relay.mean,
        relay.nn.global_avg_pool2d,
        lambda inp: relay.nn.adaptive_avg_pool2d(inp, (1, 1)),
    ]
    for op in ops:
        mod = tvm.IRModule.from_expr(op(a))
        out_mod = ToMixedPrecision("float16")(mod)
        orig_mod = tvm.relay.transform.InferType()(mod)
        assert tvm.ir.structural_equal(orig_mod, out_mod)
示例#6
0
def verify_mixed_precision_output_close(
    mod: tvm.runtime.Module,
    mod_params: Dict[str, Any],
    mixed_precision_dtype="float16",
    rtol: float = 1e-3,
    atol: float = 0,
) -> tvm.runtime.Module:

    mod = InferType()(mod)
    result_fp32 = run_module(mod, mod_params)
    fp16_mod = ToMixedPrecision(mixed_precision_dtype)(mod)
    result_fp16 = run_module(fp16_mod, mod_params)
    # Ensure the results are close
    for fp32, fp16 in zip(result_fp32, result_fp16):
        np.testing.assert_allclose(fp32, fp16, rtol=rtol, atol=atol)

    return fp16_mod