Esempio n. 1
0
    def test_data_location(self):
        res = data_provider_pb2.GetExperimentResponse()
        res.data_location = "./logs/mnist"
        self.stub.GetExperiment.return_value = res

        actual = self.provider.data_location(self.ctx, experiment_id="123")
        self.assertEqual(actual, "./logs/mnist")

        req = data_provider_pb2.GetExperimentRequest()
        req.experiment_id = "123"
        self.stub.GetExperiment.assert_called_once_with(req)
Esempio n. 2
0
    def test_experiment_metadata_when_only_data_location_set(self):
        res = data_provider_pb2.GetExperimentResponse()
        self.stub.GetExperiment.return_value = res

        actual = self.provider.experiment_metadata(self.ctx,
                                                   experiment_id="123")
        self.assertEqual(actual, provider.ExperimentMetadata())

        req = data_provider_pb2.GetExperimentRequest()
        req.experiment_id = "123"
        self.stub.GetExperiment.assert_called_once_with(req)
Esempio n. 3
0
 def experiment_metadata(self, ctx, *, experiment_id):
     req = data_provider_pb2.GetExperimentRequest()
     req.experiment_id = experiment_id
     with _translate_grpc_error():
         res = self._stub.GetExperiment(req)
     res = provider.ExperimentMetadata(
         data_location=res.data_location,
         experiment_name=res.name,
         experiment_description=res.description,
         creation_time=_timestamp_proto_to_float(res.creation_time),
     )
     return res
Esempio n. 4
0
    def test_experiment_metadata_with_partial_metadata(self):
        res = data_provider_pb2.GetExperimentResponse()
        res.name = "mnist"
        self.stub.GetExperiment.return_value = res

        actual = self.provider.experiment_metadata(self.ctx,
                                                   experiment_id="123")
        self.assertEqual(
            actual,
            provider.ExperimentMetadata(
                experiment_name="mnist",
                experiment_description="",
                creation_time=0,
            ),
        )

        req = data_provider_pb2.GetExperimentRequest()
        req.experiment_id = "123"
        self.stub.GetExperiment.assert_called_once_with(req)
Esempio n. 5
0
    def test_experiment_metadata_with_creation_time(self):
        res = data_provider_pb2.GetExperimentResponse()
        res.name = "mnist"
        res.description = "big breakthroughs"
        res.creation_time.FromMilliseconds(1500)
        self.stub.GetExperiment.return_value = res

        actual = self.provider.experiment_metadata(self.ctx,
                                                   experiment_id="123")
        self.assertEqual(
            actual,
            provider.ExperimentMetadata(
                experiment_name="mnist",
                experiment_description="big breakthroughs",
                creation_time=1.5,
            ),
        )

        req = data_provider_pb2.GetExperimentRequest()
        req.experiment_id = "123"
        self.stub.GetExperiment.assert_called_once_with(req)
Esempio n. 6
0
 def data_location(self, ctx, *, experiment_id):
     req = data_provider_pb2.GetExperimentRequest()
     req.experiment_id = experiment_id
     with _translate_grpc_error():
         res = self._stub.GetExperiment(req)
     return res.data_location
Esempio n. 7
0
    def start(self):
        if self._data_provider:
            return

        tmpdir = tempfile.TemporaryDirectory(prefix="tensorboard_data_server_")
        port_file_path = os.path.join(tmpdir.name, "port")
        error_file_path = os.path.join(tmpdir.name, "startup_error")

        if self._reload_interval <= 0:
            reload = "once"
        else:
            reload = str(int(self._reload_interval))

        sample_hint_pairs = [
            "%s=%s" % (k, "all" if v == 0 else v)
            for k, v in self._samples_per_plugin.items()
        ]
        samples_per_plugin = ",".join(sample_hint_pairs)

        args = [
            self._server_binary.path,
            "--logdir=%s" % os.path.expanduser(self._logdir),
            "--reload=%s" % reload,
            "--samples-per-plugin=%s" % samples_per_plugin,
            "--port=0",
            "--port-file=%s" % (port_file_path, ),
            "--die-after-stdin",
        ]
        if self._server_binary.at_least_version("0.5.0a0"):
            args.append("--error-file=%s" % (error_file_path, ))
        if logger.isEnabledFor(logging.INFO):
            args.append("--verbose")
        if logger.isEnabledFor(logging.DEBUG):
            args.append("--verbose")  # Repeat arg to increase verbosity.
        args.extend(self._extra_flags)

        logger.info("Spawning data server: %r", args)
        popen = subprocess.Popen(args, stdin=subprocess.PIPE)
        # Stash stdin to avoid calling its destructor: on Windows, this
        # is a `subprocess.Handle` that closes itself in `__del__`,
        # which would cause the data server to shut down. (This is not
        # documented; you have to read CPython source to figure it out.)
        # We want that to happen at end of process, but not before.
        self._stdin_handle = popen.stdin  # stash to avoid stdin being closed

        port = None
        # The server only needs about 10 microseconds to spawn on my machine,
        # but give a few orders of magnitude of padding, and then poll.
        time.sleep(0.01)
        for i in range(20):
            if popen.poll() is not None:
                msg = (_maybe_read_file(error_file_path) or "").strip()
                if not msg:
                    msg = ("exited with %d; check stderr for details" %
                           popen.poll())
                raise DataServerStartupError(msg)
            logger.info("Polling for data server port (attempt %d)", i)
            port_file_contents = _maybe_read_file(port_file_path)
            logger.info("Port file contents: %r", port_file_contents)
            if (port_file_contents or "").endswith("\n"):
                port = int(port_file_contents)
                break
            # Else, not done writing yet.
            time.sleep(0.5)
        if port is None:
            raise DataServerStartupError(
                "Timed out while waiting for data server to start. "
                "It may still be running as pid %d." % popen.pid)

        addr = "localhost:%d" % port
        stub = _make_stub(addr, self._channel_creds_type)
        logger.info(
            "Opened channel to data server at pid %d via %s",
            popen.pid,
            addr,
        )

        req = data_provider_pb2.GetExperimentRequest()
        try:
            stub.GetExperiment(req, timeout=5)  # should be near-instant
        except grpc.RpcError as e:
            msg = "Failed to communicate with data server at %s: %s" % (addr,
                                                                        e)
            logging.warning("%s", msg)
            raise DataServerStartupError(msg) from e
        logger.info("Got valid response from data server")
        self._data_provider = grpc_provider.GrpcDataProvider(addr, stub)