예제 #1
0
  def is_supported(self, path):
    module_def_path = get_module_proto_path(path)
    if not tf_v1.gfile.Exists(module_def_path):
      return False

    module_def_proto = module_def_pb2.ModuleDef()
    with tf_v1.gfile.Open(module_def_path, "rb") as f:
      module_def_proto.ParseFromString(f.read())

    return module_def_proto.format == module_def_pb2.ModuleDef.FORMAT_V3
예제 #2
0
    def _export(self, path, variables_saver):
        """Internal.

    Args:
      path: string where to export the module to.
      variables_saver: an unary-function that writes the module variables
        checkpoint on the given path.
    """
        self._saved_model_handler.export(path, variables_saver=variables_saver)

        module_def_proto = module_def_pb2.ModuleDef()
        module_def_proto.format = module_def_pb2.ModuleDef.FORMAT_V3
        module_def_filename = _get_module_proto_path(path)
        tf_utils.atomic_write_string_to_file(
            module_def_filename,
            module_def_proto.SerializeToString(),
            overwrite=False)
        tf.logging.info("Exported TF-Hub module to: %s", path)
예제 #3
0
  def __call__(self, path):
    module_def_path = get_module_proto_path(path)
    module_def_proto = module_def_pb2.ModuleDef()
    with tf_v1.gfile.Open(module_def_path, "rb") as f:
      module_def_proto.ParseFromString(f.read())

    if module_def_proto.format != module_def_pb2.ModuleDef.FORMAT_V3:
      raise ValueError("Unsupported module def format: %r" %
                       module_def_proto.format)

    required_features = set(module_def_proto.required_features)
    unsupported_features = (required_features - _MODULE_V3_SUPPORTED_FEATURES)

    if unsupported_features:
      raise ValueError("Unsupported features: %r" % list(unsupported_features))

    saved_model_handler = saved_model_lib.load(path)
    checkpoint_filename = saved_model_lib.get_variables_path(path)
    return _ModuleSpec(saved_model_handler, checkpoint_filename)
예제 #4
0
def load_module_spec(path):
    """Loads a ModuleSpec from the filesystem.

  Args:
    path: string describing the location of a module. There are several
          supported path encoding schemes:
          a) URL location specifying an archived module
            (e.g. http://domain/module.tgz)
          b) Any filesystem location of a module directory (e.g. /module_dir
             for a local filesystem). All filesystems implementations provided
             by Tensorflow are supported.

  Returns:
    A ModuleSpec.

  Raises:
    ValueError: on unexpected values in the module spec.
    tf.OpError: on file handling exceptions.
  """
    path = compressed_module_resolver.get_default().get_module_path(path)
    module_def_path = _get_module_proto_path(path)
    module_def_proto = module_def_pb2.ModuleDef()
    with tf.gfile.Open(module_def_path, "rb") as f:
        module_def_proto.ParseFromString(f.read())

    if module_def_proto.format != module_def_pb2.ModuleDef.FORMAT_V3:
        raise ValueError("Unsupported module def format: %r" %
                         module_def_proto.format)

    required_features = set(module_def_proto.required_features)
    unsupported_features = (required_features - _MODULE_V3_SUPPORTED_FEATURES)

    if unsupported_features:
        raise ValueError("Unsupported features: %r" %
                         list(unsupported_features))

    saved_model_handler = saved_model_lib.load(path)
    checkpoint_filename = saved_model_lib.get_variables_path(path)
    return _ModuleSpec(saved_model_handler, checkpoint_filename)
예제 #5
0
 def _get_module_def_proto(self, path):
     module_def_path = get_module_proto_path(path)
     module_def_proto = module_def_pb2.ModuleDef()
     with tf_v1.gfile.Open(module_def_path, "rb") as f:
         module_def_proto.ParseFromString(f.read())
     return module_def_proto