def __enter__(self): if self.mixed_dtype == torch.bfloat16: core.enable_mix_bf16_fp32() core.disable_mix_int8_fp32() elif self.mixed_dtype == torch.int8: core.enable_mix_int8_fp32() core.disable_mix_bf16_fp32() if self.running_mode == 'inference': core.disable_int8_calibration() elif self.running_mode == 'calibration': core.enable_int8_calibration() else: assert False, 'int8 quantization only suport inference and calibration running mode' else: core.disable_mix_int8_fp32() core.disable_mix_bf16_fp32() core.set_execution_mode( train=True if self.running_mode == 'training' else False)
def __exit__(self, *args): if self.mixed_dtype == torch.int8: if self.running_mode == 'calibration': core.calibration_reset() # restore previous state if self.pre_calibration_state: core.enable_int8_calibration() else: core.disable_int8_calibration() if self.pre_mixed_dtype == torch.bfloat16: core.enable_mix_bf16_fp32() core.disable_mix_int8_fp32() elif self.pre_mixed_dtype == torch.int8: core.enable_mix_int8_fp32() core.disable_mix_bf16_fp32() else: core.disable_mix_int8_fp32() core.disable_mix_bf16_fp32() core.set_execution_mode(train=self.pre_running_mode)
def __enter__(self): if not core.get_mix_int8_fp32(): raise ValueError( "please first run enable_auto_mix_precision(torch.int8) before int8 calibration" ) core.enable_int8_calibration()