Ejemplo n.º 1
0
    def send_run(self, data):
        run = data.run
        error = None
        is_wandb_init = self._run is None

        # build config dict
        config_dict = None
        config_path = os.path.join(self._settings.files_dir,
                                   filenames.CONFIG_FNAME)
        if run.config:
            config_dict = config_util.dict_from_proto_list(run.config.update)
            config_util.save_config_file_from_dict(config_path, config_dict)

        if is_wandb_init:
            # Ensure we have a project to query for status
            if run.project == "":
                run.project = util.auto_project_name(self._settings.program)
            # Only check resume status on `wandb.init`
            error = self._maybe_setup_resume(run)

        if error is not None:
            if data.control.req_resp:
                resp = wandb_internal_pb2.Result(uuid=data.uuid)
                resp.run_result.run.CopyFrom(run)
                resp.run_result.error.CopyFrom(error)
                self._result_q.put(resp)
            else:
                logger.error("Got error in async mode: %s", error.message)
            return

        # Save the resumed config
        if self._resume_state["config"] is not None:
            # TODO: should we merge this with resumed config?
            config_override = config_dict or {}
            config_dict = self._resume_state["config"]
            config_dict.update(config_override)
            config_util.save_config_file_from_dict(config_path, config_dict)

        self._init_run(run, config_dict)

        if data.control.req_resp:
            resp = wandb_internal_pb2.Result(uuid=data.uuid)
            # TODO: we could do self._interface.publish_defer(resp) to notify
            # the handler not to actually perform server updates for this uuid
            # because the user process will send a summary update when we resume
            resp.run_result.run.CopyFrom(self._run)
            self._result_q.put(resp)

        # Only spin up our threads on the first run message
        if is_wandb_init:
            self._start_run_threads()
        else:
            logger.info("updated run: %s", self._run.run_id)
Ejemplo n.º 2
0
    def send_request_defer(self, data):
        defer = data.request.defer
        state = defer.state
        logger.info("handle sender defer: {}".format(state))

        done = False
        if state == defer.BEGIN:
            pass
        elif state == defer.FLUSH_STATS:
            # NOTE: this is handled in handler.py:handle_request_defer()
            pass
        elif state == defer.FLUSH_TB:
            # NOTE: this is handled in handler.py:handle_request_defer()
            pass
        elif state == defer.FLUSH_SUM:
            # NOTE: this is handled in handler.py:handle_request_defer()
            pass
        elif state == defer.FLUSH_DIR:
            if self._dir_watcher:
                self._dir_watcher.finish()
                self._dir_watcher = None
        elif state == defer.FLUSH_FP:
            if self._pusher:
                self._pusher.finish()
        elif state == defer.FLUSH_FS:
            if self._fs:
                # TODO(jhr): now is a good time to output pending output lines
                self._fs.finish(self._exit_code)
                self._fs = None
        elif state == defer.FLUSH_FINAL:
            self._interface.publish_final()
            self._interface.publish_footer()
        elif state == defer.END:
            done = True
        else:
            raise AssertionError("unknown state")

        if not done:
            state += 1
            logger.info("send defer: {}".format(state))
            self._interface.publish_defer(state)
            return

        exit_result = wandb_internal_pb2.RunExitResult()

        # This path is not the prefered method to return exit results
        # as it could take a long time to flush the file pusher buffers
        if self._exit_sync_uuid:
            if self._pusher:
                # NOTE: This will block until finished
                self._pusher.print_status()
                self._pusher.join()
                self._pusher = None
            resp = wandb_internal_pb2.Result(
                exit_result=exit_result, uuid=self._exit_sync_uuid
            )
            self._result_q.put(resp)

        # mark exit done in case we are polling on exit
        self._exit_result = exit_result
Ejemplo n.º 3
0
    def handle_request_defer(self, data):
        logger.info("handle defer")

        if self._dir_watcher:
            self._dir_watcher.finish()
            self._dir_watcher = None

        if self._pusher:
            self._pusher.finish()

        if self._fs:
            # TODO(jhr): now is a good time to output pending output lines
            self._fs.finish(self._exit_code)
            self._fs = None

        # NB: assume we always need to send a response for this message
        # since it was sent on behalf of handle_exit() req/resp logic
        resp = wandb_internal_pb2.Result(uuid=data.uuid)
        file_counts = self._pusher.file_counts_by_category()
        resp.exit_result.files.wandb_count = file_counts["wandb"]
        resp.exit_result.files.media_count = file_counts["media"]
        resp.exit_result.files.artifact_count = file_counts["artifact"]
        resp.exit_result.files.other_count = file_counts["other"]
        self._resp_q.put(resp)

        # TODO(david): this info should be in exit_result footer?
        if self._pusher:
            self._pusher.print_status()
            self._pusher = None
