def get_auto_mix_precision():
    if core.get_mix_bf16_fp32():
        return torch.bfloat16
    elif core.get_mix_int8_fp32():
        return torch.int8
    else:
        return None
def trace_(func, example_inputs, *args, **kwargs):
    # Disable mix precision. torch.jit.trace will check the traced output
    # against what is expected. Since mix precision will lead to
    # loss of accuracy, this will raise warning during torch.jit.trace
    mix_state = core.get_mix_bf16_fp32()
    core.disable_mix_bf16_fp32()
    jit_m = orig_trace(func, example_inputs, *args, **kwargs)
    if core.get_jit_opt() and hasattr(jit_m, '_c'):
        jit_m = wrap_cpp_module(torch._C._jit_pass_fold_convbn(jit_m._c))
    if mix_state:
        core.enable_mix_bf16_fp32()
    return jit_m
def script_(obj, optimize=None, _frames_up=0, _rcb=None):
    torch.jit.script = orig_script
    jit_m = orig_script(obj, optimize=optimize, _frames_up=_frames_up+1, _rcb=_rcb)
    torch.jit.script = script_

    mix_state = torch.bfloat16 if core.get_mix_bf16_fp32() else torch.int8 if core.get_mix_int8_fp32() else None
    # Disable mix precision in model fusion, since mixed precision cannot
    # bring any benefits for inference, but will lead to loss of accuracy
    core.disable_mix_bf16_fp32()
    core.disable_mix_int8_fp32()
    if core.get_jit_opt() and hasattr(jit_m, '_c'):
        jit_m = wrap_cpp_module(torch._C._jit_pass_fold_convbn(jit_m._c))
    if mix_state == torch.bfloat16:
        core.enable_mix_bf16_fp32()
    elif mix_state == torch.int8:
        core.enable_mix_int8_fp32()
    return jit_m
 def test_mix_bf16_fp32(self):
     self.assertFalse(ipex.get_mix_bf16_fp32())
     ipex.enable_mix_bf16_fp32()
     self.assertTrue(ipex.get_mix_bf16_fp32())
     ipex.disable_mix_bf16_fp32()
     self.assertFalse(ipex.get_mix_bf16_fp32())