Ejemplo n.º 1
0
 def test_econnrefused(self):
     (family, localhost) = _localhost()
     s = socket.socket(family)
     s.bind((localhost, 0))
     self.addCleanup(s.close)
     port = s.getsockname()[1]
     with self.assertRaises(server_info.CommunicationError) as cm:
         server_info.fetch_server_info("http://localhost:%d" % port)
     msg = str(cm.exception)
     self.assertIn("Failed to connect to backend", msg)
     if os.name != "nt":
         self.assertIn(os.strerror(errno.ECONNREFUSED), msg)
Ejemplo n.º 2
0
    def test_non_ok_response(self):
        @wrappers.BaseRequest.application
        def app(request):
            del request  # unused
            return wrappers.BaseResponse(b"very sad", status="502 Bad Gateway")

        origin = self._start_server(app)
        with self.assertRaises(server_info.CommunicationError) as cm:
            server_info.fetch_server_info(origin)
        msg = str(cm.exception)
        self.assertIn("Non-OK status from backend (502 Bad Gateway)", msg)
        self.assertIn("very sad", msg)
Ejemplo n.º 3
0
    def test_corrupt_response(self):
        @wrappers.BaseRequest.application
        def app(request):
            del request  # unused
            return wrappers.BaseResponse(b"an unlikely proto")

        origin = self._start_server(app)
        with self.assertRaises(server_info.CommunicationError) as cm:
            server_info.fetch_server_info(origin)
        msg = str(cm.exception)
        self.assertIn("Corrupt response from backend", msg)
        self.assertIn("an unlikely proto", msg)
def _get_server_info(api_endpoint=None):
    # TODO(cais): Add more plugins to the list when more plugin/data types
    # are supported
    plugins = ["scalars"]
    if api_endpoint:
        return server_info_lib.create_server_info(DEFAULT_ORIGIN, api_endpoint,
                                                  plugins)
    return server_info_lib.fetch_server_info(DEFAULT_ORIGIN, plugins)
Ejemplo n.º 5
0
def _get_server_info(flags):
    origin = flags.origin or _DEFAULT_ORIGIN
    if flags.api_endpoint and not flags.origin:
        return server_info_lib.create_server_info(origin, flags.api_endpoint)
    server_info = server_info_lib.fetch_server_info(origin)
    # Override with any API server explicitly specified on the command
    # line, but only if the server accepted our initial handshake.
    if flags.api_endpoint and server_info.api_server.endpoint:
        server_info.api_server.endpoint = flags.api_endpoint
    return server_info
Ejemplo n.º 6
0
    def test_user_agent(self):
        @wrappers.BaseRequest.application
        def app(request):
            result = server_info_pb2.ServerInfoResponse()
            result.compatibility.details = request.headers["User-Agent"]
            return wrappers.BaseResponse(result.SerializeToString())

        origin = self._start_server(app)
        result = server_info.fetch_server_info(origin)
        expected_user_agent = "tensorboard/%s" % version.VERSION
        self.assertEqual(result.compatibility.details, expected_user_agent)
Ejemplo n.º 7
0
    def test_fetches_with_plugins(self):
        @wrappers.BaseRequest.application
        def app(request):
            body = request.get_data()
            request_pb = server_info_pb2.ServerInfoRequest.FromString(body)
            self.assertEqual(request_pb.version, version.VERSION)
            self.assertEqual(
                request_pb.plugin_specification.upload_plugins,
                ["plugin1", "plugin2"],
            )
            return wrappers.BaseResponse(
                server_info_pb2.ServerInfoResponse().SerializeToString())

        origin = self._start_server(app)
        result = server_info.fetch_server_info(origin, ["plugin1", "plugin2"])
        self.assertIsNotNone(result)
Ejemplo n.º 8
0
    def test_fetches_response(self):
        expected_result = server_info_pb2.ServerInfoResponse()
        expected_result.compatibility.verdict = server_info_pb2.VERDICT_OK
        expected_result.compatibility.details = "all clear"
        expected_result.api_server.endpoint = "api.example.com:443"
        expected_result.url_format.template = "http://localhost:8080/{{eid}}"
        expected_result.url_format.id_placeholder = "{{eid}}"

        @wrappers.BaseRequest.application
        def app(request):
            self.assertEqual(request.method, "POST")
            self.assertEqual(request.path, "/api/uploader")
            body = request.get_data()
            request_pb = server_info_pb2.ServerInfoRequest.FromString(body)
            self.assertEqual(request_pb.version, version.VERSION)
            return wrappers.BaseResponse(expected_result.SerializeToString())

        origin = self._start_server(app)
        result = server_info.fetch_server_info(origin)
        self.assertEqual(result, expected_result)