Exemple #1
0
def load_file_system_library(library_filename):
  """Loads a TensorFlow plugin, containing file system implementation.

  Pass `library_filename` to a platform-specific mechanism for dynamically
  loading a library. The rules for determining the exact location of the
  library are platform-specific and are not documented here.

  Args:
    library_filename: Path to the plugin.
      Relative or absolute filesystem path to a dynamic library file.

  Returns:
    None.

  Raises:
    RuntimeError: when unable to load the library.
  """
  status = py_tf.TF_NewStatus()
  lib_handle = py_tf.TF_LoadLibrary(library_filename, status)
  try:
    error_code = py_tf.TF_GetCode(status)
    if error_code != 0:
      error_msg = compat.as_text(py_tf.TF_Message(status))
      # pylint: disable=protected-access
      raise errors._make_specific_exception(None, None, error_msg, error_code)
      # pylint: enable=protected-access
  finally:
    py_tf.TF_DeleteStatus(status)
Exemple #2
0
def load_op_library(library_filename):
  """Loads a TensorFlow plugin, containing custom ops and kernels.

  Pass "library_filename" to a platform-specific mechanism for dynamically
  loading a library. The rules for determining the exact location of the
  library are platform-specific and are not documented here. When the
  library is loaded, ops and kernels registered in the library via the
  REGISTER_* macros are made available in the TensorFlow process. Note
  that ops with the same name as an existing op are rejected and not
  registered with the process.

  Args:
    library_filename: Path to the plugin.
      Relative or absolute filesystem path to a dynamic library file.

  Returns:
    A python module containing the Python wrappers for Ops defined in
    the plugin.

  Raises:
    RuntimeError: when unable to load the library or get the python wrappers.
  """
  status = py_tf.TF_NewStatus()

  lib_handle = py_tf.TF_LoadLibrary(library_filename, status)
  try:
    error_code = py_tf.TF_GetCode(status)
    if error_code != 0:
      error_msg = compat.as_text(py_tf.TF_Message(status))
      with _OP_LIBRARY_MAP_LOCK:
        if (error_code == error_codes_pb2.ALREADY_EXISTS and
            'has already been loaded' in error_msg and
            library_filename in _OP_LIBRARY_MAP):
          return _OP_LIBRARY_MAP[library_filename]
      # pylint: disable=protected-access
      raise errors._make_specific_exception(None, None, error_msg, error_code)
      # pylint: enable=protected-access
  finally:
    py_tf.TF_DeleteStatus(status)

  op_list_str = py_tf.TF_GetOpList(lib_handle)
  op_list = op_def_pb2.OpList()
  op_list.ParseFromString(compat.as_bytes(op_list_str))
  wrappers = py_tf.GetPythonWrappers(op_list_str)

  # Get a unique name for the module.
  module_name = hashlib.md5(wrappers).hexdigest()
  module = imp.new_module(module_name)
  # pylint: disable=exec-used
  exec(wrappers, module.__dict__)
  # Stash away the library handle for making calls into the dynamic library.
  module.LIB_HANDLE = lib_handle
  # OpDefs of the list of ops defined in the library.
  module.OP_LIST = op_list
  sys.modules[module_name] = module
  # Memoize the filename to module mapping.
  with _OP_LIBRARY_MAP_LOCK:
    _OP_LIBRARY_MAP[library_filename] = module
  return module
Exemple #3
0
  def _do_run(self, target_list, fetch_list, feed_dict):
    """Runs a step based on the given fetches and feeds.

    Args:
      target_list: A list of byte arrays corresponding to names of tensors
        or operations to be run to, but not fetched.
      fetch_list: A list of byte arrays corresponding to names of tensors to
        be fetched and operations to be run.
      feed_dict: A dictionary that maps tensor names (as byte arrays) to
        numpy ndarrays.

    Returns:
      A list of numpy ndarrays, corresponding to the elements of
      `fetch_list`.  If the ith element of `fetch_list` contains the
      name of an operation, the first Tensor output of that operation
      will be returned for that element.
    """
    try:
      # Ensure any changes to the graph are reflected in the runtime.
      with self._extend_lock:
        if self._graph.version > self._current_version:
          graph_def = self._graph.as_graph_def(
              from_version=self._current_version)

          try:
            status = tf_session.TF_NewStatus()
            tf_session.TF_ExtendGraph(
                self._session, graph_def.SerializeToString(), status)
            if tf_session.TF_GetCode(status) != 0:
              raise RuntimeError(compat.as_text(tf_session.TF_Message(status)))
            self._opened = True
          finally:
            tf_session.TF_DeleteStatus(status)

          self._current_version = self._graph.version

      return tf_session.TF_Run(self._session, feed_dict, fetch_list,
                               target_list)

    except tf_session.StatusNotOK as e:
      e_type, e_value, e_traceback = sys.exc_info()
      error_message = compat.as_text(e.error_message)
      m = BaseSession._NODEDEF_NAME_RE.search(error_message)
      if m is not None:
        node_name = m.group(1)
        node_def = None
        try:
          op = self._graph.get_operation_by_name(node_name)
          node_def = op.node_def
        except KeyError:
          op = None
        # pylint: disable=protected-access
        raise errors._make_specific_exception(node_def, op, error_message,
                                              e.code)
        # pylint: enable=protected-access
      six.reraise(e_type, e_value, e_traceback)
