Beispiel #1
0
def ProcessModelHook(ref, args, req):
  """Analyzes given model, and converts model if necessary.

  Args:
    ref: A resource ref to the parsed Edge ML Model resource,
      unused in this hook
    args: The parsed args namespace from CLI
    req: Created request for the API call

  Returns:
    req, with new model URI, input/out tensor information, accelerator type
        if applicable.

  Raises:
    InvalidFrameworkException: if framework is FRAMEWORK_UNSPECIFIED.
      This should not happen.
  """
  del ref  # Unused.
  edgeml_messages = edgeml_util.GetMessagesModule()
  model_types = edgeml_messages.AnalyzeModelResponse.ModelTypeValueValuesEnum
  tf_model_types = (
      model_types.TENSORFLOW_LITE,
      model_types.TENSORFLOW_LITE_EDGE_TPU_OPTIMIZED,
      model_types.TENSORFLOW_SAVED_MODEL,
  )

  edge_messages = edge_util.GetMessagesModule()
  framework_types = edge_messages.MlModel.FrameworkValueValuesEnum
  patch_req_type = (
      edge_messages.EdgeProjectsLocationsRegistriesDevicesMlModelsPatchRequest)

  analyze_result = edgeml.EdgeMlClient().Analyze(req.mlModel.modelUri)

  if req.mlModel.framework == framework_types.TFLITE:

    if analyze_result.modelType not in tf_model_types:
      raise exceptions.InvalidArgumentException(
          '--framework', 'tflite provided for non-Tensorflow model.')

    _ProcessTensorflowModel(req.mlModel, args, analyze_result)

    if isinstance(req, patch_req_type):
      # updateMask should have some pre-filled values.
      update_fields = set(req.updateMask.split(','))
      update_fields.update({
          'modelUri', 'acceleratorType', 'inputTensors', 'outputTensors'
      })
      req.updateMask = ','.join(sorted(update_fields))

  # Try to deploy as a scikit-learn model if it's not a TF model.
  elif req.mlModel.framework == framework_types.SCIKIT_LEARN:
    if analyze_result.modelType in tf_model_types:
      raise exceptions.InvalidArgumentException(
          '--framework', 'scikit-learn provided for Tensorflow model.')

  else:
    raise InvalidFrameworkException()  # FRAMEWORK_UNSPECIFIED is not allowed.

  return req
Beispiel #2
0
def _ProcessTensorflowModel(model, args, analyze_result):
    """Processes Tensorflow (Lite) model according to analyze result.

  Args:
    model: edge.MlModel message from request
    args: The parsed args namespace from CLI
    analyze_result: edgeml.AnalyzeModelResponse from Analyze method call.

  Raises:
    UncompilableModelException: if given model cannot be optimized for Edge TPU.
  """
    client = edgeml.EdgeMlClient()
    edgeml_messages = edgeml_util.GetMessagesModule()
    edge_messages = edge_util.GetMessagesModule()
    model_types = edgeml_messages.AnalyzeModelResponse.ModelTypeValueValuesEnum
    accelerator_types = edge_messages.MlModel.AcceleratorTypeValueValuesEnum

    model_type = analyze_result.modelType
    model_signature = analyze_result.modelSignature
    edgetpu_compiliability = analyze_result.edgeTpuCompilability

    # Convert method converts TF SavedModel to TF Lite model.
    if model_type == model_types.TENSORFLOW_SAVED_MODEL:
        convert_result, model.modelUri = client.Convert(model.modelUri)
        model_signature = convert_result.modelSignature
        edgetpu_compiliability = convert_result.edgeTpuCompilability
        model_type = model_types.TENSORFLOW_LITE

    if model_type == model_types.TENSORFLOW_LITE:
        # Always use accelerator value from command line, and ignore previous
        # acceleratorType of the model.
        if args.accelerator == 'tpu':
            if edgetpu_compiliability.uncompilableReason:
                raise UncompilableModelException(
                    edgetpu_compiliability.uncompilableReason)

            compile_result, model.modelUri = client.Compile(model.modelUri)
            model_signature = compile_result.modelSignature
            model_type = model_types.TENSORFLOW_LITE_EDGE_TPU_OPTIMIZED

    if model_type == model_types.TENSORFLOW_LITE_EDGE_TPU_OPTIMIZED:
        if args.IsSpecified('accelerator') and args.accelerator != 'tpu':
            raise exceptions.InvalidArgumentException(
                '--accelerator',
                'TPU should be provided for Edge TPU optimized model.')
        if not args.IsSpecified('accelerator'):
            log.info('Setting accelerator to TPU for Edge TPU model.')
            model.acceleratorType = accelerator_types.TPU

    _FillModelSignature(model, model_signature)