def trace_(func, example_inputs, *args, **kwargs): # Disable mix precision. torch.jit.trace will check the traced output # against what is expected. Since mix precision will lead to # loss of accuracy, this will raise warning during torch.jit.trace mix_state = core.get_mix_bf16_fp32() core.disable_mix_bf16_fp32() jit_m = orig_trace(func, example_inputs, *args, **kwargs) if core.get_jit_opt() and hasattr(jit_m, '_c'): jit_m = wrap_cpp_module(torch._C._jit_pass_fold_convbn(jit_m._c)) if mix_state: core.enable_mix_bf16_fp32() return jit_m
def script_(obj, optimize=None, _frames_up=0, _rcb=None): torch.jit.script = orig_script jit_m = orig_script(obj, optimize=optimize, _frames_up=_frames_up+1, _rcb=_rcb) torch.jit.script = script_ mix_state = torch.bfloat16 if core.get_mix_bf16_fp32() else torch.int8 if core.get_mix_int8_fp32() else None # Disable mix precision in model fusion, since mixed precision cannot # bring any benefits for inference, but will lead to loss of accuracy core.disable_mix_bf16_fp32() core.disable_mix_int8_fp32() if core.get_jit_opt() and hasattr(jit_m, '_c'): jit_m = wrap_cpp_module(torch._C._jit_pass_fold_convbn(jit_m._c)) if mix_state == torch.bfloat16: core.enable_mix_bf16_fp32() elif mix_state == torch.int8: core.enable_mix_int8_fp32() return jit_m
def __enter__(self): if self.mixed_dtype == torch.bfloat16: core.enable_mix_bf16_fp32() core.disable_mix_int8_fp32() elif self.mixed_dtype == torch.int8: core.enable_mix_int8_fp32() core.disable_mix_bf16_fp32() if self.running_mode == 'inference': core.disable_int8_calibration() elif self.running_mode == 'calibration': core.enable_int8_calibration() else: assert False, 'int8 quantization only suport inference and calibration running mode' else: core.disable_mix_int8_fp32() core.disable_mix_bf16_fp32() core.set_execution_mode( train=True if self.running_mode == 'training' else False)
def __exit__(self, *args): if self.mixed_dtype == torch.int8: if self.running_mode == 'calibration': core.calibration_reset() # restore previous state if self.pre_calibration_state: core.enable_int8_calibration() else: core.disable_int8_calibration() if self.pre_mixed_dtype == torch.bfloat16: core.enable_mix_bf16_fp32() core.disable_mix_int8_fp32() elif self.pre_mixed_dtype == torch.int8: core.enable_mix_int8_fp32() core.disable_mix_bf16_fp32() else: core.disable_mix_int8_fp32() core.disable_mix_bf16_fp32() core.set_execution_mode(train=self.pre_running_mode)
def enable_auto_mix_precision(mixed_dtype=torch.bfloat16, train=False, configure_file=None): if mixed_dtype == torch.bfloat16: core.enable_mix_bf16_fp32() core.disable_mix_int8_fp32() elif mixed_dtype == torch.int8 or mixed_dtype == torch.uint8: core.enable_mix_int8_fp32() core.disable_mix_bf16_fp32() if configure_file != None: core.disable_int8_calibration() f = open(configure_file) configures = json.load(f) core.load_indicators_file(configures) else: warnings.warn( "please not forget do calibration before doing validation step" ) else: core.disable_mix_int8_fp32() core.disable_mix_bf16_fp32() core.set_execution_mode(train=train)
def test_mix_bf16_fp32(self): self.assertFalse(ipex.get_mix_bf16_fp32()) ipex.enable_mix_bf16_fp32() self.assertTrue(ipex.get_mix_bf16_fp32()) ipex.disable_mix_bf16_fp32() self.assertFalse(ipex.get_mix_bf16_fp32())
def enable_auto_mix_precision(mixed_dtype=torch.bfloat16): if mixed_dtype == torch.bfloat16: core.enable_mix_bf16_fp32() else: core.disable_mix_bf16_fp32()