def send_run(self, data): run = data.run error = None is_wandb_init = self._run is None # build config dict config_dict = None config_path = os.path.join(self._settings.files_dir, filenames.CONFIG_FNAME) if run.config: config_dict = config_util.dict_from_proto_list(run.config.update) config_util.save_config_file_from_dict(config_path, config_dict) if is_wandb_init: # Ensure we have a project to query for status if run.project == "": run.project = util.auto_project_name(self._settings.program) # Only check resume status on `wandb.init` error = self._maybe_setup_resume(run) if error is not None: if data.control.req_resp: resp = wandb_internal_pb2.Result(uuid=data.uuid) resp.run_result.run.CopyFrom(run) resp.run_result.error.CopyFrom(error) self._result_q.put(resp) else: logger.error("Got error in async mode: %s", error.message) return # Save the resumed config if self._resume_state["config"] is not None: # TODO: should we merge this with resumed config? config_override = config_dict or {} config_dict = self._resume_state["config"] config_dict.update(config_override) config_util.save_config_file_from_dict(config_path, config_dict) self._init_run(run, config_dict) if data.control.req_resp: resp = wandb_internal_pb2.Result(uuid=data.uuid) # TODO: we could do self._interface.publish_defer(resp) to notify # the handler not to actually perform server updates for this uuid # because the user process will send a summary update when we resume resp.run_result.run.CopyFrom(self._run) self._result_q.put(resp) # Only spin up our threads on the first run message if is_wandb_init: self._start_run_threads() else: logger.info("updated run: %s", self._run.run_id)
def send_request_defer(self, data): defer = data.request.defer state = defer.state logger.info("handle sender defer: {}".format(state)) done = False if state == defer.BEGIN: pass elif state == defer.FLUSH_STATS: # NOTE: this is handled in handler.py:handle_request_defer() pass elif state == defer.FLUSH_TB: # NOTE: this is handled in handler.py:handle_request_defer() pass elif state == defer.FLUSH_SUM: # NOTE: this is handled in handler.py:handle_request_defer() pass elif state == defer.FLUSH_DIR: if self._dir_watcher: self._dir_watcher.finish() self._dir_watcher = None elif state == defer.FLUSH_FP: if self._pusher: self._pusher.finish() elif state == defer.FLUSH_FS: if self._fs: # TODO(jhr): now is a good time to output pending output lines self._fs.finish(self._exit_code) self._fs = None elif state == defer.FLUSH_FINAL: self._interface.publish_final() self._interface.publish_footer() elif state == defer.END: done = True else: raise AssertionError("unknown state") if not done: state += 1 logger.info("send defer: {}".format(state)) self._interface.publish_defer(state) return exit_result = wandb_internal_pb2.RunExitResult() # This path is not the prefered method to return exit results # as it could take a long time to flush the file pusher buffers if self._exit_sync_uuid: if self._pusher: # NOTE: This will block until finished self._pusher.print_status() self._pusher.join() self._pusher = None resp = wandb_internal_pb2.Result( exit_result=exit_result, uuid=self._exit_sync_uuid ) self._result_q.put(resp) # mark exit done in case we are polling on exit self._exit_result = exit_result
def handle_request_defer(self, data): logger.info("handle defer") if self._dir_watcher: self._dir_watcher.finish() self._dir_watcher = None if self._pusher: self._pusher.finish() if self._fs: # TODO(jhr): now is a good time to output pending output lines self._fs.finish(self._exit_code) self._fs = None # NB: assume we always need to send a response for this message # since it was sent on behalf of handle_exit() req/resp logic resp = wandb_internal_pb2.Result(uuid=data.uuid) file_counts = self._pusher.file_counts_by_category() resp.exit_result.files.wandb_count = file_counts["wandb"] resp.exit_result.files.media_count = file_counts["media"] resp.exit_result.files.artifact_count = file_counts["artifact"] resp.exit_result.files.other_count = file_counts["other"] self._resp_q.put(resp) # TODO(david): this info should be in exit_result footer? if self._pusher: self._pusher.print_status() self._pusher = None
def send_request_poll_exit(self, record): if not record.control.req_resp: return result = wandb_internal_pb2.Result(uuid=record.uuid) alive = False if self._pusher: alive, status = self._pusher.get_status() file_counts = self._pusher.file_counts_by_category() resp = result.response.poll_exit_response resp.pusher_stats.uploaded_bytes = status["uploaded_bytes"] resp.pusher_stats.total_bytes = status["total_bytes"] resp.pusher_stats.deduped_bytes = status["deduped_bytes"] resp.file_counts.wandb_count = file_counts["wandb"] resp.file_counts.media_count = file_counts["media"] resp.file_counts.artifact_count = file_counts["artifact"] resp.file_counts.other_count = file_counts["other"] if self._exit_result and not alive: # pusher join should not block as it was reported as not alive if self._pusher: self._pusher.join() result.response.poll_exit_response.exit_result.CopyFrom( self._exit_result) result.response.poll_exit_response.done = True self._result_q.put(result)
def handle_request_run_start(self, record): run_start = record.request.run_start assert run_start assert run_start.run if not self._settings._disable_stats: pid = os.getpid() self._system_stats = stats.SystemStats(pid=pid, interface=self._interface) self._system_stats.start() if not self._settings._disable_meta: run_meta = meta.Meta(settings=self._settings, interface=self._interface) run_meta.probe() run_meta.write() self._tb_watcher = tb_watcher.TBWatcher(self._settings, interface=self._interface, run_proto=run_start.run) if run_start.run.resumed: self._step = run_start.run.starting_step result = wandb_internal_pb2.Result(uuid=record.uuid) self._result_q.put(result)
def handle_request_get_summary(self, record: Record) -> None: result = wandb_internal_pb2.Result(uuid=record.uuid) for key, value in six.iteritems(self._consolidated_summary): item = wandb_internal_pb2.SummaryItem() item.key = key item.value_json = json.dumps(value) result.response.get_summary_response.item.append(item) self._result_q.put(result)
def send_request_check_version(self, record): assert record.control.req_resp result = wandb_internal_pb2.Result(uuid=record.uuid) current_version = wandb.__version__ message = update.check_available(current_version) if message: result.response.check_version_response.message = message self._result_q.put(result)
def _communicate(self, rec, timeout=5, local=False): resp = wandb_internal_pb2.Result() record_type = rec.WhichOneof("record_type") if record_type == "request": req = rec.request req_type = req.WhichOneof("request_type") if req_type == "poll_exit": resp.response.poll_exit_response.done = True return resp
def handle_request_login(self, data): # TODO: do something with api_key or anonymous? # TODO: return an error if we aren't logged in? viewer = self._api.viewer() self._flags = json.loads(viewer.get("flags", "{}")) self._entity = viewer.get("entity") if data.control.req_resp: result = wandb_internal_pb2.Result(uuid=data.uuid) result.response.login_response.active_entity = self._entity self._resp_q.put(result)
def handle_request_sampled_history(self, record: Record) -> None: result = wandb_internal_pb2.Result(uuid=record.uuid) for key, sampled in six.iteritems(self._sampled_history): item = wandb_internal_pb2.SampledHistoryItem() item.key = key values: Iterable[Any] = sampled.get() if all(isinstance(i, numbers.Integral) for i in values): item.values_int.extend(values) elif all(isinstance(i, numbers.Real) for i in values): item.values_float.extend(values) result.response.sampled_history_response.item.append(item) self._result_q.put(result)
def send_request_stop_status(self, record): assert record.control.req_resp result = wandb_internal_pb2.Result(uuid=record.uuid) status_resp = result.response.stop_status_response status_resp.run_should_stop = False if self._entity and self._project and self._run.run_id: try: status_resp.run_should_stop = self._api.check_stop_requested( self._project, self._entity, self._run.run_id) except Exception as e: logger.warning("Failed to check stop requested status: %s", e) self._result_q.put(result)
def send_request_log_artifact(self, record): assert record.control.req_resp result = wandb_internal_pb2.Result(uuid=record.uuid) artifact = record.request.log_artifact.artifact try: result.response.log_artifact_response.artifact_id = self._send_artifact( artifact).get("id") except Exception as e: result.response.log_artifact_response.error_message = 'error logging artifact "{}/{}": {}'.format( artifact.type, artifact.name, e) self._result_q.put(result)
def send_request_check_version(self, record): assert record.control.req_resp result = wandb_internal_pb2.Result(uuid=record.uuid) current_version = wandb.__version__ messages = update.check_available(current_version) if messages: result.response.check_version_response.upgrade_message = messages[ "upgrade_message"] result.response.check_version_response.yank_message = messages[ "yank_message"] result.response.check_version_response.delete_message = messages[ "delete_message"] self._result_q.put(result)
def send_request_network_status(self, record): assert record.control.req_resp result = wandb_internal_pb2.Result(uuid=record.uuid) status_resp = result.response.network_status_response while True: try: status_resp.network_responses.append(self._retry_q.get_nowait()) except queue.Empty: break except Exception as e: logger.warning("Error emptying retry queue: {}".format(e)) self._result_q.put(result)
def send_request_login(self, record): # TODO: do something with api_key or anonymous? # TODO: return an error if we aren't logged in? self._api.reauth() viewer_tuple = self._api.viewer_server_info() # self._login_flags = json.loads(viewer.get("flags", "{}")) # self._login_entity = viewer.get("entity") viewer, server_info = viewer_tuple if server_info: logger.info("Login server info: {}".format(server_info)) self._entity = viewer.get("entity") if record.control.req_resp: result = wandb_internal_pb2.Result(uuid=record.uuid) if self._entity: result.response.login_response.active_entity = self._entity self._result_q.put(result)
def handle_request_status(self, data): if not data.control.req_resp: return result = wandb_internal_pb2.Result(uuid=data.uuid) status_resp = result.response.status_response if data.request.status.check_stop_req: status_resp.run_should_stop = False if self._entity and self._project and self._run.run_id: try: status_resp.run_should_stop = self._api.check_stop_requested( self._project, self._entity, self._run.run_id) except Exception as e: logger.warning("Failed to check stop requested status: %s", e) self._resp_q.put(result)
def _future_poll_artifact(self, xid: str) -> Optional[pb.Result]: art_poll = pb.ArtifactPollRequest(xid=xid) assert self._stub self._assign(art_poll) art_poll_resp = self._stub.ArtifactPoll(art_poll) if not art_poll_resp.ready: return None # emulate log_artifact response for old _communicate_artifact() protocol result = pb.Result() result.response.log_artifact_response.artifact_id = art_poll_resp.artifact_id result.response.log_artifact_response.error_message = ( art_poll_resp.error_message) return result
def send_request_check_version(self, record): assert record.control.req_resp result = wandb_internal_pb2.Result(uuid=record.uuid) current_version = (record.request.check_version.current_version or wandb.__version__) messages = update.check_available(current_version) if messages: upgrade_message = messages.get("upgrade_message") if upgrade_message: result.response.check_version_response.upgrade_message = upgrade_message yank_message = messages.get("yank_message") if yank_message: result.response.check_version_response.yank_message = yank_message delete_message = messages.get("delete_message") if delete_message: result.response.check_version_response.delete_message = delete_message self._result_q.put(result)
def handle_request_shutdown(self, record: Record) -> None: # TODO(jhr): should we drain things and stop new requests from coming in? result = wandb_internal_pb2.Result(uuid=record.uuid) self._result_q.put(result) self._stopped.set()
def _result_from_record(record: "pb.Record") -> "pb.Result": result = pb.Result(uuid=record.uuid, control=record.control) return result
def handle_run(self, data): run = data.run run_tags = run.tags[:] error = None is_wandb_init = self._run is None # build config dict config_dict = None if run.HasField("config"): config_dict = _config_dict_from_proto_list(run.config.update) config_path = os.path.join(self._settings.files_dir, CONFIG_FNAME) save_config_file_from_dict(config_path, config_dict) repo = GitRepo(remote=self._settings.git_remote) if is_wandb_init: # Only check resume status on `wandb.init` error = self._maybe_setup_resume(run) if error is not None: if data.control.req_resp: resp = wandb_internal_pb2.Result(uuid=data.uuid) resp.run_result.run.CopyFrom(run) resp.run_result.error.CopyFrom(error) self._resp_q.put(resp) else: logger.error("Got error in async mode: %s", error.message) return # TODO: we don't check inserted currently, ultimately we should make # the upsert know the resume state and fail transactionally ups, inserted = self._api.upsert_run( name=run.run_id, entity=run.entity or None, project=run.project or None, group=run.run_group or None, job_type=run.job_type or None, display_name=run.display_name or None, notes=run.notes or None, tags=run_tags or None, config=config_dict or None, sweep_name=run.sweep_id or None, host=run.host or None, program_path=self._settings.program or None, repo=repo.remote_url, commit=repo.last_commit, ) # We subtract the previous runs runtime when resuming start_time = run.start_time.ToSeconds() - self._offsets["runtime"] self._run = run self._run.starting_step = self._offsets["step"] self._run.start_time.FromSeconds(start_time) storage_id = ups.get("id") if storage_id: self._run.storage_id = storage_id display_name = ups.get("displayName") if display_name: self._run.display_name = display_name project = ups.get("project") if project: project_name = project.get("name") if project_name: self._run.project = project_name self._project = project_name entity = project.get("entity") if entity: entity_name = entity.get("name") if entity_name: self._run.entity = entity_name self._entity = entity_name if data.control.req_resp: resp = wandb_internal_pb2.Result(uuid=data.uuid) resp.run_result.run.CopyFrom(self._run) self._resp_q.put(resp) if self._entity is not None: self._api_settings["entity"] = self._entity if self._project is not None: self._api_settings["project"] = self._project # Only spin up our threads on the first run message if is_wandb_init: self._fs = file_stream.FileStreamApi(self._api, run.run_id, start_time, settings=self._api_settings) # Ensure the streaming polices have the proper offsets self._fs.set_file_policy("wandb-summary.json", file_stream.SummaryFilePolicy()) self._fs.set_file_policy( "wandb-history.jsonl", file_stream.JsonlFilePolicy( start_chunk_id=self._offsets["history"]), ) self._fs.set_file_policy( "wandb-events.jsonl", file_stream.JsonlFilePolicy( start_chunk_id=self._offsets["events"]), ) self._fs.set_file_policy( "output.log", file_stream.CRDedupeFilePolicy( start_chunk_id=self._offsets["output"]), ) self._fs.start() self._pusher = FilePusher(self._api) self._dir_watcher = DirWatcher(self._settings, self._api, self._pusher) self._tb_watcher = tb_watcher.TBWatcher(self._settings, sender=self) if self._run_meta: self._run_meta.write() sentry_set_scope("internal", run.entity, run.project) logger.info("run started: %s with start time %s", self._run.run_id, start_time) else: logger.info("updated run: %s", self._run.run_id)
def _request_response(self, rec, timeout=5): resp = wandb_internal_pb2.Result() return resp
def send_run(self, data, file_dir=None) -> None: run = data.run error = None is_wandb_init = self._run is None # update telemetry if run.telemetry: self._telemetry_obj.MergeFrom(run.telemetry) # build config dict config_value_dict: Optional[DictWithValues] = None if run.config: config_util.update_from_proto(self._consolidated_config, run.config) config_value_dict = self._config_format(self._consolidated_config) self._config_save(config_value_dict) if is_wandb_init: # Ensure we have a project to query for status if run.project == "": run.project = util.auto_project_name(self._settings.program) # Only check resume status on `wandb.init` error = self._maybe_setup_resume(run) if error is not None: if data.control.req_resp: resp = wandb_internal_pb2.Result(uuid=data.uuid) resp.run_result.run.CopyFrom(run) resp.run_result.error.CopyFrom(error) self._result_q.put(resp) else: logger.error("Got error in async mode: %s", error.message) return # Save the resumed config if self._resume_state["config"] is not None: # TODO: should we merge this with resumed config? config_override = self._consolidated_config config_dict = self._resume_state["config"] config_dict = config_util.dict_strip_value_dict(config_dict) config_dict.update(config_override) self._consolidated_config.update(config_dict) config_value_dict = self._config_format(self._consolidated_config) self._config_save(config_value_dict) # handle empty config # TODO(jhr): consolidate the 4 ways config is built: # (passed config, empty config, resume config, send_config) if not config_value_dict: config_value_dict = self._config_format(None) self._config_save(config_value_dict) self._init_run(run, config_value_dict) if data.control.req_resp: resp = wandb_internal_pb2.Result(uuid=data.uuid) # TODO: we could do self._interface.publish_defer(resp) to notify # the handler not to actually perform server updates for this uuid # because the user process will send a summary update when we resume resp.run_result.run.CopyFrom(self._run) self._result_q.put(resp) # Only spin up our threads on the first run message if is_wandb_init: self._start_run_threads(file_dir) else: logger.info("updated run: %s", self._run.run_id)
def send_request_defer(self, data): defer = data.request.defer state = defer.state logger.info("handle sender defer: {}".format(state)) def transition_state(): state = defer.state + 1 logger.info("send defer: {}".format(state)) self._interface.publish_defer(state) done = False if state == defer.BEGIN: transition_state() elif state == defer.FLUSH_STATS: # NOTE: this is handled in handler.py:handle_request_defer() transition_state() elif state == defer.FLUSH_TB: # NOTE: this is handled in handler.py:handle_request_defer() transition_state() elif state == defer.FLUSH_SUM: # NOTE: this is handled in handler.py:handle_request_defer() transition_state() elif state == defer.FLUSH_DEBOUNCER: self.debounce() transition_state() elif state == defer.FLUSH_DIR: if self._dir_watcher: self._dir_watcher.finish() self._dir_watcher = None transition_state() elif state == defer.FLUSH_FP: if self._pusher: # FilePusher generates some events for FileStreamApi, so we # need to wait for pusher to finish before going to the next # state to ensure that filestream gets all the events that we # want before telling it to finish up self._pusher.finish(transition_state) else: transition_state() elif state == defer.FLUSH_FS: if self._fs: # TODO(jhr): now is a good time to output pending output lines self._fs.finish(self._exit_code) self._fs = None transition_state() elif state == defer.FLUSH_FINAL: self._interface.publish_final() self._interface.publish_footer() transition_state() elif state == defer.END: done = True else: raise AssertionError("unknown state") if not done: return exit_result = wandb_internal_pb2.RunExitResult() # This path is not the prefered method to return exit results # as it could take a long time to flush the file pusher buffers if self._exit_sync_uuid: if self._pusher: # NOTE: This will block until finished self._pusher.print_status() self._pusher.join() self._pusher = None resp = wandb_internal_pb2.Result(exit_result=exit_result, uuid=self._exit_sync_uuid) self._result_q.put(resp) # mark exit done in case we are polling on exit self._exit_result = exit_result