コード例 #1
0
 def __exit__(self, *args):
     core.disable_int8_calibration()
     core.add_indicators()
     configures = core.get_int8_configures()
     with open(self.configure_file, 'w') as fp:
         json.dump(configures, fp, indent=4)
     return False
コード例 #2
0
 def __enter__(self):
     if self.mixed_dtype == torch.bfloat16:
         core.enable_mix_bf16_fp32()
         core.disable_mix_int8_fp32()
     elif self.mixed_dtype == torch.int8:
         core.enable_mix_int8_fp32()
         core.disable_mix_bf16_fp32()
         if self.running_mode == 'inference':
             core.disable_int8_calibration()
         elif self.running_mode == 'calibration':
             core.enable_int8_calibration()
         else:
             assert False, 'int8 quantization only suport inference and calibration running mode'
     else:
         core.disable_mix_int8_fp32()
         core.disable_mix_bf16_fp32()
     core.set_execution_mode(
         train=True if self.running_mode == 'training' else False)
コード例 #3
0
 def __exit__(self, *args):
     if self.mixed_dtype == torch.int8:
         if self.running_mode == 'calibration':
             core.calibration_reset()
     # restore previous state
     if self.pre_calibration_state:
         core.enable_int8_calibration()
     else:
         core.disable_int8_calibration()
     if self.pre_mixed_dtype == torch.bfloat16:
         core.enable_mix_bf16_fp32()
         core.disable_mix_int8_fp32()
     elif self.pre_mixed_dtype == torch.int8:
         core.enable_mix_int8_fp32()
         core.disable_mix_bf16_fp32()
     else:
         core.disable_mix_int8_fp32()
         core.disable_mix_bf16_fp32()
     core.set_execution_mode(train=self.pre_running_mode)
コード例 #4
0
def enable_auto_mix_precision(mixed_dtype=torch.bfloat16,
                              train=False,
                              configure_file=None):
    if mixed_dtype == torch.bfloat16:
        core.enable_mix_bf16_fp32()
        core.disable_mix_int8_fp32()
    elif mixed_dtype == torch.int8 or mixed_dtype == torch.uint8:
        core.enable_mix_int8_fp32()
        core.disable_mix_bf16_fp32()
        if configure_file != None:
            core.disable_int8_calibration()
            f = open(configure_file)
            configures = json.load(f)
            core.load_indicators_file(configures)
        else:
            warnings.warn(
                "please not forget do calibration before doing validation step"
            )
    else:
        core.disable_mix_int8_fp32()
        core.disable_mix_bf16_fp32()
    core.set_execution_mode(train=train)