def TopicType(value): """Converts [NAME:]TOPIC format string to TopicInfo object. Args: value: a topic string, provided in [NAME:]TOPIC format. Returns: TopicInfo message (name=NAME, topic=TOPIC) Raises: TopicTypeError: when topic format is invalid """ topic_name_re = re.compile(r'^[a-zA-Z0-9-_]*$') # topic should not start with '$' and should not contain any of '+#,:' topic_re = re.compile(r'^[^$+#,:][^+#,:]*$') topic_parts = value.split(':') if len(topic_parts) > 2: raise TopicTypeError() if len(topic_parts) == 2: name, topic = topic_parts if len(topic_parts) == 1: name = '' topic = topic_parts[0] if not topic_name_re.match(name) or not topic_re.match(topic): raise TopicTypeError() messages = util.GetMessagesModule() return messages.TopicInfo(id=name, topic=topic)
def ProcessModelHook(ref, args, req): """Analyzes given model, and converts model if necessary. Args: ref: A resource ref to the parsed Edge ML Model resource, unused in this hook args: The parsed args namespace from CLI req: Created request for the API call Returns: req, with new model URI, input/out tensor information, accelerator type if applicable. Raises: InvalidFrameworkException: if framework is FRAMEWORK_UNSPECIFIED. This should not happen. """ del ref # Unused. edgeml_messages = edgeml_util.GetMessagesModule() model_types = edgeml_messages.AnalyzeModelResponse.ModelTypeValueValuesEnum tf_model_types = ( model_types.TENSORFLOW_LITE, model_types.TENSORFLOW_LITE_EDGE_TPU_OPTIMIZED, model_types.TENSORFLOW_SAVED_MODEL, ) edge_messages = edge_util.GetMessagesModule() framework_types = edge_messages.MlModel.FrameworkValueValuesEnum patch_req_type = ( edge_messages.EdgeProjectsLocationsRegistriesDevicesMlModelsPatchRequest) analyze_result = edgeml.EdgeMlClient().Analyze(req.mlModel.modelUri) if req.mlModel.framework == framework_types.TFLITE: if analyze_result.modelType not in tf_model_types: raise exceptions.InvalidArgumentException( '--framework', 'tflite provided for non-Tensorflow model.') _ProcessTensorflowModel(req.mlModel, args, analyze_result) if isinstance(req, patch_req_type): # updateMask should have some pre-filled values. update_fields = set(req.updateMask.split(',')) update_fields.update({ 'modelUri', 'acceleratorType', 'inputTensors', 'outputTensors' }) req.updateMask = ','.join(sorted(update_fields)) # Try to deploy as a scikit-learn model if it's not a TF model. elif req.mlModel.framework == framework_types.SCIKIT_LEARN: if analyze_result.modelType in tf_model_types: raise exceptions.InvalidArgumentException( '--framework', 'scikit-learn provided for Tensorflow model.') else: raise InvalidFrameworkException() # FRAMEWORK_UNSPECIFIED is not allowed. return req
def VolumeBindingType(value): """Verifies volume binding flag format, and returns VolumeBinding messages. Args: value: a volume binding string parsed by ArgList from CLI flag. Returns: VolumeBinding message Raises: VolumeBindingTypeError: when the format is invalid. """ binding_parts = value.split(':') if len(binding_parts) > 3: raise VolumeBindingTypeError() if len(binding_parts) == 3: source, destination, read_only = binding_parts if len(binding_parts) == 2: if binding_parts[1] in ['ro', 'rw']: source = destination = binding_parts[0] read_only = binding_parts[1] else: source, destination = binding_parts read_only = 'rw' if len(binding_parts) == 1: source = destination = binding_parts[0] read_only = 'rw' if not destination.startswith('/'): raise VolumeBindingTypeError( 'DESTINATION {0} is not a valid absolute path.'.format( destination)) if source and not source.startswith('/'): raise VolumeBindingTypeError( 'SOURCE {0} is not a valid absolute path.'.format(source)) if read_only not in ['ro', 'rw']: raise VolumeBindingTypeError( 'The last value should be "ro" for read-only volume, and' ' "rw" for writable volume.') messages = util.GetMessagesModule() return messages.VolumeBinding( source=source or destination, destination=destination, # read_only is one of 'ro' or 'rw' readOnly=(read_only == 'ro'))
def _ConvertTensorRef(edgeml_tensor_refs): """Converts edgeml.TensorRef[] to edge.TensorInfo[].""" edge_messages = edge_util.GetMessagesModule() inference_type = edge_messages.TensorInfo.InferenceTypeValueValuesEnum edge_tensor_infos = [] for tensor_ref in edgeml_tensor_refs: tensor_info = edge_messages.TensorInfo( index=tensor_ref.index, dimensions=tensor_ref.tensorInfo.dimensions, tensorName=tensor_ref.tensorInfo.tensorName, inferenceType=inference_type(tensor_ref.tensorInfo.inferenceType.name)) edge_tensor_infos.append(tensor_info) return edge_tensor_infos
def DeviceBindingType(value): """Verifies device binding flag format, and returns device binding list. Args: value: a device binding string parsed by ArgList from CLI flag. Returns: DeviceBinding message Raises: DeviceBindingTypeError: when the format is invalid. """ cgroup_perms_re = re.compile(r'^r?w?m?$') binding_parts = value.split(':') if len(binding_parts) > 3: raise DeviceBindingTypeError() if len(binding_parts) == 3: source, destination, cgroup_permissions = binding_parts if len(binding_parts) == 2: if cgroup_perms_re.match(binding_parts[1]): source = destination = binding_parts[0] cgroup_permissions = binding_parts[1] else: source, destination = binding_parts cgroup_permissions = 'rwm' if len(binding_parts) == 1: source = destination = binding_parts[0] cgroup_permissions = 'rwm' if not source.startswith('/'): raise DeviceBindingTypeError( 'SOURCE {0} is not a valid absolute path.'.format(source)) if destination and not destination.startswith('/'): raise DeviceBindingTypeError( 'DESTINATION {0} is not a valid absolute path.'.format( destination)) if not cgroup_perms_re.match(cgroup_permissions): raise DeviceBindingTypeError( 'CGROUP_PERMS should be a combination of the following flags' ' in order: "r/w/m."') messages = util.GetMessagesModule() return messages.DeviceBinding(source=source, destination=destination or source, cgroupPermissions=cgroup_permissions)
def _ProcessTensorflowModel(model, args, analyze_result): """Processes Tensorflow (Lite) model according to analyze result. Args: model: edge.MlModel message from request args: The parsed args namespace from CLI analyze_result: edgeml.AnalyzeModelResponse from Analyze method call. Raises: UncompilableModelException: if given model cannot be optimized for Edge TPU. """ client = edgeml.EdgeMlClient() edgeml_messages = edgeml_util.GetMessagesModule() edge_messages = edge_util.GetMessagesModule() model_types = edgeml_messages.AnalyzeModelResponse.ModelTypeValueValuesEnum accelerator_types = edge_messages.MlModel.AcceleratorTypeValueValuesEnum model_type = analyze_result.modelType model_signature = analyze_result.modelSignature edgetpu_compiliability = analyze_result.edgeTpuCompilability # Convert method converts TF SavedModel to TF Lite model. if model_type == model_types.TENSORFLOW_SAVED_MODEL: convert_result, model.modelUri = client.Convert(model.modelUri) model_signature = convert_result.modelSignature edgetpu_compiliability = convert_result.edgeTpuCompilability model_type = model_types.TENSORFLOW_LITE if model_type == model_types.TENSORFLOW_LITE: # Always use accelerator value from command line, and ignore previous # acceleratorType of the model. if args.accelerator == 'tpu': if edgetpu_compiliability.uncompilableReason: raise UncompilableModelException( edgetpu_compiliability.uncompilableReason) compile_result, model.modelUri = client.Compile(model.modelUri) model_signature = compile_result.modelSignature model_type = model_types.TENSORFLOW_LITE_EDGE_TPU_OPTIMIZED if model_type == model_types.TENSORFLOW_LITE_EDGE_TPU_OPTIMIZED: if args.IsSpecified('accelerator') and args.accelerator != 'tpu': raise exceptions.InvalidArgumentException( '--accelerator', 'TPU should be provided for Edge TPU optimized model.') if not args.IsSpecified('accelerator'): log.info('Setting accelerator to TPU for Edge TPU model.') model.acceleratorType = accelerator_types.TPU _FillModelSignature(model, model_signature)
def ParseSamplingInfo(path): messages = edge_util.GetMessagesModule() sampling_info = cloudbuild_util.LoadMessageFromPath( path, messages.MlSamplingInfo, 'Edge ML sampling info') return sampling_info
def ParseTopicBridgingTable(unused_ref, args, req): messages = util.GetMessagesModule() parsed_table = cloudbuild_util.LoadMessageFromPath( args.rule_file, messages.TopicBridgingTable, 'topic bridging table') req.rules = parsed_table.rules return req