Ejemplo n.º 4
0
    def send_request_poll_exit(self, record):
        if not record.control.req_resp:
            return

        result = wandb_internal_pb2.Result(uuid=record.uuid)

        alive = False
        if self._pusher:
            alive, status = self._pusher.get_status()
            file_counts = self._pusher.file_counts_by_category()
            resp = result.response.poll_exit_response
            resp.pusher_stats.uploaded_bytes = status["uploaded_bytes"]
            resp.pusher_stats.total_bytes = status["total_bytes"]
            resp.pusher_stats.deduped_bytes = status["deduped_bytes"]
            resp.file_counts.wandb_count = file_counts["wandb"]
            resp.file_counts.media_count = file_counts["media"]
            resp.file_counts.artifact_count = file_counts["artifact"]
            resp.file_counts.other_count = file_counts["other"]

        if self._exit_result and not alive:
            # pusher join should not block as it was reported as not alive
            if self._pusher:
                self._pusher.join()
            result.response.poll_exit_response.exit_result.CopyFrom(
                self._exit_result)
            result.response.poll_exit_response.done = True
        self._result_q.put(result)
Ejemplo n.º 5
0
    def handle_request_run_start(self, record):
        run_start = record.request.run_start
        assert run_start
        assert run_start.run

        if not self._settings._disable_stats:
            pid = os.getpid()
            self._system_stats = stats.SystemStats(pid=pid,
                                                   interface=self._interface)
            self._system_stats.start()

        if not self._settings._disable_meta:
            run_meta = meta.Meta(settings=self._settings,
                                 interface=self._interface)
            run_meta.probe()
            run_meta.write()

        self._tb_watcher = tb_watcher.TBWatcher(self._settings,
                                                interface=self._interface,
                                                run_proto=run_start.run)

        if run_start.run.resumed:
            self._step = run_start.run.starting_step
        result = wandb_internal_pb2.Result(uuid=record.uuid)
        self._result_q.put(result)
Ejemplo n.º 6
0
 def handle_request_get_summary(self, record: Record) -> None:
     result = wandb_internal_pb2.Result(uuid=record.uuid)
     for key, value in six.iteritems(self._consolidated_summary):
         item = wandb_internal_pb2.SummaryItem()
         item.key = key
         item.value_json = json.dumps(value)
         result.response.get_summary_response.item.append(item)
     self._result_q.put(result)
Ejemplo n.º 7
0
 def send_request_check_version(self, record):
     assert record.control.req_resp
     result = wandb_internal_pb2.Result(uuid=record.uuid)
     current_version = wandb.__version__
     message = update.check_available(current_version)
     if message:
         result.response.check_version_response.message = message
     self._result_q.put(result)
Ejemplo n.º 8
0
 def _communicate(self, rec, timeout=5, local=False):
     resp = wandb_internal_pb2.Result()
     record_type = rec.WhichOneof("record_type")
     if record_type == "request":
         req = rec.request
         req_type = req.WhichOneof("request_type")
         if req_type == "poll_exit":
             resp.response.poll_exit_response.done = True
     return resp
Ejemplo n.º 9
0
 def handle_request_login(self, data):
     # TODO: do something with api_key or anonymous?
     # TODO: return an error if we aren't logged in?
     viewer = self._api.viewer()
     self._flags = json.loads(viewer.get("flags", "{}"))
     self._entity = viewer.get("entity")
     if data.control.req_resp:
         result = wandb_internal_pb2.Result(uuid=data.uuid)
         result.response.login_response.active_entity = self._entity
         self._resp_q.put(result)
Ejemplo n.º 10
0
 def handle_request_sampled_history(self, record: Record) -> None:
     result = wandb_internal_pb2.Result(uuid=record.uuid)
     for key, sampled in six.iteritems(self._sampled_history):
         item = wandb_internal_pb2.SampledHistoryItem()
         item.key = key
         values: Iterable[Any] = sampled.get()
         if all(isinstance(i, numbers.Integral) for i in values):
             item.values_int.extend(values)
         elif all(isinstance(i, numbers.Real) for i in values):
             item.values_float.extend(values)
         result.response.sampled_history_response.item.append(item)
     self._result_q.put(result)
