コード例 #1
0
def trace_(func, example_inputs, *args, **kwargs):
    # Disable mix precision. torch.jit.trace will check the traced output
    # against what is expected. Since mix precision will lead to
    # loss of accuracy, this will raise warning during torch.jit.trace
    mix_state = core.get_mix_bf16_fp32()
    core.disable_mix_bf16_fp32()
    jit_m = orig_trace(func, example_inputs, *args, **kwargs)
    if core.get_jit_opt() and hasattr(jit_m, '_c'):
        jit_m = wrap_cpp_module(torch._C._jit_pass_fold_convbn(jit_m._c))
    if mix_state:
        core.enable_mix_bf16_fp32()
    return jit_m
コード例 #2
0
def script_(obj, optimize=None, _frames_up=0, _rcb=None):
    torch.jit.script = orig_script
    jit_m = orig_script(obj, optimize=optimize, _frames_up=_frames_up+1, _rcb=_rcb)
    torch.jit.script = script_

    mix_state = torch.bfloat16 if core.get_mix_bf16_fp32() else torch.int8 if core.get_mix_int8_fp32() else None
    # Disable mix precision in model fusion, since mixed precision cannot
    # bring any benefits for inference, but will lead to loss of accuracy
    core.disable_mix_bf16_fp32()
    core.disable_mix_int8_fp32()
    if core.get_jit_opt() and hasattr(jit_m, '_c'):
        jit_m = wrap_cpp_module(torch._C._jit_pass_fold_convbn(jit_m._c))
    if mix_state == torch.bfloat16:
        core.enable_mix_bf16_fp32()
    elif mix_state == torch.int8:
        core.enable_mix_int8_fp32()
    return jit_m
コード例 #3
0
 def __enter__(self):
     if self.mixed_dtype == torch.bfloat16:
         core.enable_mix_bf16_fp32()
         core.disable_mix_int8_fp32()
     elif self.mixed_dtype == torch.int8:
         core.enable_mix_int8_fp32()
         core.disable_mix_bf16_fp32()
         if self.running_mode == 'inference':
             core.disable_int8_calibration()
         elif self.running_mode == 'calibration':
             core.enable_int8_calibration()
         else:
             assert False, 'int8 quantization only suport inference and calibration running mode'
     else:
         core.disable_mix_int8_fp32()
         core.disable_mix_bf16_fp32()
     core.set_execution_mode(
         train=True if self.running_mode == 'training' else False)
コード例 #4
0
 def __exit__(self, *args):
     if self.mixed_dtype == torch.int8:
         if self.running_mode == 'calibration':
             core.calibration_reset()
     # restore previous state
     if self.pre_calibration_state:
         core.enable_int8_calibration()
     else:
         core.disable_int8_calibration()
     if self.pre_mixed_dtype == torch.bfloat16:
         core.enable_mix_bf16_fp32()
         core.disable_mix_int8_fp32()
     elif self.pre_mixed_dtype == torch.int8:
         core.enable_mix_int8_fp32()
         core.disable_mix_bf16_fp32()
     else:
         core.disable_mix_int8_fp32()
         core.disable_mix_bf16_fp32()
     core.set_execution_mode(train=self.pre_running_mode)
コード例 #5
0
def enable_auto_mix_precision(mixed_dtype=torch.bfloat16,
                              train=False,
                              configure_file=None):
    if mixed_dtype == torch.bfloat16:
        core.enable_mix_bf16_fp32()
        core.disable_mix_int8_fp32()
    elif mixed_dtype == torch.int8 or mixed_dtype == torch.uint8:
        core.enable_mix_int8_fp32()
        core.disable_mix_bf16_fp32()
        if configure_file != None:
            core.disable_int8_calibration()
            f = open(configure_file)
            configures = json.load(f)
            core.load_indicators_file(configures)
        else:
            warnings.warn(
                "please not forget do calibration before doing validation step"
            )
    else:
        core.disable_mix_int8_fp32()
        core.disable_mix_bf16_fp32()
    core.set_execution_mode(train=train)
コード例 #6
0
 def test_mix_bf16_fp32(self):
     self.assertFalse(ipex.get_mix_bf16_fp32())
     ipex.enable_mix_bf16_fp32()
     self.assertTrue(ipex.get_mix_bf16_fp32())
     ipex.disable_mix_bf16_fp32()
     self.assertFalse(ipex.get_mix_bf16_fp32())
コード例 #7
0
def enable_auto_mix_precision(mixed_dtype=torch.bfloat16):
    if mixed_dtype == torch.bfloat16:
        core.enable_mix_bf16_fp32()
    else:
        core.disable_mix_bf16_fp32()