def from_saved_model(cls, saved_model_dir, signature_keys=None, tags=None): """Creates a TFLiteConverter object from a SavedModel directory. Args: saved_model_dir: SavedModel directory to convert. signature_keys: List of keys identifying SignatureDef containing inputs and outputs. Elements should not be duplicated. By default the `signatures` attribute of the MetaGraphdef is used. (default saved_model.signatures) tags: Set of tags identifying the MetaGraphDef within the SavedModel to analyze. All tags in the tag set must be present. (default set(SERVING)) Returns: TFLiteConverter object. Raises: Invalid signature keys. """ # Ensures any graphs created in Eager mode are able to run. This is required # in order to create a tf.estimator.Exporter that exports a TFLite model. with context.eager_mode(): saved_model = _load(saved_model_dir, tags) if not signature_keys: signature_keys = saved_model.signatures funcs = [] for key in signature_keys: if key not in saved_model.signatures: raise ValueError( "Invalid signature key '{}' found. Valid keys are " "'{}'.".format(key, ",".join(saved_model.signatures))) funcs.append(saved_model.signatures[key]) return cls(funcs, saved_model)
def from_saved_model(cls, saved_model_dir, signature_keys=None, tags=None): """Creates a TFLiteConverter object from a SavedModel directory. Args: saved_model_dir: SavedModel directory to convert. signature_keys: List of keys identifying SignatureDef containing inputs and outputs. Elements should not be duplicated. By default the `signatures` attribute of the MetaGraphdef is used. (default saved_model.signatures) tags: Set of tags identifying the MetaGraphDef within the SavedModel to analyze. All tags in the tag set must be present. (default set(SERVING)) Returns: TFLiteConverter object. Raises: Invalid signature keys. """ # Ensures any graphs created in Eager mode are able to run. This is required # in order to create a tf.estimator.Exporter that exports a TFLite model. with context.eager_mode(): saved_model = _load(saved_model_dir, tags) if not signature_keys: signature_keys = saved_model.signatures funcs = [] for key in signature_keys: if key not in saved_model.signatures: raise ValueError("Invalid signature key '{}' found. Valid keys are " "'{}'.".format(key, ",".join(saved_model.signatures))) funcs.append(saved_model.signatures[key]) return cls(funcs, saved_model)