Ejemplo n.º 11
0
    def send_request_stop_status(self, record):
        assert record.control.req_resp

        result = wandb_internal_pb2.Result(uuid=record.uuid)
        status_resp = result.response.stop_status_response
        status_resp.run_should_stop = False
        if self._entity and self._project and self._run.run_id:
            try:
                status_resp.run_should_stop = self._api.check_stop_requested(
                    self._project, self._entity, self._run.run_id)
            except Exception as e:
                logger.warning("Failed to check stop requested status: %s", e)
        self._result_q.put(result)
Ejemplo n.º 12
0
    def send_request_log_artifact(self, record):
        assert record.control.req_resp
        result = wandb_internal_pb2.Result(uuid=record.uuid)
        artifact = record.request.log_artifact.artifact

        try:
            result.response.log_artifact_response.artifact_id = self._send_artifact(
                artifact).get("id")
        except Exception as e:
            result.response.log_artifact_response.error_message = 'error logging artifact "{}/{}": {}'.format(
                artifact.type, artifact.name, e)

        self._result_q.put(result)
Ejemplo n.º 13
0
 def send_request_check_version(self, record):
     assert record.control.req_resp
     result = wandb_internal_pb2.Result(uuid=record.uuid)
     current_version = wandb.__version__
     messages = update.check_available(current_version)
     if messages:
         result.response.check_version_response.upgrade_message = messages[
             "upgrade_message"]
         result.response.check_version_response.yank_message = messages[
             "yank_message"]
         result.response.check_version_response.delete_message = messages[
             "delete_message"]
     self._result_q.put(result)
Ejemplo n.º 14
0
    def send_request_network_status(self, record):
        assert record.control.req_resp

        result = wandb_internal_pb2.Result(uuid=record.uuid)
        status_resp = result.response.network_status_response
        while True:
            try:
                status_resp.network_responses.append(self._retry_q.get_nowait())
            except queue.Empty:
                break
            except Exception as e:
                logger.warning("Error emptying retry queue: {}".format(e))
        self._result_q.put(result)
Ejemplo n.º 15
0
 def send_request_login(self, record):
     # TODO: do something with api_key or anonymous?
     # TODO: return an error if we aren't logged in?
     self._api.reauth()
     viewer_tuple = self._api.viewer_server_info()
     # self._login_flags = json.loads(viewer.get("flags", "{}"))
     # self._login_entity = viewer.get("entity")
     viewer, server_info = viewer_tuple
     if server_info:
         logger.info("Login server info: {}".format(server_info))
     self._entity = viewer.get("entity")
     if record.control.req_resp:
         result = wandb_internal_pb2.Result(uuid=record.uuid)
         if self._entity:
             result.response.login_response.active_entity = self._entity
         self._result_q.put(result)
Ejemplo n.º 16
0
    def handle_request_status(self, data):
        if not data.control.req_resp:
            return

        result = wandb_internal_pb2.Result(uuid=data.uuid)
        status_resp = result.response.status_response
        if data.request.status.check_stop_req:
            status_resp.run_should_stop = False
            if self._entity and self._project and self._run.run_id:
                try:
                    status_resp.run_should_stop = self._api.check_stop_requested(
                        self._project, self._entity, self._run.run_id)
                except Exception as e:
                    logger.warning("Failed to check stop requested status: %s",
                                   e)
        self._resp_q.put(result)
Ejemplo n.º 17
0
    def _future_poll_artifact(self, xid: str) -> Optional[pb.Result]:
        art_poll = pb.ArtifactPollRequest(xid=xid)

        assert self._stub
        self._assign(art_poll)
        art_poll_resp = self._stub.ArtifactPoll(art_poll)

        if not art_poll_resp.ready:
            return None

        # emulate log_artifact response for old _communicate_artifact() protocol
        result = pb.Result()
        result.response.log_artifact_response.artifact_id = art_poll_resp.artifact_id
        result.response.log_artifact_response.error_message = (
            art_poll_resp.error_message)
        return result
Ejemplo n.º 18
0
 def send_request_check_version(self, record):
     assert record.control.req_resp
     result = wandb_internal_pb2.Result(uuid=record.uuid)
     current_version = (record.request.check_version.current_version
                        or wandb.__version__)
     messages = update.check_available(current_version)
     if messages:
         upgrade_message = messages.get("upgrade_message")
         if upgrade_message:
             result.response.check_version_response.upgrade_message = upgrade_message
         yank_message = messages.get("yank_message")
         if yank_message:
             result.response.check_version_response.yank_message = yank_message
         delete_message = messages.get("delete_message")
         if delete_message:
             result.response.check_version_response.delete_message = delete_message
     self._result_q.put(result)