Exemple #4
0
  def _do_run(self, target_list, fetch_list, feed_dict):
    """Runs a step based on the given fetches and feeds.

    Args:
      target_list: A list of byte arrays corresponding to names of tensors
        or operations to be run to, but not fetched.
      fetch_list: A list of byte arrays corresponding to names of tensors to
        be fetched and operations to be run.
      feed_dict: A dictionary that maps tensor names (as byte arrays) to
        numpy ndarrays.

    Returns:
      A list of numpy ndarrays, corresponding to the elements of
      `fetch_list`.  If the ith element of `fetch_list` contains the
      name of an operation, the first Tensor output of that operation
      will be returned for that element.
    """
    try:
      # Ensure any changes to the graph are reflected in the runtime.
      with self._extend_lock:
        if self._graph.version > self._current_version:
          graph_def = self._graph.as_graph_def(
              from_version=self._current_version)

          try:
            status = tf_session.TF_NewStatus()
            tf_session.TF_ExtendGraph(
                self._session, graph_def.SerializeToString(), status)
            if tf_session.TF_GetCode(status) != 0:
              raise RuntimeError(compat.as_text(tf_session.TF_Message(status)))
            self._opened = True
          finally:
            tf_session.TF_DeleteStatus(status)

          self._current_version = self._graph.version

      return tf_session.TF_Run(self._session, feed_dict, fetch_list,
                               target_list)

    except tf_session.StatusNotOK as e:
      e_type, e_value, e_traceback = sys.exc_info()
      error_message = compat.as_text(e.error_message)
      m = BaseSession._NODEDEF_NAME_RE.search(error_message)
      if m is not None:
        node_name = m.group(1)
        node_def = None
        try:
          op = self._graph.get_operation_by_name(node_name)
          node_def = op.node_def
        except KeyError:
          op = None
        # pylint: disable=protected-access
        raise errors._make_specific_exception(node_def, op, error_message,
                                              e.code)
        # pylint: enable=protected-access
      six.reraise(e_type, e_value, e_traceback)
def load_op_library(library_filename):
    """Loads a TensorFlow plugin, containing custom ops and kernels.

  Pass "library_filename" to a platform-specific mechanism for dynamically
  loading a library. The rules for determining the exact location of the
  library are platform-specific and are not documented here.

  Args:
    library_filename: Path to the plugin.
      Relative or absolute filesystem path to a dynamic library file.

  Returns:
    A python module containing the Python wrappers for Ops defined in
    the plugin.

  Raises:
    RuntimeError: when unable to load the library or get the python wrappers.
  """
    status = py_tf.TF_NewStatus()

    lib_handle = py_tf.TF_LoadLibrary(library_filename, status)
    try:
        error_code = py_tf.TF_GetCode(status)
        if error_code != 0:
            error_msg = compat.as_text(py_tf.TF_Message(status))
            with _OP_LIBRARY_MAP_LOCK:
                if (error_code == error_codes_pb2.ALREADY_EXISTS
                        and 'has already been loaded' in error_msg
                        and library_filename in _OP_LIBRARY_MAP):
                    return _OP_LIBRARY_MAP[library_filename]
            # pylint: disable=protected-access
            raise errors._make_specific_exception(None, None, error_msg,
                                                  error_code)
            # pylint: enable=protected-access
    finally:
        py_tf.TF_DeleteStatus(status)

    op_list_str = py_tf.TF_GetOpList(lib_handle)
    op_list = op_def_pb2.OpList()
    op_list.ParseFromString(compat.as_bytes(op_list_str))
    wrappers = py_tf.GetPythonWrappers(op_list_str)

    # Get a unique name for the module.
    module_name = hashlib.md5(wrappers).hexdigest()
    module = imp.new_module(module_name)
    # pylint: disable=exec-used
    exec(wrappers, module.__dict__)
    # Stash away the library handle for making calls into the dynamic library.
    module.LIB_HANDLE = lib_handle
    # OpDefs of the list of ops defined in the library.
    module.OP_LIST = op_list
    sys.modules[module_name] = module
    # Memoize the filename to module mapping.
    with _OP_LIBRARY_MAP_LOCK:
        _OP_LIBRARY_MAP[library_filename] = module
    return module
