예제 #1
0
  def deployable_model(self, src_dir, used_for_xmodel=False):
    if used_for_xmodel:
      device = torch.device('cpu')
      inputs = self._inputs.to(device)
    else:
      device = self._device
      inputs = self._inputs

    model = copy.deepcopy(self._model)
    model.load_state_dict(
        torch.load(os.path.join(src_dir, _DEPLOYABLE_MODEL_NAME)))
    qprocessor = qproc.TorchQuantProcessor(
        'test',
        model,
        inputs,
        output_dir=src_dir,
        bitwidth_w=self._bitwidth,
        bitwidth_a=self._bitwidth,
        mix_bit=self._mix_bit,
        device=device)
    self._qprocessor = qprocessor
    if used_for_xmodel:
      logging.info(
          'Forward the deployable model with data of batch_size=1 in cpu mode to dump xmodel.'
      )
    return qprocessor.quant_model()
예제 #2
0
def fuse_conv_bn(model):
  model.apply(conv_fused.fuse_conv_bn)
  model.conv_bn_fused = True
  logging.info('Merge batchnorm to conv.')
예제 #3
0
def freeze_bn_stats(model):
  model.apply(conv_fused.freeze_bn_stats)
  logging.info('Running statistics of batch normlization has been frozen.')
예제 #4
0
def freeze_quant(model):
  model.apply(quantizer_mod.freeze_quant)
  logging.info('Scale of quantizer has been frozen.')
예제 #5
0
def enable_warmup(model):
  model.apply(quantizer_mod.enable_warmup)
  logging.info('Initialize quantizer.')
예제 #6
0
def disable_quant(model):
  model.apply(quantizer_mod.disable_quant)
  logging.info(
      'Disable quantization: floating point operations will be performed.')
예제 #7
0
def enable_quant(model):
  model.apply(quantizer_mod.enable_quant)
  logging.info('Enable quantization: quantized operations will be performed.')