Example #1
0
    def __init__(self, target='', graph=None, config=None):
        """Constructs a new TensorFlow session.

    Args:
      target: (Optional) The TensorFlow execution engine to connect to.
      graph: (Optional) The graph to be used. If this argument is None,
        the default graph will be used.
      config: (Optional) ConfigProto proto used to configure the session.

    Raises:
      tf.errors.OpError: Or one of its subclasses if an error occurs while
        creating the TensorFlow session.
      TypeError: If one of the arguments has the wrong type.
    """
        if graph is None:
            self._graph = ops.get_default_graph()
        else:
            if not isinstance(graph, ops.Graph):
                raise TypeError('graph must be a tf.Graph, but got %s' %
                                type(graph))
            self._graph = graph

        self._opened = False
        self._closed = False

        self._current_version = 0
        self._extend_lock = threading.Lock()
        if target is not None:
            try:
                self._target = compat.as_bytes(target)
            except TypeError:
                raise TypeError('target must be a string, but got %s' %
                                type(target))
        else:
            self._target = None

        self._delete_lock = threading.Lock()
        self._dead_handles = []

        if config is not None:
            if not isinstance(config, config_pb2.ConfigProto):
                raise TypeError('config must be a tf.ConfigProto, but got %s' %
                                type(config))
            self._config = config
            self._add_shapes = config.graph_options.infer_shapes
        else:
            self._config = None
            self._add_shapes = False

        self._session = None
        opts = tf_session.TF_NewSessionOptions(target=self._target,
                                               config=config)
        try:
            with errors.raise_exception_on_not_ok_status() as status:
                self._session = tf_session.TF_NewSession(opts, status)
        finally:
            tf_session.TF_DeleteSessionOptions(opts)
Example #2
0
 def testInvalidDeviceNumber(self):
   opts = tf_session.TF_NewSessionOptions()
   c_session = tf_session.TF_NewSession(ops.get_default_graph()._c_graph, opts)
   raw_device_list = tf_session.TF_SessionListDevices(c_session)
   size = tf_session.TF_DeviceListCount(raw_device_list)
   with self.assertRaises(errors.InvalidArgumentError):
     tf_session.TF_DeviceListMemoryBytes(raw_device_list, size)
   tf_session.TF_DeleteDeviceList(raw_device_list)
   tf_session.TF_CloseSession(c_session)
Example #3
0
 def testInvalidDeviceNumber(self):
     opts = tf_session.TF_NewSessionOptions()
     with errors.raise_exception_on_not_ok_status() as status:
         c_session = tf_session.TF_NewSession(
             ops.get_default_graph()._c_graph, opts, status)
         raw_device_list = tf_session.TF_SessionListDevices(
             c_session, status)
     size = tf_session.TF_DeviceListCount(raw_device_list)
     # Test that invalid device numbers return -1 rather than a Swig-wrapped
     # pointer.
     status_no_exception = c_api_util.ScopedTFStatus()
     memory = tf_session.TF_DeviceListMemoryBytes(raw_device_list, size,
                                                  status_no_exception)
     self.assertEqual(memory, -1)
     tf_session.TF_DeleteDeviceList(raw_device_list)
     with errors.raise_exception_on_not_ok_status() as status:
         tf_session.TF_CloseSession(c_session, status)
Example #4
0
    def __init__(self, target='', graph=None, config=None):
        """Constructs a new TensorFlow session.

    Args:
      target: (Optional) The TensorFlow execution engine to connect to.
      graph: (Optional) The graph to be used. If this argument is None,
        the default graph will be used.
      config: (Optional) ConfigProto proto used to configure the session.

    Raises:
      RuntimeError: If an error occurs while creating the TensorFlow
        session.
    """
        if graph is None:
            self._graph = ops.get_default_graph()
        else:
            self._graph = graph

        self._opened = False
        self._closed = False

        self._current_version = 0
        self._extend_lock = threading.Lock()
        self._target = target

        self._delete_lock = threading.Lock()
        self._dead_handles = []

        self._session = None

        opts = tf_session.TF_NewSessionOptions(target=target, config=config)
        try:
            status = tf_session.TF_NewStatus()
            try:
                self._session = tf_session.TF_NewSession(opts, status)
                if tf_session.TF_GetCode(status) != 0:
                    raise RuntimeError(
                        compat.as_text(tf_session.TF_Message(status)))
            finally:
                tf_session.TF_DeleteStatus(status)
        finally:
            tf_session.TF_DeleteSessionOptions(opts)
Example #5
0
    def __init__(self, target='', graph=None, config=None):
        """Constructs a new TensorFlow session.

    Args:
      target: (Optional) The TensorFlow execution engine to connect to.
      graph: (Optional) The graph to be used. If this argument is None,
        the default graph will be used.
      config: (Optional) ConfigProto proto used to configure the session.

    Raises:
      tf.errors.OpError: Or one of its subclasses if an error occurs while
        creating the TensorFlow session.
    """
        if graph is None:
            self._graph = ops.get_default_graph()
        else:
            self._graph = graph

        self._opened = False
        self._closed = False

        self._current_version = 0
        self._extend_lock = threading.Lock()
        self._target = target

        self._delete_lock = threading.Lock()
        self._dead_handles = []

        self._session = None
        self._config = config
        self._add_shapes = config.graph_options.infer_shapes if (
            config and config.graph_options) else False

        try:
            opts = tf_session.TF_NewSessionOptions(target=target,
                                                   config=config)
            with errors.raise_exception_on_not_ok_status() as status:
                self._session = tf_session.TF_NewSession(opts, status)
        finally:
            tf_session.TF_DeleteSessionOptions(opts)