コード例 #1
0
 def __enter__(self):
     if self.mixed_dtype == torch.bfloat16:
         core.enable_mix_bf16_fp32()
         core.disable_mix_int8_fp32()
     elif self.mixed_dtype == torch.int8:
         core.enable_mix_int8_fp32()
         core.disable_mix_bf16_fp32()
         if self.running_mode == 'inference':
             core.disable_int8_calibration()
         elif self.running_mode == 'calibration':
             core.enable_int8_calibration()
         else:
             assert False, 'int8 quantization only suport inference and calibration running mode'
     else:
         core.disable_mix_int8_fp32()
         core.disable_mix_bf16_fp32()
     core.set_execution_mode(
         train=True if self.running_mode == 'training' else False)
コード例 #2
0
 def __exit__(self, *args):
     if self.mixed_dtype == torch.int8:
         if self.running_mode == 'calibration':
             core.calibration_reset()
     # restore previous state
     if self.pre_calibration_state:
         core.enable_int8_calibration()
     else:
         core.disable_int8_calibration()
     if self.pre_mixed_dtype == torch.bfloat16:
         core.enable_mix_bf16_fp32()
         core.disable_mix_int8_fp32()
     elif self.pre_mixed_dtype == torch.int8:
         core.enable_mix_int8_fp32()
         core.disable_mix_bf16_fp32()
     else:
         core.disable_mix_int8_fp32()
         core.disable_mix_bf16_fp32()
     core.set_execution_mode(train=self.pre_running_mode)
コード例 #3
0
 def __enter__(self):
     if not core.get_mix_int8_fp32():
         raise ValueError(
             "please first run enable_auto_mix_precision(torch.int8) before int8 calibration"
         )
     core.enable_int8_calibration()