예제 #1
0
def load_file_system_library(library_filename):
    """Loads a TensorFlow plugin, containing file system implementation.

  Pass `library_filename` to a platform-specific mechanism for dynamically
  loading a library. The rules for determining the exact location of the
  library are platform-specific and are not documented here.

  Args:
    library_filename: Path to the plugin.
      Relative or absolute filesystem path to a dynamic library file.

  Returns:
    None.

  Raises:
    RuntimeError: when unable to load the library.
  """
    status = py_tf.TF_NewStatus()
    lib_handle = py_tf.TF_LoadLibrary(library_filename, status)
    try:
        error_code = py_tf.TF_GetCode(status)
        if error_code != 0:
            error_msg = compat.as_text(py_tf.TF_Message(status))
            # pylint: disable=protected-access
            raise errors._make_specific_exception(None, None, error_msg,
                                                  error_code)
            # pylint: enable=protected-access
    finally:
        py_tf.TF_DeleteStatus(status)
예제 #2
0
def load_op_library(library_filename):
    """Loads a TensorFlow plugin, containing custom ops and kernels.

  Pass "library_filename" to a platform-specific mechanism for dynamically
  loading a library. The rules for determining the exact location of the
  library are platform-specific and are not documented here. When the
  library is loaded, ops and kernels registered in the library via the
  REGISTER_* macros are made available in the TensorFlow process. Note
  that ops with the same name as an existing op are rejected and not
  registered with the process.

  Args:
    library_filename: Path to the plugin.
      Relative or absolute filesystem path to a dynamic library file.

  Returns:
    A python module containing the Python wrappers for Ops defined in
    the plugin.

  Raises:
    RuntimeError: when unable to load the library or get the python wrappers.
  """
    status = py_tf.TF_NewStatus()

    lib_handle = py_tf.TF_LoadLibrary(library_filename, status)
    try:
        error_code = py_tf.TF_GetCode(status)
        if error_code != 0:
            error_msg = compat.as_text(py_tf.TF_Message(status))
            with _OP_LIBRARY_MAP_LOCK:
                if (error_code == error_codes_pb2.ALREADY_EXISTS
                        and 'has already been loaded' in error_msg
                        and library_filename in _OP_LIBRARY_MAP):
                    return _OP_LIBRARY_MAP[library_filename]
            # pylint: disable=protected-access
            raise errors._make_specific_exception(None, None, error_msg,
                                                  error_code)
            # pylint: enable=protected-access
    finally:
        py_tf.TF_DeleteStatus(status)

    op_list_str = py_tf.TF_GetOpList(lib_handle)
    op_list = op_def_pb2.OpList()
    op_list.ParseFromString(compat.as_bytes(op_list_str))
    wrappers = py_tf.GetPythonWrappers(op_list_str)

    # Get a unique name for the module.
    module_name = hashlib.md5(wrappers).hexdigest()
    module = imp.new_module(module_name)
    # pylint: disable=exec-used
    exec(wrappers, module.__dict__)
    # Stash away the library handle for making calls into the dynamic library.
    module.LIB_HANDLE = lib_handle
    # OpDefs of the list of ops defined in the library.
    module.OP_LIST = op_list
    sys.modules[module_name] = module
    # Memoize the filename to module mapping.
    with _OP_LIBRARY_MAP_LOCK:
        _OP_LIBRARY_MAP[library_filename] = module
    return module