Exemple #6
0
def load_op_library(library_filename):
    """Loads a TensorFlow plugin, containing custom ops and kernels.

  Pass "library_filename" to a platform-specific mechanism for dynamically
  loading a library. The rules for determining the exact location of the
  library are platform-specific and are not documented here. When the
  library is loaded, ops and kernels registered in the library via the
  REGISTER_* macros are made available in the TensorFlow process. Note
  that ops with the same name as an existing op are rejected and not
  registered with the process.

  Args:
    library_filename: Path to the plugin.
      Relative or absolute filesystem path to a dynamic library file.

  Returns:
    A python module containing the Python wrappers for Ops defined in
    the plugin.

  Raises:
    RuntimeError: when unable to load the library or get the python wrappers.
  """
    status = py_tf.TF_NewStatus()

    lib_handle = py_tf.TF_LoadLibrary(library_filename, status)
    try:
        error_code = py_tf.TF_GetCode(status)
        if error_code != 0:
            error_msg = compat.as_text(py_tf.TF_Message(status))
            # pylint: disable=protected-access
            raise errors._make_specific_exception(None, None, error_msg,
                                                  error_code)
            # pylint: enable=protected-access
    finally:
        py_tf.TF_DeleteStatus(status)

    op_list_str = py_tf.TF_GetOpList(lib_handle)
    op_list = op_def_pb2.OpList()
    op_list.ParseFromString(compat.as_bytes(op_list_str))
    wrappers = py_tf.GetPythonWrappers(op_list_str)

    # Get a unique name for the module.
    module_name = hashlib.md5(wrappers).hexdigest()
    if module_name in sys.modules:
        return sys.modules[module_name]
    module = imp.new_module(module_name)
    # pylint: disable=exec-used
    exec(wrappers, module.__dict__)
    # Stash away the library handle for making calls into the dynamic library.
    module.LIB_HANDLE = lib_handle
    # OpDefs of the list of ops defined in the library.
    module.OP_LIST = op_list
    sys.modules[module_name] = module
    return module
Exemple #7
0
 def _do_call(self, fn, *args):
   try:
     return fn(*args)
   except tf_session.StatusNotOK as e:
     error_message = compat.as_text(e.error_message)
     m = BaseSession._NODEDEF_NAME_RE.search(error_message)
     node_def = None
     op = None
     if m is not None:
       node_name = m.group(1)
       try:
         op = self._graph.get_operation_by_name(node_name)
         node_def = op.node_def
       except KeyError:
         pass
     # pylint: disable=protected-access
     raise errors._make_specific_exception(node_def, op, error_message,
                                           e.code)
Exemple #8
0
 def _do_call(self, fn, *args):
     try:
         return fn(*args)
     except tf_session.StatusNotOK as e:
         error_message = compat.as_text(e.error_message)
         m = BaseSession._NODEDEF_NAME_RE.search(error_message)
         node_def = None
         op = None
         if m is not None:
             node_name = m.group(1)
             try:
                 op = self._graph.get_operation_by_name(node_name)
                 node_def = op.node_def
             except KeyError:
                 pass
         # pylint: disable=protected-access
         raise errors._make_specific_exception(node_def, op, error_message,
                                               e.code)
