def deploy_nnapi(dataloader: DataLoader,
                 model: nn.Module,
                 fuse: bool,
                 name: str,
                 backend: str = 'qnnpack'):
    model = deepcopy(model)
    torch.backends.quantized.engine = backend
    model.qconfig = torch.quantization.get_default_qconfig(backend)
    path = f'./{name}_nnapi'

    model = model.eval()
    if fuse:
        model.fuse()
        path += '_fused'

    model_prepared = torch.quantization.prepare(model)
    for sample in dataloader:
        model_prepared(sample)
    model_quantized = torch.quantization.convert(model_prepared)

    input_float = torch.rand(1, 3, 224, 224)

    quantizer = model_quantized.quant
    dequantizer = model_quantized.dequant
    model_quantized.quant = torch.nn.Identity()
    model_quantized.dequant = torch.nn.Identity()
    input_tensor = quantizer(input_float)

    input_tensor = input_tensor.contiguous(memory_format=torch.channels_last)
    input_tensor.nnapi_nhwc = True

    with torch.no_grad():
        model_quantized_traced = torch.jit.trace(model_quantized, input_tensor)
    nnapi_model = convert_model_to_nnapi(model_quantized_traced, input_tensor)
    nnapi_model_float_interface = torch.jit.script(
        torch.nn.Sequential(quantizer, nnapi_model, dequantizer))

    traced_path = path + '_traced.pt'
    traced_float_path = path + '_float_interface_traced.pt'
    nnapi_model.save(traced_path)
    nnapi_model_float_interface.save(traced_float_path)
예제 #2
0
 def check(
     self,
     module,
     arg_or_args,
     *,
     trace_args=None,
     convert_args=None,
     atol_rtol=None,
     limit=None,
 ):
     with torch.no_grad():
         if isinstance(arg_or_args, torch.Tensor):
             args = [arg_or_args]
         else:
             args = arg_or_args
         module.eval()
         traced = torch.jit.trace(module, trace_args or args)
         nnapi_module = convert_model_to_nnapi(traced, convert_args or args)
         if not self.can_run_nnapi:
             # Only test that the model was converted successfully.
             return
         eager_output = module(*args)
         nnapi_output = nnapi_module(*args)
         kwargs = {}
         if atol_rtol is not None:
             kwargs["atol"] = atol_rtol[0]
             kwargs["rtol"] = atol_rtol[1]
         self.assertEqual(eager_output, nnapi_output, **kwargs)
         if limit is not None:
             mismatches = \
                 eager_output.int_repr().to(torch.int32) - \
                 nnapi_output.int_repr().to(torch.int32)
             if mismatches.count_nonzero() > limit:
                 # Too many mismatches.  Re-run the check with no tolerance
                 # to get a nice message.
                 self.assertEqual(eager_output,
                                  nnapi_output,
                                  atol=0,
                                  rtol=0)
예제 #3
0
 def call_lowering_to_nnapi(self, traced_module, args):
     return convert_model_to_nnapi(traced_module, args)