예제 #3
0
    def _do_run(self, target_list, fetch_list, feed_dict):
        """Runs a step based on the given fetches and feeds.

    Args:
      target_list: A list of byte arrays corresponding to names of tensors
        or operations to be run to, but not fetched.
      fetch_list: A list of byte arrays corresponding to names of tensors to
        be fetched and operations to be run.
      feed_dict: A dictionary that maps tensor names (as byte arrays) to
        numpy ndarrays.

    Returns:
      A list of numpy ndarrays, corresponding to the elements of
      `fetch_list`.  If the ith element of `fetch_list` contains the
      name of an operation, the first Tensor output of that operation
      will be returned for that element.
    """
        try:
            # Ensure any changes to the graph are reflected in the runtime.
            with self._extend_lock:
                if self._graph.version > self._current_version:
                    graph_def = self._graph.as_graph_def(
                        from_version=self._current_version)

                    try:
                        status = tf_session.TF_NewStatus()
                        tf_session.TF_ExtendGraph(
                            self._session, graph_def.SerializeToString(),
                            status)
                        if tf_session.TF_GetCode(status) != 0:
                            raise RuntimeError(
                                compat.as_text(tf_session.TF_Message(status)))
                        self._opened = True
                    finally:
                        tf_session.TF_DeleteStatus(status)

                    self._current_version = self._graph.version

            return tf_session.TF_Run(self._session, feed_dict, fetch_list,
                                     target_list)

        except tf_session.StatusNotOK as e:
            e_type, e_value, e_traceback = sys.exc_info()
            error_message = compat.as_text(e.error_message)
            m = BaseSession._NODEDEF_NAME_RE.search(error_message)
            if m is not None:
                node_name = m.group(1)
                node_def = None
                try:
                    op = self._graph.get_operation_by_name(node_name)
                    node_def = op.node_def
                except KeyError:
                    op = None
                # pylint: disable=protected-access
                raise errors._make_specific_exception(node_def, op,
                                                      error_message, e.code)
                # pylint: enable=protected-access
            six.reraise(e_type, e_value, e_traceback)
예제 #4
0
 def __del__(self):
     self.close()
     try:
         status = tf_session.TF_NewStatus()
         if self._session is not None:
             tf_session.TF_DeleteSession(self._session, status)
             if tf_session.TF_GetCode(status) != 0:
                 raise RuntimeError(tf_session.TF_Message(status))
             self._session = None
     finally:
         tf_session.TF_DeleteStatus(status)
예제 #5
0
def raise_exception_on_not_ok_status():
    status = pywrap_tensorflow.TF_NewStatus()
    try:
        yield status
        if pywrap_tensorflow.TF_GetCode(status) != 0:
            raise _make_specific_exception(
                None, None,
                compat.as_text(pywrap_tensorflow.TF_Message(status)),
                pywrap_tensorflow.TF_GetCode(status))
    finally:
        pywrap_tensorflow.TF_DeleteStatus(status)
예제 #6
0
 def __del__(self):
   # cleanly ignore all exceptions
   try:
     self.close()
   except Exception:  # pylint: disable=broad-except
     pass
   if self._session is not None:
     try:
       status = tf_session.TF_NewStatus()
       tf_session.TF_DeleteDeprecatedSession(self._session, status)
     finally:
       tf_session.TF_DeleteStatus(status)
     self._session = None
예제 #7
0
def is_directory(dirname):
    """Returns whether the path is a directory or not.

  Args:
    dirname: string, path to a potential directory

  Returns:
    True, if the path is a directory; False otherwise
  """
    try:
        status = pywrap_tensorflow.TF_NewStatus()
        return pywrap_tensorflow.IsDirectory(compat.as_bytes(dirname), status)
    finally:
        pywrap_tensorflow.TF_DeleteStatus(status)
예제 #8
0
def load_op_library(library_filename):
    """Loads a TensorFlow plugin, containing custom ops and kernels.

  Pass "library_filename" to a platform-specific mechanism for dynamically
  loading a library. The rules for determining the exact location of the
  library are platform-specific and are not documented here.
  Expects the symbols "RegisterOps", "RegisterKernels", and "GetOpList", to be
  defined in the library.

  Args:
    library_filename: Path to the plugin.
      Relative or absolute filesystem path to a dynamic library file.

  Returns:
    A python module containing the Python wrappers for Ops defined in
    the plugin.

  Raises:
    RuntimeError: when unable to load the library or get the python wrappers.
  """
    status = py_tf.TF_NewStatus()

    lib_handle = py_tf.TF_LoadLibrary(library_filename, status)
    try:
        if py_tf.TF_GetCode(status) != 0:
            raise RuntimeError(compat.as_text(py_tf.TF_Message(status)))
    finally:
        py_tf.TF_DeleteStatus(status)

    op_list_str = py_tf.TF_GetOpList(lib_handle)
    op_list = op_def_pb2.OpList()
    op_list.ParseFromString(compat.as_bytes(op_list_str))
    wrappers = py_tf.GetPythonWrappers(op_list_str, len(op_list_str))

    # Get a unique name for the module.
    module_name = hashlib.md5(wrappers).hexdigest()
    module = imp.new_module(module_name)
    # pylint: disable=exec-used
    exec(wrappers, module.__dict__)
    # Stash away the library handle for making calls into the dynamic library.
    module.LIB_HANDLE = lib_handle
    # OpDefs of the list of ops defined in the library.
    module.OP_LIST = op_list
    sys.modules[module_name] = module
    return module