Exemple #9
0
    def __init__(self,
                 server_or_cluster_def,
                 job_name=None,
                 task_index=None,
                 protocol=None,
                 start=True):
        """Creates a new server with the given definition.

    The `job_name`, `task_index`, and `protocol` arguments are optional, and
    override any information provided in `server_or_cluster_def`.

    Args:
      server_or_cluster_def: A `tf.train.ServerDef` or
        `tf.train.ClusterDef` protocol buffer, or a
        `tf.train.ClusterSpec` object, describing the server to be
        created and/or the cluster of which it is a member.
      job_name: (Optional.) Specifies the name of the job of which the server
        is a member. Defaults to the value in `server_or_cluster_def`, if
        specified.
      task_index: (Optional.) Specifies the task index of the server in its
        job. Defaults to the value in `server_or_cluster_def`, if specified.
        Otherwise defaults to 0 if the server's job has only one task.
      protocol: (Optional.) Specifies the protocol to be used by the server.
        Acceptable values include `"grpc"`. Defaults to the value in
        `server_or_cluster_def`, if specified. Otherwise defaults to `"grpc"`.
      start: (Optional.) Boolean, indicating whether to start the server
        after creating it. Defaults to `True`.
    """
        server_def = _make_server_def(server_or_cluster_def, job_name,
                                      task_index, protocol)
        try:
            self._server = pywrap_tensorflow.NewServer(
                server_def.SerializeToString())
        except pywrap_tensorflow.StatusNotOK as e:
            # pylint: disable=protected-access
            raise errors._make_specific_exception(None, None, e.error_message,
                                                  e.code)
            # pylint: enable=protected-access
        if start:
            self.start()
Exemple #10
0
  def __init__(self,
               server_or_cluster_def,
               job_name=None,
               task_index=None,
               protocol=None,
               start=True):
    """Creates a new server with the given definition.

    The `job_name`, `task_index`, and `protocol` arguments are optional, and
    override any information provided in `server_or_cluster_def`.

    Args:
      server_or_cluster_def: A `tf.train.ServerDef` or
        `tf.train.ClusterDef` protocol buffer, or a
        `tf.train.ClusterSpec` object, describing the server to be
        created and/or the cluster of which it is a member.
      job_name: (Optional.) Specifies the name of the job of which the server
        is a member. Defaults to the value in `server_or_cluster_def`, if
        specified.
      task_index: (Optional.) Specifies the task index of the server in its
        job. Defaults to the value in `server_or_cluster_def`, if specified.
        Otherwise defaults to 0 if the server's job has only one task.
      protocol: (Optional.) Specifies the protocol to be used by the server.
        Acceptable values include `"grpc"`. Defaults to the value in
        `server_or_cluster_def`, if specified. Otherwise defaults to `"grpc"`.
      start: (Optional.) Boolean, indicating whether to start the server
        after creating it. Defaults to `True`.
    """
    server_def = _make_server_def(server_or_cluster_def,
                                  job_name, task_index, protocol)
    try:
      self._server = pywrap_tensorflow.NewServer(server_def.SerializeToString())
    except pywrap_tensorflow.StatusNotOK as e:
      # pylint: disable=protected-access
      raise errors._make_specific_exception(None, None, e.error_message, e.code)
      # pylint: enable=protected-access
    if start:
      self.start()
Exemple #11
0
def load_file_system_library(library_filename):
    """Loads a TensorFlow plugin, containing file system implementation.

  Pass `library_filename` to a platform-specific mechanism for dynamically
  loading a library. The rules for determining the exact location of the
  library are platform-specific and are not documented here.

  Args:
    library_filename: Path to the plugin.
      Relative or absolute filesystem path to a dynamic library file.

  Returns:
    None.

  Raises:
    RuntimeError: when unable to load the library.
  """
    status = py_tf.TF_NewStatus()
    lib_handle = py_tf.TF_LoadLibrary(library_filename, status)
    try:
        error_code = py_tf.TF_GetCode(status)
        if error_code != 0:
            error_msg = compat.as_text(py_tf.TF_Message(status))
            with _FILE_SYSTEM_LIBRARY_MAP_LOCK:
                if (error_code == error_codes_pb2.ALREADY_EXISTS
                        and 'has already been loaded' in error_msg
                        and library_filename in _FILE_SYSTEM_LIBRARY_MAP):
                    return
            # pylint: disable=protected-access
            raise errors._make_specific_exception(None, None, error_msg,
                                                  error_code)
            # pylint: enable=protected-access
    finally:
        py_tf.TF_DeleteStatus(status)

    with _FILE_SYSTEM_LIBRARY_MAP_LOCK:
        _FILE_SYSTEM_LIBRARY_MAP[library_filename] = lib_handle