Beispiel #1
0
 def __enter__(self):
     if self.mixed_dtype == torch.bfloat16:
         core.enable_mix_bf16_fp32()
         core.disable_mix_int8_fp32()
     elif self.mixed_dtype == torch.int8:
         core.enable_mix_int8_fp32()
         core.disable_mix_bf16_fp32()
         if self.running_mode == 'inference':
             core.disable_int8_calibration()
         elif self.running_mode == 'calibration':
             core.enable_int8_calibration()
         else:
             assert False, 'int8 quantization only suport inference and calibration running mode'
     else:
         core.disable_mix_int8_fp32()
         core.disable_mix_bf16_fp32()
     core.set_execution_mode(
         train=True if self.running_mode == 'training' else False)
Beispiel #2
0
 def __exit__(self, *args):
     if self.mixed_dtype == torch.int8:
         if self.running_mode == 'calibration':
             core.calibration_reset()
     # restore previous state
     if self.pre_calibration_state:
         core.enable_int8_calibration()
     else:
         core.disable_int8_calibration()
     if self.pre_mixed_dtype == torch.bfloat16:
         core.enable_mix_bf16_fp32()
         core.disable_mix_int8_fp32()
     elif self.pre_mixed_dtype == torch.int8:
         core.enable_mix_int8_fp32()
         core.disable_mix_bf16_fp32()
     else:
         core.disable_mix_int8_fp32()
         core.disable_mix_bf16_fp32()
     core.set_execution_mode(train=self.pre_running_mode)
def enable_auto_mix_precision(mixed_dtype=torch.bfloat16,
                              train=False,
                              configure_file=None):
    if mixed_dtype == torch.bfloat16:
        core.enable_mix_bf16_fp32()
        core.disable_mix_int8_fp32()
    elif mixed_dtype == torch.int8 or mixed_dtype == torch.uint8:
        core.enable_mix_int8_fp32()
        core.disable_mix_bf16_fp32()
        if configure_file != None:
            core.disable_int8_calibration()
            f = open(configure_file)
            configures = json.load(f)
            core.load_indicators_file(configures)
        else:
            warnings.warn(
                "please not forget do calibration before doing validation step"
            )
    else:
        core.disable_mix_int8_fp32()
        core.disable_mix_bf16_fp32()
    core.set_execution_mode(train=train)
 def test_mix_bf16_fp32_train(self):
     self.assertFalse(ipex.get_train())
     ipex.set_execution_mode(train=True)
     self.assertTrue(ipex.get_train())
     ipex.set_execution_mode(train=False)
     self.assertFalse(ipex.get_train())