Ejemplo n.º 19
0
 def handle_request_shutdown(self, record: Record) -> None:
     # TODO(jhr): should we drain things and stop new requests from coming in?
     result = wandb_internal_pb2.Result(uuid=record.uuid)
     self._result_q.put(result)
     self._stopped.set()
Ejemplo n.º 20
0
def _result_from_record(record: "pb.Record") -> "pb.Result":
    result = pb.Result(uuid=record.uuid, control=record.control)
    return result
Ejemplo n.º 21
0
    def handle_run(self, data):
        run = data.run
        run_tags = run.tags[:]
        error = None
        is_wandb_init = self._run is None

        # build config dict
        config_dict = None
        if run.HasField("config"):
            config_dict = _config_dict_from_proto_list(run.config.update)
            config_path = os.path.join(self._settings.files_dir, CONFIG_FNAME)
            save_config_file_from_dict(config_path, config_dict)

        repo = GitRepo(remote=self._settings.git_remote)

        if is_wandb_init:
            # Only check resume status on `wandb.init`
            error = self._maybe_setup_resume(run)

        if error is not None:
            if data.control.req_resp:
                resp = wandb_internal_pb2.Result(uuid=data.uuid)
                resp.run_result.run.CopyFrom(run)
                resp.run_result.error.CopyFrom(error)
                self._resp_q.put(resp)
            else:
                logger.error("Got error in async mode: %s", error.message)
            return

        # TODO: we don't check inserted currently, ultimately we should make
        # the upsert know the resume state and fail transactionally
        ups, inserted = self._api.upsert_run(
            name=run.run_id,
            entity=run.entity or None,
            project=run.project or None,
            group=run.run_group or None,
            job_type=run.job_type or None,
            display_name=run.display_name or None,
            notes=run.notes or None,
            tags=run_tags or None,
            config=config_dict or None,
            sweep_name=run.sweep_id or None,
            host=run.host or None,
            program_path=self._settings.program or None,
            repo=repo.remote_url,
            commit=repo.last_commit,
        )

        # We subtract the previous runs runtime when resuming
        start_time = run.start_time.ToSeconds() - self._offsets["runtime"]
        self._run = run
        self._run.starting_step = self._offsets["step"]
        self._run.start_time.FromSeconds(start_time)
        storage_id = ups.get("id")
        if storage_id:
            self._run.storage_id = storage_id
        display_name = ups.get("displayName")
        if display_name:
            self._run.display_name = display_name
        project = ups.get("project")
        if project:
            project_name = project.get("name")
            if project_name:
                self._run.project = project_name
                self._project = project_name
            entity = project.get("entity")
            if entity:
                entity_name = entity.get("name")
                if entity_name:
                    self._run.entity = entity_name
                    self._entity = entity_name

        if data.control.req_resp:
            resp = wandb_internal_pb2.Result(uuid=data.uuid)
            resp.run_result.run.CopyFrom(self._run)
            self._resp_q.put(resp)

        if self._entity is not None:
            self._api_settings["entity"] = self._entity
        if self._project is not None:
            self._api_settings["project"] = self._project

        # Only spin up our threads on the first run message
        if is_wandb_init:
            self._fs = file_stream.FileStreamApi(self._api,
                                                 run.run_id,
                                                 start_time,
                                                 settings=self._api_settings)
            # Ensure the streaming polices have the proper offsets
            self._fs.set_file_policy("wandb-summary.json",
                                     file_stream.SummaryFilePolicy())
            self._fs.set_file_policy(
                "wandb-history.jsonl",
                file_stream.JsonlFilePolicy(
                    start_chunk_id=self._offsets["history"]),
            )
            self._fs.set_file_policy(
                "wandb-events.jsonl",
                file_stream.JsonlFilePolicy(
                    start_chunk_id=self._offsets["events"]),
            )
            self._fs.set_file_policy(
                "output.log",
                file_stream.CRDedupeFilePolicy(
                    start_chunk_id=self._offsets["output"]),
            )
            self._fs.start()
            self._pusher = FilePusher(self._api)
            self._dir_watcher = DirWatcher(self._settings, self._api,
                                           self._pusher)
            self._tb_watcher = tb_watcher.TBWatcher(self._settings,
                                                    sender=self)
            if self._run_meta:
                self._run_meta.write()
            sentry_set_scope("internal", run.entity, run.project)
            logger.info("run started: %s with start time %s", self._run.run_id,
                        start_time)
        else:
            logger.info("updated run: %s", self._run.run_id)
Ejemplo n.º 22
0
 def _request_response(self, rec, timeout=5):
     resp = wandb_internal_pb2.Result()
     return resp
