Esempio n. 1
0
    def _initialize_handle_and_devices(self):
        """Initialize handle and devices."""
        with self._initialize_lock:
            if self._context_handle is not None:
                return
            assert self._context_devices is None
            opts = pywrap_tensorflow.TFE_NewContextOptions()
            try:
                if self._config is not None:
                    config_str = self._config.SerializeToString()
                    pywrap_tensorflow.TFE_ContextOptionsSetConfig(
                        opts, config_str)
                if self._device_policy is not None:
                    pywrap_tensorflow.TFE_ContextOptionsSetDevicePlacementPolicy(
                        opts, self._device_policy)
                if self._execution_mode == ASYNC:
                    pywrap_tensorflow.TFE_ContextOptionsSetAsync(opts, True)
                self._context_handle = pywrap_tensorflow.TFE_NewContext(opts)
            finally:
                pywrap_tensorflow.TFE_DeleteContextOptions(opts)
            assert not (
                self._server_def and self._collective_ops_server_def
            ), ("Cannot enable remote execution as well as collective ops at the "
                "moment. If this is important to you, please file an issue.")
            if self._server_def is not None:
                server_def_str = self._server_def.SerializeToString()
                pywrap_tensorflow.TFE_ContextSetServerDef(
                    self._context_handle, 600, server_def_str)
            elif self._collective_ops_server_def is not None:
                server_def_str = self._collective_ops_server_def.SerializeToString(
                )
                pywrap_tensorflow.TFE_EnableCollectiveOps(
                    self._context_handle, server_def_str)

            self._initialize_devices()
Esempio n. 2
0
    def enable_collective_ops(self, server_def):
        """Enable collective ops with an appropriate server_def.

    If previously enabled, this cannot be re-enabled.

    Args:
      server_def: A tensorflow::ServerDef proto. Enables execution on remote
        devices.

    Raises:
      ValueError: if server_def is None.
    """
        if not server_def:
            raise ValueError("server_def is None.")
        if not self._context_handle:
            self._collective_ops_server_def = server_def
        else:
            server_def_str = server_def.SerializeToString()
            pywrap_tensorflow.TFE_EnableCollectiveOps(self._context_handle,
                                                      server_def_str)

            self._clear_caches()
            self._initialize_devices()