def is_supported(self, path): module_def_path = get_module_proto_path(path) if not tf_v1.gfile.Exists(module_def_path): return False module_def_proto = module_def_pb2.ModuleDef() with tf_v1.gfile.Open(module_def_path, "rb") as f: module_def_proto.ParseFromString(f.read()) return module_def_proto.format == module_def_pb2.ModuleDef.FORMAT_V3
def _export(self, path, variables_saver): """Internal. Args: path: string where to export the module to. variables_saver: an unary-function that writes the module variables checkpoint on the given path. """ self._saved_model_handler.export(path, variables_saver=variables_saver) module_def_proto = module_def_pb2.ModuleDef() module_def_proto.format = module_def_pb2.ModuleDef.FORMAT_V3 module_def_filename = _get_module_proto_path(path) tf_utils.atomic_write_string_to_file( module_def_filename, module_def_proto.SerializeToString(), overwrite=False) tf.logging.info("Exported TF-Hub module to: %s", path)
def __call__(self, path): module_def_path = get_module_proto_path(path) module_def_proto = module_def_pb2.ModuleDef() with tf_v1.gfile.Open(module_def_path, "rb") as f: module_def_proto.ParseFromString(f.read()) if module_def_proto.format != module_def_pb2.ModuleDef.FORMAT_V3: raise ValueError("Unsupported module def format: %r" % module_def_proto.format) required_features = set(module_def_proto.required_features) unsupported_features = (required_features - _MODULE_V3_SUPPORTED_FEATURES) if unsupported_features: raise ValueError("Unsupported features: %r" % list(unsupported_features)) saved_model_handler = saved_model_lib.load(path) checkpoint_filename = saved_model_lib.get_variables_path(path) return _ModuleSpec(saved_model_handler, checkpoint_filename)
def load_module_spec(path): """Loads a ModuleSpec from the filesystem. Args: path: string describing the location of a module. There are several supported path encoding schemes: a) URL location specifying an archived module (e.g. http://domain/module.tgz) b) Any filesystem location of a module directory (e.g. /module_dir for a local filesystem). All filesystems implementations provided by Tensorflow are supported. Returns: A ModuleSpec. Raises: ValueError: on unexpected values in the module spec. tf.OpError: on file handling exceptions. """ path = compressed_module_resolver.get_default().get_module_path(path) module_def_path = _get_module_proto_path(path) module_def_proto = module_def_pb2.ModuleDef() with tf.gfile.Open(module_def_path, "rb") as f: module_def_proto.ParseFromString(f.read()) if module_def_proto.format != module_def_pb2.ModuleDef.FORMAT_V3: raise ValueError("Unsupported module def format: %r" % module_def_proto.format) required_features = set(module_def_proto.required_features) unsupported_features = (required_features - _MODULE_V3_SUPPORTED_FEATURES) if unsupported_features: raise ValueError("Unsupported features: %r" % list(unsupported_features)) saved_model_handler = saved_model_lib.load(path) checkpoint_filename = saved_model_lib.get_variables_path(path) return _ModuleSpec(saved_model_handler, checkpoint_filename)
def _get_module_def_proto(self, path): module_def_path = get_module_proto_path(path) module_def_proto = module_def_pb2.ModuleDef() with tf_v1.gfile.Open(module_def_path, "rb") as f: module_def_proto.ParseFromString(f.read()) return module_def_proto