Ejemplo n.º 23
0
    def send_run(self, data, file_dir=None) -> None:
        run = data.run
        error = None
        is_wandb_init = self._run is None

        # update telemetry
        if run.telemetry:
            self._telemetry_obj.MergeFrom(run.telemetry)

        # build config dict
        config_value_dict: Optional[DictWithValues] = None
        if run.config:
            config_util.update_from_proto(self._consolidated_config,
                                          run.config)
            config_value_dict = self._config_format(self._consolidated_config)
            self._config_save(config_value_dict)

        if is_wandb_init:
            # Ensure we have a project to query for status
            if run.project == "":
                run.project = util.auto_project_name(self._settings.program)
            # Only check resume status on `wandb.init`
            error = self._maybe_setup_resume(run)

        if error is not None:
            if data.control.req_resp:
                resp = wandb_internal_pb2.Result(uuid=data.uuid)
                resp.run_result.run.CopyFrom(run)
                resp.run_result.error.CopyFrom(error)
                self._result_q.put(resp)
            else:
                logger.error("Got error in async mode: %s", error.message)
            return

        # Save the resumed config
        if self._resume_state["config"] is not None:
            # TODO: should we merge this with resumed config?
            config_override = self._consolidated_config
            config_dict = self._resume_state["config"]
            config_dict = config_util.dict_strip_value_dict(config_dict)
            config_dict.update(config_override)
            self._consolidated_config.update(config_dict)
            config_value_dict = self._config_format(self._consolidated_config)
            self._config_save(config_value_dict)

        # handle empty config
        # TODO(jhr): consolidate the 4 ways config is built:
        #            (passed config, empty config, resume config, send_config)
        if not config_value_dict:
            config_value_dict = self._config_format(None)
            self._config_save(config_value_dict)

        self._init_run(run, config_value_dict)

        if data.control.req_resp:
            resp = wandb_internal_pb2.Result(uuid=data.uuid)
            # TODO: we could do self._interface.publish_defer(resp) to notify
            # the handler not to actually perform server updates for this uuid
            # because the user process will send a summary update when we resume
            resp.run_result.run.CopyFrom(self._run)
            self._result_q.put(resp)

        # Only spin up our threads on the first run message
        if is_wandb_init:
            self._start_run_threads(file_dir)
        else:
            logger.info("updated run: %s", self._run.run_id)
Ejemplo n.º 24
0
    def send_request_defer(self, data):
        defer = data.request.defer
        state = defer.state
        logger.info("handle sender defer: {}".format(state))

        def transition_state():
            state = defer.state + 1
            logger.info("send defer: {}".format(state))
            self._interface.publish_defer(state)

        done = False
        if state == defer.BEGIN:
            transition_state()
        elif state == defer.FLUSH_STATS:
            # NOTE: this is handled in handler.py:handle_request_defer()
            transition_state()
        elif state == defer.FLUSH_TB:
            # NOTE: this is handled in handler.py:handle_request_defer()
            transition_state()
        elif state == defer.FLUSH_SUM:
            # NOTE: this is handled in handler.py:handle_request_defer()
            transition_state()
        elif state == defer.FLUSH_DEBOUNCER:
            self.debounce()
            transition_state()
        elif state == defer.FLUSH_DIR:
            if self._dir_watcher:
                self._dir_watcher.finish()
                self._dir_watcher = None
            transition_state()
        elif state == defer.FLUSH_FP:
            if self._pusher:
                # FilePusher generates some events for FileStreamApi, so we
                # need to wait for pusher to finish before going to the next
                # state to ensure that filestream gets all the events that we
                # want before telling it to finish up
                self._pusher.finish(transition_state)
            else:
                transition_state()
        elif state == defer.FLUSH_FS:
            if self._fs:
                # TODO(jhr): now is a good time to output pending output lines
                self._fs.finish(self._exit_code)
                self._fs = None
            transition_state()
        elif state == defer.FLUSH_FINAL:
            self._interface.publish_final()
            self._interface.publish_footer()
            transition_state()
        elif state == defer.END:
            done = True
        else:
            raise AssertionError("unknown state")

        if not done:
            return

        exit_result = wandb_internal_pb2.RunExitResult()

        # This path is not the prefered method to return exit results
        # as it could take a long time to flush the file pusher buffers
        if self._exit_sync_uuid:
            if self._pusher:
                # NOTE: This will block until finished
                self._pusher.print_status()
                self._pusher.join()
                self._pusher = None
            resp = wandb_internal_pb2.Result(exit_result=exit_result,
                                             uuid=self._exit_sync_uuid)
            self._result_q.put(resp)

        # mark exit done in case we are polling on exit
        self._exit_result = exit_result