def verify_mixed_precision_output_close( mod: tvm.runtime.Module, mod_params: Dict[str, Any], mixed_precision_dtype="float16", rtol: float = 1e-3, atol: float = 0, keep_orig_output_dtype=False, ) -> tvm.runtime.Module: mod = InferType()(mod) result_fp32 = run_module(mod, mod_params) if not keep_orig_output_dtype: fp16_mod = ToMixedPrecision(mixed_precision_dtype)(mod) result_fp16 = run_module(fp16_mod, mod_params) else: with tvm.transform.PassContext( config={"relay.ToMixedPrecision.keep_orig_output_dtype": True }): fp16_mod = ToMixedPrecision(mixed_precision_dtype)(mod) result_fp16 = run_module(fp16_mod, mod_params) # Ensure the results are close for fp32, fp16 in zip(result_fp32, result_fp16): np.testing.assert_allclose(fp32, fp16, rtol=rtol, atol=atol) if keep_orig_output_dtype: assert (np.array(result_fp16).dtype == np.array(result_fp32).dtype ), "output type and original type mismatch" return fp16_mod
def test_do_not_convert_arange(): """Arange is a red listed operation and therefore should never be fp16.""" dtype = "float32" arange = relay.arange(relay.const(1, dtype), relay.const(128, dtype)) mod = tvm.IRModule.from_expr(arange) out_mod = ToMixedPrecision("float16")(mod) orig_mod = tvm.relay.transform.InferType()(mod) assert tvm.ir.structural_equal(orig_mod, out_mod)
def test_do_not_convert_softmax(): """Softmax is a red listed operation and therefore should never be fp16.""" shape = [1, 2, 3] a = relay.var("a", shape=shape) b = relay.nn.softmax(a) mod = tvm.IRModule.from_expr(b) mod = tvm.relay.transform.InferType()(mod) out_mod = ToMixedPrecision("float16")(mod) orig_mod = tvm.relay.transform.InferType()(mod) assert tvm.ir.structural_equal(orig_mod, out_mod)
def test_conv2d_bwd(): IC = 16 OC = 8 dshape = (16, IC, 32, 32) wshape = (OC, IC, 3, 3) padding = (0, 0) strides = (1, 1) conv = get_conv2d_nchw( dshape, wshape, padding, strides=strides, out_dtype="float32", data_dtype="float32", weight_dtype="float32", ) fwd_mod = InferType()(tvm.IRModule.from_expr(conv)) # Note: large difference in tvm and cutlass Wgrad results if use fp16. # Cutlass wgrad uses fp32 accumulation even if the output is fp16. use_fp16 = False verify_dgrad = False # False to verify wgrad tol = 1e-5 if verify_dgrad else 1e-4 # Wgrad slightly less accurate if use_fp16: fwd_mod = ToMixedPrecision("float16")(fwd_mod) fwd_bwd_func = FirstOrderGradient()(fwd_mod)["main"] bwd_func = relay.Function( fwd_bwd_func.params, relay.TupleGetItem(relay.TupleGetItem(fwd_bwd_func.body, 1), 0 if verify_dgrad else 1), ) verify_conv2d( bwd_func, bwd_func, dshape, wshape, sm=80, atol=1e-2 if use_fp16 else tol, rtol=1e-2 if use_fp16 else tol, use_cudnn_ref=False, data_dtype="float32", weight_dtype="float32", use_vm=True, )
def test_do_not_convert_summation(): """Ops that could involve a large summation are not allowed in fp16.""" shape = [1, 3, 16, 16] a = relay.var("a", shape=shape) ops = [ relay.sum, relay.mean, relay.nn.global_avg_pool2d, lambda inp: relay.nn.adaptive_avg_pool2d(inp, (1, 1)), ] for op in ops: mod = tvm.IRModule.from_expr(op(a)) out_mod = ToMixedPrecision("float16")(mod) orig_mod = tvm.relay.transform.InferType()(mod) assert tvm.ir.structural_equal(orig_mod, out_mod)
def verify_mixed_precision_output_close( mod: tvm.runtime.Module, mod_params: Dict[str, Any], mixed_precision_dtype="float16", rtol: float = 1e-3, atol: float = 0, ) -> tvm.runtime.Module: mod = InferType()(mod) result_fp32 = run_module(mod, mod_params) fp16_mod = ToMixedPrecision(mixed_precision_dtype)(mod) result_fp16 = run_module(fp16_mod, mod_params) # Ensure the results are close for fp32, fp16 in zip(result_fp32, result_fp16): np.testing.assert_allclose(fp32, fp16, rtol=rtol, atol=atol) return fp16_mod