def test_econnrefused(self): (family, localhost) = _localhost() s = socket.socket(family) s.bind((localhost, 0)) self.addCleanup(s.close) port = s.getsockname()[1] with self.assertRaises(server_info.CommunicationError) as cm: server_info.fetch_server_info("http://localhost:%d" % port) msg = str(cm.exception) self.assertIn("Failed to connect to backend", msg) if os.name != "nt": self.assertIn(os.strerror(errno.ECONNREFUSED), msg)
def test_non_ok_response(self): @wrappers.BaseRequest.application def app(request): del request # unused return wrappers.BaseResponse(b"very sad", status="502 Bad Gateway") origin = self._start_server(app) with self.assertRaises(server_info.CommunicationError) as cm: server_info.fetch_server_info(origin) msg = str(cm.exception) self.assertIn("Non-OK status from backend (502 Bad Gateway)", msg) self.assertIn("very sad", msg)
def test_corrupt_response(self): @wrappers.BaseRequest.application def app(request): del request # unused return wrappers.BaseResponse(b"an unlikely proto") origin = self._start_server(app) with self.assertRaises(server_info.CommunicationError) as cm: server_info.fetch_server_info(origin) msg = str(cm.exception) self.assertIn("Corrupt response from backend", msg) self.assertIn("an unlikely proto", msg)
def _get_server_info(api_endpoint=None): # TODO(cais): Add more plugins to the list when more plugin/data types # are supported plugins = ["scalars"] if api_endpoint: return server_info_lib.create_server_info(DEFAULT_ORIGIN, api_endpoint, plugins) return server_info_lib.fetch_server_info(DEFAULT_ORIGIN, plugins)
def _get_server_info(flags): origin = flags.origin or _DEFAULT_ORIGIN if flags.api_endpoint and not flags.origin: return server_info_lib.create_server_info(origin, flags.api_endpoint) server_info = server_info_lib.fetch_server_info(origin) # Override with any API server explicitly specified on the command # line, but only if the server accepted our initial handshake. if flags.api_endpoint and server_info.api_server.endpoint: server_info.api_server.endpoint = flags.api_endpoint return server_info
def test_user_agent(self): @wrappers.BaseRequest.application def app(request): result = server_info_pb2.ServerInfoResponse() result.compatibility.details = request.headers["User-Agent"] return wrappers.BaseResponse(result.SerializeToString()) origin = self._start_server(app) result = server_info.fetch_server_info(origin) expected_user_agent = "tensorboard/%s" % version.VERSION self.assertEqual(result.compatibility.details, expected_user_agent)
def test_fetches_with_plugins(self): @wrappers.BaseRequest.application def app(request): body = request.get_data() request_pb = server_info_pb2.ServerInfoRequest.FromString(body) self.assertEqual(request_pb.version, version.VERSION) self.assertEqual( request_pb.plugin_specification.upload_plugins, ["plugin1", "plugin2"], ) return wrappers.BaseResponse( server_info_pb2.ServerInfoResponse().SerializeToString()) origin = self._start_server(app) result = server_info.fetch_server_info(origin, ["plugin1", "plugin2"]) self.assertIsNotNone(result)
def test_fetches_response(self): expected_result = server_info_pb2.ServerInfoResponse() expected_result.compatibility.verdict = server_info_pb2.VERDICT_OK expected_result.compatibility.details = "all clear" expected_result.api_server.endpoint = "api.example.com:443" expected_result.url_format.template = "http://localhost:8080/{{eid}}" expected_result.url_format.id_placeholder = "{{eid}}" @wrappers.BaseRequest.application def app(request): self.assertEqual(request.method, "POST") self.assertEqual(request.path, "/api/uploader") body = request.get_data() request_pb = server_info_pb2.ServerInfoRequest.FromString(body) self.assertEqual(request_pb.version, version.VERSION) return wrappers.BaseResponse(expected_result.SerializeToString()) origin = self._start_server(app) result = server_info.fetch_server_info(origin) self.assertEqual(result, expected_result)