예제 #9
0
    def close(self):
        """Closes this session.

    Calling this method frees all resources associated with the session.

    Raises:
      RuntimeError: If an error occurs while closing the session.
    """
        with self._extend_lock:
            if self._opened and not self._closed:
                self._closed = True
                try:
                    status = tf_session.TF_NewStatus()
                    tf_session.TF_CloseSession(self._session, status)
                    if tf_session.TF_GetCode(status) != 0:
                        raise RuntimeError(tf_session.TF_Message(status))
                finally:
                    tf_session.TF_DeleteStatus(status)
예제 #10
0
  def _extend_graph(self):
    # Ensure any changes to the graph are reflected in the runtime.
    with self._extend_lock:
      if self._graph.version > self._current_version:
        graph_def = self._graph.as_graph_def(
            from_version=self._current_version)

        try:
          status = tf_session.TF_NewStatus()
          tf_session.TF_ExtendGraph(
              self._session, graph_def.SerializeToString(), status)
          if tf_session.TF_GetCode(status) != 0:
            raise RuntimeError(compat.as_text(tf_session.TF_Message(status)))
          self._opened = True
        finally:
          tf_session.TF_DeleteStatus(status)

        self._current_version = self._graph.version
예제 #11
0
    def __init__(self, target='', graph=None, config=None):
        """Constructs a new TensorFlow session.

    Args:
      target: (Optional) The TensorFlow execution engine to connect to.
      graph: (Optional) The graph to be used. If this argument is None,
        the default graph will be used.
      config: (Optional) ConfigProto proto used to configure the session.

    Raises:
      RuntimeError: If an error occurs while creating the TensorFlow
        session.
    """
        if graph is None:
            self._graph = ops.get_default_graph()
        else:
            self._graph = graph

        self._opened = False
        self._closed = False

        self._current_version = 0
        self._extend_lock = threading.Lock()
        self._target = target

        self._delete_lock = threading.Lock()
        self._dead_handles = []

        self._session = None

        opts = tf_session.TF_NewSessionOptions(target=target, config=config)
        try:
            status = tf_session.TF_NewStatus()
            try:
                self._session = tf_session.TF_NewSession(opts, status)
                if tf_session.TF_GetCode(status) != 0:
                    raise RuntimeError(
                        compat.as_text(tf_session.TF_Message(status)))
            finally:
                tf_session.TF_DeleteStatus(status)
        finally:
            tf_session.TF_DeleteSessionOptions(opts)
예제 #12
0
 def __del__(self):
   # cleanly ignore all exceptions
   try:
     self.close()
   except Exception:  # pylint: disable=broad-except
     pass
   if self._session is not None:
     # We create `status` outside the `try` block because at shutdown
     # `tf_session` may have been garbage collected, and the creation
     # of a status object may fail. In that case, we prefer to ignore
     # the failure and silently leak the session object, since the
     # program is about to terminate.
     status = None
     try:
       status = tf_session.TF_NewStatus()
       tf_session.TF_DeleteDeprecatedSession(self._session, status)
     finally:
       if status is not None:
         tf_session.TF_DeleteStatus(status)
     self._session = None
예제 #13
0
 def __del__(self):
     # Note: when we're destructing the global context (i.e when the process is
     # terminating) we can have already deleted other modules.
     if c_api is not None and c_api.TF_DeleteStatus is not None:
         c_api.TF_DeleteStatus(self.status)