Beispiel #1
0
 def _start_run_threads(self):
     self._fs = file_stream.FileStreamApi(
         self._api,
         self._run.run_id,
         self._run.start_time.ToSeconds(),
         settings=self._api_settings,
     )
     # Ensure the streaming polices have the proper offsets
     self._fs.set_file_policy("wandb-summary.json", file_stream.SummaryFilePolicy())
     self._fs.set_file_policy(
         "wandb-history.jsonl",
         file_stream.JsonlFilePolicy(start_chunk_id=self._resume_state["history"]),
     )
     self._fs.set_file_policy(
         "wandb-events.jsonl",
         file_stream.JsonlFilePolicy(start_chunk_id=self._resume_state["events"]),
     )
     self._fs.set_file_policy(
         "output.log",
         file_stream.CRDedupeFilePolicy(start_chunk_id=self._resume_state["output"]),
     )
     util.sentry_set_scope(
         "internal",
         entity=self._run.entity,
         project=self._run.project,
         email=self._settings.email,
     )
     self._fs.start()
     self._pusher = FilePusher(self._api, silent=self._settings.silent)
     self._dir_watcher = DirWatcher(self._settings, self._api, self._pusher)
     logger.info(
         "run started: %s with start time %s",
         self._run.run_id,
         self._run.start_time.ToSeconds(),
     )
Beispiel #2
0
class SendManager(object):

    _telemetry_obj: telemetry.TelemetryRecord

    def __init__(
        self,
        settings,
        record_q,
        result_q,
        interface,
    ):
        self._settings = settings
        self._record_q = record_q
        self._result_q = result_q
        self._interface = interface

        self._fs = None
        self._pusher = None
        self._dir_watcher = None

        # State updated by login
        self._entity = None
        self._flags = None

        # State updated by wandb.init
        self._run = None
        self._project = None

        # keep track of config from key/val updates
        self._consolidated_config: DictNoValues = dict()
        self._telemetry_obj = telemetry.TelemetryRecord()
        self._config_metric_pbdict_list: List[Dict[int, Any]] = []
        self._config_metric_index_dict: Dict[str, int] = {}
        self._config_metric_dict: Dict[str,
                                       wandb_internal_pb2.MetricRecord] = {}

        # State updated by resuming
        self._resume_state = {
            "step": 0,
            "history": 0,
            "events": 0,
            "output": 0,
            "runtime": 0,
            "summary": None,
            "config": None,
            "resumed": False,
        }

        # State added when run_exit needs results
        self._exit_sync_uuid = None

        # State added when run_exit is complete
        self._exit_result = None

        self._api = internal_api.Api(default_settings=settings,
                                     retry_callback=self.retry_callback)
        self._api_settings = dict()

        # queue filled by retry_callback
        self._retry_q: "Queue[HttpResponse]" = queue.Queue()

        # do we need to debounce?
        self._config_needs_debounce: bool = False

        # TODO(jhr): do something better, why do we need to send full lines?
        self._partial_output = dict()

        self._exit_code = 0

    @classmethod
    def setup(cls, root_dir):
        """This is a helper class method to setup a standalone SendManager.
        Currently we're using this primarily for `sync.py`.
        """
        files_dir = os.path.join(root_dir, "files")
        sd = dict(
            files_dir=files_dir,
            root_dir=root_dir,
            _start_time=0,
            git_remote=None,
            resume=None,
            program=None,
            ignore_globs=(),
            run_id=None,
            entity=None,
            project=None,
            run_group=None,
            job_type=None,
            run_tags=None,
            run_name=None,
            run_notes=None,
            save_code=None,
            email=None,
            silent=None,
        )
        settings = settings_static.SettingsStatic(sd)
        record_q = queue.Queue()
        result_q = queue.Queue()
        publish_interface = interface.BackendSender(record_q=record_q)
        return SendManager(
            settings=settings,
            record_q=record_q,
            result_q=result_q,
            interface=publish_interface,
        )

    def __len__(self):
        return self._record_q.qsize()

    def retry_callback(self, status, response_text):
        response = wandb_internal_pb2.HttpResponse()
        response.http_status_code = status
        response.http_response_text = response_text
        self._retry_q.put(response)

    def send(self, record):
        record_type = record.WhichOneof("record_type")
        assert record_type
        handler_str = "send_" + record_type
        send_handler = getattr(self, handler_str, None)
        # Don't log output to reduce log noise
        if record_type not in {"output", "request"}:
            logger.debug("send: {}".format(record_type))
        assert send_handler, "unknown send handler: {}".format(handler_str)
        send_handler(record)

    def send_preempting(self, record):
        if self._fs:
            self._fs.enqueue_preempting()

    def send_request(self, record):
        request_type = record.request.WhichOneof("request_type")
        assert request_type
        handler_str = "send_request_" + request_type
        send_handler = getattr(self, handler_str, None)
        if request_type != "network_status":
            logger.debug("send_request: {}".format(request_type))
        assert send_handler, "unknown handle: {}".format(handler_str)
        send_handler(record)

    def _flatten(self, dictionary):
        if type(dictionary) == dict:
            for k, v in list(dictionary.items()):
                if type(v) == dict:
                    self._flatten(v)
                    dictionary.pop(k)
                    for k2, v2 in v.items():
                        dictionary[k + "." + k2] = v2

    def send_request_check_version(self, record):
        assert record.control.req_resp
        result = wandb_internal_pb2.Result(uuid=record.uuid)
        current_version = (record.request.check_version.current_version
                           or wandb.__version__)
        messages = update.check_available(current_version)
        if messages:
            upgrade_message = messages.get("upgrade_message")
            if upgrade_message:
                result.response.check_version_response.upgrade_message = upgrade_message
            yank_message = messages.get("yank_message")
            if yank_message:
                result.response.check_version_response.yank_message = yank_message
            delete_message = messages.get("delete_message")
            if delete_message:
                result.response.check_version_response.delete_message = delete_message
        self._result_q.put(result)

    def send_request_stop_status(self, record):
        assert record.control.req_resp

        result = wandb_internal_pb2.Result(uuid=record.uuid)
        status_resp = result.response.stop_status_response
        status_resp.run_should_stop = False
        if self._entity and self._project and self._run.run_id:
            try:
                status_resp.run_should_stop = self._api.check_stop_requested(
                    self._project, self._entity, self._run.run_id)
            except Exception as e:
                logger.warning("Failed to check stop requested status: %s", e)
        self._result_q.put(result)

    def debounce(self) -> None:
        if self._config_needs_debounce:
            self._debounce_config()

    def _debounce_config(self):
        config_value_dict = self._config_format(self._consolidated_config)
        # TODO(jhr): check result of upsert_run?
        self._api.upsert_run(name=self._run.run_id,
                             config=config_value_dict,
                             **self._api_settings)
        self._config_save(config_value_dict)
        self._config_needs_debounce = False

    def send_request_network_status(self, record):
        assert record.control.req_resp

        result = wandb_internal_pb2.Result(uuid=record.uuid)
        status_resp = result.response.network_status_response
        while True:
            try:
                status_resp.network_responses.append(
                    self._retry_q.get_nowait())
            except queue.Empty:
                break
            except Exception as e:
                logger.warning("Error emptying retry queue: {}".format(e))
        self._result_q.put(result)

    def send_request_login(self, record):
        # TODO: do something with api_key or anonymous?
        # TODO: return an error if we aren't logged in?
        self._api.reauth()
        viewer_tuple = self._api.viewer_server_info()
        # self._login_flags = json.loads(viewer.get("flags", "{}"))
        # self._login_entity = viewer.get("entity")
        viewer, server_info = viewer_tuple
        if server_info:
            logger.info("Login server info: {}".format(server_info))
        self._entity = viewer.get("entity")
        if record.control.req_resp:
            result = wandb_internal_pb2.Result(uuid=record.uuid)
            if self._entity:
                result.response.login_response.active_entity = self._entity
            self._result_q.put(result)

    def send_exit(self, data):
        exit = data.exit
        self._exit_code = exit.exit_code
        logger.info("handling exit code: %s", exit.exit_code)

        # Pass the responsibility to respond to handle_request_defer()
        if data.control.req_resp:
            self._exit_sync_uuid = data.uuid

        # We need to give the request queue a chance to empty between states
        # so use handle_request_defer as a state machine.
        logger.info("send defer")
        self._interface.publish_defer()

    def send_final(self, data):
        pass

    def send_request_defer(self, data):
        defer = data.request.defer
        state = defer.state
        logger.info("handle sender defer: {}".format(state))

        def transition_state():
            state = defer.state + 1
            logger.info("send defer: {}".format(state))
            self._interface.publish_defer(state)

        done = False
        if state == defer.BEGIN:
            transition_state()
        elif state == defer.FLUSH_STATS:
            # NOTE: this is handled in handler.py:handle_request_defer()
            transition_state()
        elif state == defer.FLUSH_TB:
            # NOTE: this is handled in handler.py:handle_request_defer()
            transition_state()
        elif state == defer.FLUSH_SUM:
            # NOTE: this is handled in handler.py:handle_request_defer()
            transition_state()
        elif state == defer.FLUSH_DEBOUNCER:
            self.debounce()
            transition_state()
        elif state == defer.FLUSH_DIR:
            if self._dir_watcher:
                self._dir_watcher.finish()
                self._dir_watcher = None
            transition_state()
        elif state == defer.FLUSH_FP:
            if self._pusher:
                # FilePusher generates some events for FileStreamApi, so we
                # need to wait for pusher to finish before going to the next
                # state to ensure that filestream gets all the events that we
                # want before telling it to finish up
                self._pusher.finish(transition_state)
            else:
                transition_state()
        elif state == defer.FLUSH_FS:
            if self._fs:
                # TODO(jhr): now is a good time to output pending output lines
                self._fs.finish(self._exit_code)
                self._fs = None
            transition_state()
        elif state == defer.FLUSH_FINAL:
            self._interface.publish_final()
            self._interface.publish_footer()
            transition_state()
        elif state == defer.END:
            done = True
        else:
            raise AssertionError("unknown state")

        if not done:
            return

        exit_result = wandb_internal_pb2.RunExitResult()

        # This path is not the prefered method to return exit results
        # as it could take a long time to flush the file pusher buffers
        if self._exit_sync_uuid:
            if self._pusher:
                # NOTE: This will block until finished
                self._pusher.print_status()
                self._pusher.join()
                self._pusher = None
            resp = wandb_internal_pb2.Result(exit_result=exit_result,
                                             uuid=self._exit_sync_uuid)
            self._result_q.put(resp)

        # mark exit done in case we are polling on exit
        self._exit_result = exit_result

    def send_request_poll_exit(self, record):
        if not record.control.req_resp:
            return

        result = wandb_internal_pb2.Result(uuid=record.uuid)

        alive = False
        if self._pusher:
            alive, status = self._pusher.get_status()
            file_counts = self._pusher.file_counts_by_category()
            resp = result.response.poll_exit_response
            resp.pusher_stats.uploaded_bytes = status["uploaded_bytes"]
            resp.pusher_stats.total_bytes = status["total_bytes"]
            resp.pusher_stats.deduped_bytes = status["deduped_bytes"]
            resp.file_counts.wandb_count = file_counts["wandb"]
            resp.file_counts.media_count = file_counts["media"]
            resp.file_counts.artifact_count = file_counts["artifact"]
            resp.file_counts.other_count = file_counts["other"]

        if self._exit_result and not alive:
            # pusher join should not block as it was reported as not alive
            if self._pusher:
                self._pusher.join()
            result.response.poll_exit_response.exit_result.CopyFrom(
                self._exit_result)
            result.response.poll_exit_response.done = True
        self._result_q.put(result)

    def _maybe_setup_resume(self,
                            run) -> "Optional[wandb_internal_pb2.ErrorInfo]":
        """This maybe queries the backend for a run and fails if the settings are
        incompatible."""
        if not self._settings.resume:
            return None

        # TODO: This causes a race, we need to make the upsert atomically
        # only create or update depending on the resume config
        # we use the runs entity if set, otherwise fallback to users entity
        entity = run.entity or self._entity
        logger.info("checking resume status for %s/%s/%s", entity, run.project,
                    run.run_id)
        resume_status = self._api.run_resume_status(entity=entity,
                                                    project_name=run.project,
                                                    name=run.run_id)

        if not resume_status:
            if self._settings.resume == "must":
                error = wandb_internal_pb2.ErrorInfo()
                error.code = wandb_internal_pb2.ErrorInfo.ErrorCode.INVALID
                error.message = "resume='must' but run (%s) doesn't exist" % run.run_id
                return error
            return None

        #
        # handle cases where we have resume_status
        #
        if self._settings.resume == "never":
            error = wandb_internal_pb2.ErrorInfo()
            error.code = wandb_internal_pb2.ErrorInfo.ErrorCode.INVALID
            error.message = "resume='never' but run (%s) exists" % run.run_id
            return error

        history = {}
        events = {}
        config = {}
        summary = {}
        try:
            events_rt = 0
            history_rt = 0
            history = json.loads(resume_status["historyTail"])
            if history:
                history = json.loads(history[-1])
                history_rt = history.get("_runtime", 0)
            events = json.loads(resume_status["eventsTail"])
            if events:
                events = json.loads(events[-1])
                events_rt = events.get("_runtime", 0)
            config = json.loads(resume_status["config"] or "{}")
            summary = json.loads(resume_status["summaryMetrics"] or "{}")
        except (IndexError, ValueError) as e:
            logger.error("unable to load resume tails", exc_info=e)
            if self._settings.resume == "must":
                error = wandb_internal_pb2.ErrorInfo()
                error.code = wandb_internal_pb2.ErrorInfo.ErrorCode.INVALID
                error.message = "resume='must' but could not resume (%s) " % run.run_id
                return error

        # TODO: Do we need to restore config / summary?
        # System metrics runtime is usually greater than history
        self._resume_state["runtime"] = max(events_rt, history_rt)
        self._resume_state["step"] = history.get("_step",
                                                 -1) + 1 if history else 0
        self._resume_state["history"] = resume_status["historyLineCount"]
        self._resume_state["events"] = resume_status["eventsLineCount"]
        self._resume_state["output"] = resume_status["logLineCount"]
        self._resume_state["config"] = config
        self._resume_state["summary"] = summary
        self._resume_state["resumed"] = True
        logger.info("configured resuming with: %s" % self._resume_state)
        return None

    def _telemetry_get_framework(self) -> str:
        """Get telemetry data for internal config structure."""
        # detect framework by checking what is loaded
        imp: telemetry.TelemetryImports
        if self._telemetry_obj.HasField("imports_finish"):
            imp = self._telemetry_obj.imports_finish
        elif self._telemetry_obj.HasField("imports_init"):
            imp = self._telemetry_obj.imports_init
        else:
            return ""
        priority = _framework_priority(imp)
        framework = next((f for b, f in priority if b), "")
        return framework

    def _config_telemetry_update(self, config_dict: Dict[str, Any]) -> None:
        """Add legacy telemetry to config object."""
        wandb_key = "_wandb"
        config_dict.setdefault(wandb_key, dict())
        s: str
        b: bool
        s = self._telemetry_obj.python_version
        if s:
            config_dict[wandb_key]["python_version"] = s
        s = self._telemetry_obj.cli_version
        if s:
            config_dict[wandb_key]["cli_version"] = s
        s = self._telemetry_get_framework()
        if s:
            config_dict[wandb_key]["framework"] = s
        s = self._telemetry_obj.huggingface_version
        if s:
            config_dict[wandb_key]["huggingface_version"] = s
        b = self._telemetry_obj.env.jupyter
        config_dict[wandb_key]["is_jupyter_run"] = b
        b = self._telemetry_obj.env.kaggle
        config_dict[wandb_key]["is_kaggle_kernel"] = b

        t: Dict[int,
                Any] = proto_util.proto_encode_to_dict(self._telemetry_obj)
        config_dict[wandb_key]["t"] = t

    def _config_metric_update(self, config_dict: Dict[str, Any]) -> None:
        """Add default xaxis to config."""
        if not self._config_metric_pbdict_list:
            return
        wandb_key = "_wandb"
        config_dict.setdefault(wandb_key, dict())
        config_dict[wandb_key]["m"] = self._config_metric_pbdict_list

    def _config_format(self,
                       config_data: Optional[DictNoValues]) -> DictWithValues:
        """Format dict into value dict with telemetry info."""
        config_dict: Dict[str,
                          Any] = config_data.copy() if config_data else dict()
        self._config_telemetry_update(config_dict)
        self._config_metric_update(config_dict)
        config_value_dict: DictWithValues = config_util.dict_add_value_dict(
            config_dict)
        return config_value_dict

    def _config_save(self, config_value_dict: DictWithValues) -> None:
        config_path = os.path.join(self._settings.files_dir, "config.yaml")
        config_util.save_config_file_from_dict(config_path, config_value_dict)

    def _sync_spell(self, env=None):
        """Syncs this run with spell"""
        try:
            env = env or os.environ
            self._interface.publish_config(key=("_wandb", "spell_url"),
                                           val=env.get("SPELL_RUN_URL"))
            url = "{}/{}/{}/runs/{}".format(self._api.app_url,
                                            self._run.entity,
                                            self._run.project,
                                            self._run.run_id)
            return requests.put(
                env.get("SPELL_API_URL", "https://api.spell.run") +
                "/wandb_url",
                json={
                    "access_token": env.get("WANDB_ACCESS_TOKEN"),
                    "url": url
                },
                timeout=2,
            )
        except requests.RequestException:
            return False

    def send_run(self, data, file_dir=None) -> None:
        run = data.run
        error = None
        is_wandb_init = self._run is None

        # update telemetry
        if run.telemetry:
            self._telemetry_obj.MergeFrom(run.telemetry)

        # build config dict
        config_value_dict: Optional[DictWithValues] = None
        if run.config:
            config_util.update_from_proto(self._consolidated_config,
                                          run.config)
            config_value_dict = self._config_format(self._consolidated_config)
            self._config_save(config_value_dict)

        if is_wandb_init:
            # Ensure we have a project to query for status
            if run.project == "":
                run.project = util.auto_project_name(self._settings.program)
            # Only check resume status on `wandb.init`
            error = self._maybe_setup_resume(run)

        if error is not None:
            if data.control.req_resp:
                resp = wandb_internal_pb2.Result(uuid=data.uuid)
                resp.run_result.run.CopyFrom(run)
                resp.run_result.error.CopyFrom(error)
                self._result_q.put(resp)
            else:
                logger.error("Got error in async mode: %s", error.message)
            return

        # Save the resumed config
        if self._resume_state["config"] is not None:
            # TODO: should we merge this with resumed config?
            config_override = self._consolidated_config
            config_dict = self._resume_state["config"]
            config_dict = config_util.dict_strip_value_dict(config_dict)
            config_dict.update(config_override)
            self._consolidated_config.update(config_dict)
            config_value_dict = self._config_format(self._consolidated_config)
            self._config_save(config_value_dict)

        # handle empty config
        # TODO(jhr): consolidate the 4 ways config is built:
        #            (passed config, empty config, resume config, send_config)
        if not config_value_dict:
            config_value_dict = self._config_format(None)
            self._config_save(config_value_dict)

        self._init_run(run, config_value_dict)

        if data.control.req_resp:
            resp = wandb_internal_pb2.Result(uuid=data.uuid)
            # TODO: we could do self._interface.publish_defer(resp) to notify
            # the handler not to actually perform server updates for this uuid
            # because the user process will send a summary update when we resume
            resp.run_result.run.CopyFrom(self._run)
            self._result_q.put(resp)

        # Only spin up our threads on the first run message
        if is_wandb_init:
            self._start_run_threads(file_dir)
        else:
            logger.info("updated run: %s", self._run.run_id)

    def _init_run(self, run, config_dict):
        # We subtract the previous runs runtime when resuming
        start_time = run.start_time.ToSeconds() - self._resume_state["runtime"]
        repo = GitRepo(remote=self._settings.git_remote)
        # TODO: we don't check inserted currently, ultimately we should make
        # the upsert know the resume state and fail transactionally
        server_run, inserted = self._api.upsert_run(
            name=run.run_id,
            entity=run.entity or None,
            project=run.project or None,
            group=run.run_group or None,
            job_type=run.job_type or None,
            display_name=run.display_name or None,
            notes=run.notes or None,
            tags=run.tags[:] or None,
            config=config_dict or None,
            sweep_name=run.sweep_id or None,
            host=run.host or None,
            program_path=self._settings.program or None,
            repo=repo.remote_url,
            commit=repo.last_commit,
        )
        self._run = run
        if self._resume_state.get("resumed"):
            self._run.resumed = True
        self._run.starting_step = self._resume_state["step"]
        self._run.start_time.FromSeconds(int(start_time))
        self._run.config.CopyFrom(self._interface._make_config(config_dict))
        if self._resume_state["summary"] is not None:
            self._run.summary.CopyFrom(
                self._interface._make_summary_from_dict(
                    self._resume_state["summary"]))
        storage_id = server_run.get("id")
        if storage_id:
            self._run.storage_id = storage_id
        id = server_run.get("name")
        if id:
            self._api.set_current_run_id(id)
        display_name = server_run.get("displayName")
        if display_name:
            self._run.display_name = display_name
        project = server_run.get("project")
        # TODO: remove self._api.set_settings, and make self._project a property?
        if project:
            project_name = project.get("name")
            if project_name:
                self._run.project = project_name
                self._project = project_name
                self._api_settings["project"] = project_name
                self._api.set_setting("project", project_name)
            entity = project.get("entity")
            if entity:
                entity_name = entity.get("name")
                if entity_name:
                    self._run.entity = entity_name
                    self._entity = entity_name
                    self._api_settings["entity"] = entity_name
                    self._api.set_setting("entity", entity_name)
        sweep_id = server_run.get("sweepName")
        if sweep_id:
            self._run.sweep_id = sweep_id
        if os.getenv("SPELL_RUN_URL"):
            self._sync_spell()

    def _start_run_threads(self, file_dir=None):
        self._fs = file_stream.FileStreamApi(
            self._api,
            self._run.run_id,
            self._run.start_time.ToSeconds(),
            settings=self._api_settings,
        )
        # Ensure the streaming polices have the proper offsets
        self._fs.set_file_policy("wandb-summary.json",
                                 file_stream.SummaryFilePolicy())
        self._fs.set_file_policy(
            "wandb-history.jsonl",
            file_stream.JsonlFilePolicy(
                start_chunk_id=self._resume_state["history"]),
        )
        self._fs.set_file_policy(
            "wandb-events.jsonl",
            file_stream.JsonlFilePolicy(
                start_chunk_id=self._resume_state["events"]),
        )
        self._fs.set_file_policy(
            "output.log",
            file_stream.CRDedupeFilePolicy(
                start_chunk_id=self._resume_state["output"]),
        )
        util.sentry_set_scope(
            "internal",
            entity=self._run.entity,
            project=self._run.project,
            email=self._settings.email,
        )
        self._fs.start()
        self._pusher = FilePusher(self._api,
                                  self._fs,
                                  silent=self._settings.silent)
        self._dir_watcher = DirWatcher(self._settings, self._api, self._pusher,
                                       file_dir)
        logger.info(
            "run started: %s with start time %s",
            self._run.run_id,
            self._run.start_time.ToSeconds(),
        )

    def _save_history(self, history_dict):
        if self._fs:
            self._fs.push(filenames.HISTORY_FNAME, json.dumps(history_dict))

    def send_history(self, data):
        history = data.history
        history_dict = proto_util.dict_from_proto_list(history.item)
        self._save_history(history_dict)

    def send_summary(self, data):
        summary_dict = proto_util.dict_from_proto_list(data.summary.update)
        json_summary = json.dumps(summary_dict)
        if self._fs:
            self._fs.push(filenames.SUMMARY_FNAME, json_summary)
        # TODO(jhr): we should only write this at the end of the script
        summary_path = os.path.join(self._settings.files_dir,
                                    filenames.SUMMARY_FNAME)
        with open(summary_path, "w") as f:
            f.write(json_summary)
        self._save_file(filenames.SUMMARY_FNAME)

    def send_stats(self, data):
        stats = data.stats
        if stats.stats_type != wandb_internal_pb2.StatsRecord.StatsType.SYSTEM:
            return
        if not self._fs:
            return
        now = stats.timestamp.seconds
        d = dict()
        for item in stats.item:
            d[item.key] = json.loads(item.value_json)
        row = dict(system=d)
        self._flatten(row)
        row["_wandb"] = True
        row["_timestamp"] = now
        row["_runtime"] = int(now - self._run.start_time.ToSeconds())
        self._fs.push(filenames.EVENTS_FNAME, json.dumps(row))
        # TODO(jhr): check fs.push results?

    def send_output(self, data):
        if not self._fs:
            return
        out = data.output
        prepend = ""
        stream = "stdout"
        if out.output_type == wandb_internal_pb2.OutputRecord.OutputType.STDERR:
            stream = "stderr"
            prepend = "ERROR "
        line = out.line
        if not line.endswith("\n"):
            self._partial_output.setdefault(stream, "")
            if line.startswith("\r"):
                self._partial_output[stream] = ""
            self._partial_output[stream] += line
            # TODO(jhr): how do we make sure this gets flushed?
            # we might need this for other stuff like telemetry
        else:
            # TODO(jhr): use time from timestamp proto
            # TODO(jhr): do we need to make sure we write full lines?
            # seems to be some issues with line breaks
            cur_time = time.time()
            timestamp = datetime.utcfromtimestamp(cur_time).isoformat() + " "
            prev_str = self._partial_output.get(stream, "")
            line = u"{}{}{}{}".format(prepend, timestamp, prev_str, line)
            self._fs.push(filenames.OUTPUT_FNAME, line)
            self._partial_output[stream] = ""

    def _update_config(self):
        self._config_needs_debounce = True

    def send_config(self, data):
        cfg = data.config
        config_util.update_from_proto(self._consolidated_config, cfg)
        self._update_config()

    def send_metric(self, data: wandb_internal_pb2.Record) -> None:
        metric = data.metric
        if metric.glob_name:
            logger.warning("Seen metric with glob (shouldnt happen)")
            return

        # merge or overwrite
        old_metric = self._config_metric_dict.get(
            metric.name, wandb_internal_pb2.MetricRecord())
        if metric._control.overwrite:
            old_metric.CopyFrom(metric)
        else:
            old_metric.MergeFrom(metric)
        self._config_metric_dict[metric.name] = old_metric
        metric = old_metric

        # convert step_metric to index
        if metric.step_metric:
            find_step_idx = self._config_metric_index_dict.get(
                metric.step_metric)
            if find_step_idx is not None:
                # make a copy of this metric as we will be modifying it
                rec = wandb_internal_pb2.Record()
                rec.metric.CopyFrom(metric)
                metric = rec.metric

                metric.ClearField("step_metric")
                metric.step_metric_index = find_step_idx + 1

        md: Dict[int, Any] = proto_util.proto_encode_to_dict(metric)
        find_idx = self._config_metric_index_dict.get(metric.name)
        if find_idx is not None:
            self._config_metric_pbdict_list[find_idx] = md
        else:
            next_idx = len(self._config_metric_pbdict_list)
            self._config_metric_pbdict_list.append(md)
            self._config_metric_index_dict[metric.name] = next_idx
        self._update_config()

    def send_telemetry(self, data):
        telem = data.telemetry
        self._telemetry_obj.MergeFrom(telem)
        self._update_config()

    def _save_file(self, fname, policy="end"):
        logger.info("saving file %s with policy %s", fname, policy)
        if self._dir_watcher:
            self._dir_watcher.update_policy(fname, policy)

    def send_files(self, data):
        files = data.files
        for k in files.files:
            # TODO(jhr): fix paths with directories
            self._save_file(k.path, interface.file_enum_to_policy(k.policy))

    def send_header(self, data):
        pass

    def send_footer(self, data):
        pass

    def send_tbrecord(self, data):
        # tbrecord watching threads are handled by handler.py
        pass

    def send_request_log_artifact(self, record):
        assert record.control.req_resp
        result = wandb_internal_pb2.Result(uuid=record.uuid)
        artifact = record.request.log_artifact.artifact

        try:
            result.response.log_artifact_response.artifact_id = self._send_artifact(
                artifact).get("id")
        except Exception as e:
            result.response.log_artifact_response.error_message = 'error logging artifact "{}/{}": {}'.format(
                artifact.type, artifact.name, e)

        self._result_q.put(result)

    def send_artifact(self, data):
        artifact = data.artifact
        try:
            self._send_artifact(artifact)
        except Exception as e:
            logger.error(
                'send_artifact: failed for artifact "{}/{}": {}'.format(
                    artifact.type, artifact.name, e))

    def _send_artifact(self, artifact):
        saver = artifacts.ArtifactSaver(
            api=self._api,
            digest=artifact.digest,
            manifest_json=artifacts._manifest_json_from_proto(
                artifact.manifest),
            file_pusher=self._pusher,
            is_user_created=artifact.user_created,
        )

        if artifact.distributed_id:
            max_cli_version = self._max_cli_version()
            if max_cli_version is None or parse_version(
                    max_cli_version) < parse_version("0.10.16"):
                logger.warning(
                    "This W&B server doesn't support distributed artifacts, "
                    "have your administrator install wandb/local >= 0.9.37")
                return

        metadata = json.loads(artifact.metadata) if artifact.metadata else None
        return saver.save(
            type=artifact.type,
            name=artifact.name,
            metadata=metadata,
            description=artifact.description,
            aliases=artifact.aliases,
            use_after_commit=artifact.use_after_commit,
            distributed_id=artifact.distributed_id,
            finalize=artifact.finalize,
            incremental=artifact.incremental_beta1,
        )

    def send_alert(self, data):
        alert = data.alert
        max_cli_version = self._max_cli_version()
        if max_cli_version is None or parse_version(
                max_cli_version) < parse_version("0.10.9"):
            logger.warning(
                "This W&B server doesn't support alerts, "
                "have your administrator install wandb/local >= 0.9.31")
        else:
            try:
                self._api.notify_scriptable_run_alert(
                    title=alert.title,
                    text=alert.text,
                    level=alert.level,
                    wait_duration=alert.wait_duration,
                )
            except Exception as e:
                logger.error('send_alert: failed for alert "{}": {}'.format(
                    alert.title, e))

    def finish(self):
        logger.info("shutting down sender")
        # if self._tb_watcher:
        #     self._tb_watcher.finish()
        if self._dir_watcher:
            self._dir_watcher.finish()
            self._dir_watcher = None
        if self._pusher:
            self._pusher.finish()
            self._pusher.join()
            self._pusher = None
        if self._fs:
            self._fs.finish(self._exit_code)
            self._fs = None

    def _max_cli_version(self):
        _, server_info = self._api.viewer_server_info()
        max_cli_version = server_info.get("cliVersionInfo",
                                          {}).get("max_cli_version", None)
        return max_cli_version

    def __next__(self):
        return self._record_q.get(block=True)

    next = __next__
Beispiel #3
0
class SendManager(object):
    def __init__(self, settings, process_q, notify_q, resp_q, run_meta=None):
        self._settings = settings
        self._resp_q = resp_q
        self._run_meta = run_meta

        self._fs = None
        self._pusher = None
        self._dir_watcher = None
        self._tb_watcher = None

        # State updated by login
        self._entity = None
        self._flags = None

        # State updated by wandb.init
        self._run = None
        self._project = None

        # State updated by resuming
        self._offsets = {
            "step": 0,
            "history": 0,
            "events": 0,
            "output": 0,
            "runtime": 0,
        }

        self._api = internal_api.Api(default_settings=settings)
        self._api_settings = dict()

        # TODO(jhr): do something better, why do we need to send full lines?
        self._partial_output = dict()

        self._interface = interface.BackendSender(
            process_queue=process_q,
            notify_queue=notify_q,
        )

        self._exit_code = 0

        # keep track of config and summary from key/val updates
        # self._consolidated_config = dict()
        self._consolidated_summary = dict()

    def send(self, record):
        record_type = record.WhichOneof("record_type")
        if record_type is None:
            print("unknown record")
            return
        handler = getattr(self, "handle_" + record_type, None)
        if handler is None:
            print("unknown handle", record_type)
            return
        handler(record)

    def send_request(self, record):
        request_type = record.request.WhichOneof("request_type")
        if request_type is None:
            print("unknown request")
            return
        handler = getattr(self, "handle_request_" + request_type, None)
        if handler is None:
            print("unknown request handle", request_type)
            return
        handler(record)

    def _flatten(self, dictionary):
        if type(dictionary) == dict:
            for k, v in list(dictionary.items()):
                if type(v) == dict:
                    self._flatten(v)
                    dictionary.pop(k)
                    for k2, v2 in v.items():
                        dictionary[k + "." + k2] = v2

    def handle_request_status(self, data):
        if not data.control.req_resp:
            return

        result = wandb_internal_pb2.Result(uuid=data.uuid)
        status_resp = result.response.status_response
        if data.request.status.check_stop_req:
            status_resp.run_should_stop = False
            if self._entity and self._project and self._run.run_id:
                try:
                    status_resp.run_should_stop = self._api.check_stop_requested(
                        self._project, self._entity, self._run.run_id)
                except Exception as e:
                    logger.warning("Failed to check stop requested status: %s",
                                   e)
        self._resp_q.put(result)

    def handle_tbdata(self, data):
        if self._tb_watcher:
            tbdata = data.tbdata
            self._tb_watcher.add(tbdata.log_dir, tbdata.save)

    def handle_request(self, rec):
        self.send_request(rec)

    def handle_request_login(self, data):
        # TODO: do something with api_key or anonymous?
        # TODO: return an error if we aren't logged in?
        viewer = self._api.viewer()
        self._flags = json.loads(viewer.get("flags", "{}"))
        self._entity = viewer.get("entity")
        if data.control.req_resp:
            result = wandb_internal_pb2.Result(uuid=data.uuid)
            result.response.login_response.active_entity = self._entity
            self._resp_q.put(result)

    def handle_exit(self, data):
        exit = data.exit
        self._exit_code = exit.exit_code

        logger.info("handling exit code: %s", exit.exit_code)

        # Ensure we've at least noticed every file in the run directory. Sometimes
        # we miss things because asynchronously watching filesystems isn't reliable.
        run_dir = self._settings.files_dir
        logger.info("scan: %s", run_dir)

        # shutdown tensorboard workers so we get all metrics flushed
        if self._tb_watcher:
            self._tb_watcher.finish()
            self._tb_watcher = None

        # Pass the responsibility to respond to handle_final()
        if data.control.req_resp:
            # send exit_final to give the queue a chance to flush
            # response will be handled in handle_exit_final
            logger.info("send defer")
            self._interface.send_defer(data.uuid)

    def handle_request_defer(self, data):
        logger.info("handle defer")

        if self._dir_watcher:
            self._dir_watcher.finish()
            self._dir_watcher = None

        if self._pusher:
            self._pusher.finish()

        if self._fs:
            # TODO(jhr): now is a good time to output pending output lines
            self._fs.finish(self._exit_code)
            self._fs = None

        # NB: assume we always need to send a response for this message
        # since it was sent on behalf of handle_exit() req/resp logic
        resp = wandb_internal_pb2.Result(uuid=data.uuid)
        file_counts = self._pusher.file_counts_by_category()
        resp.exit_result.files.wandb_count = file_counts["wandb"]
        resp.exit_result.files.media_count = file_counts["media"]
        resp.exit_result.files.artifact_count = file_counts["artifact"]
        resp.exit_result.files.other_count = file_counts["other"]
        self._resp_q.put(resp)

        # TODO(david): this info should be in exit_result footer?
        if self._pusher:
            self._pusher.print_status()
            self._pusher = None

    def _maybe_setup_resume(self, run):
        """This maybe queries the backend for a run and fails if the settings are
        incompatible."""
        error = None
        if self._settings.resume:
            # TODO: This causes a race, we need to make the upsert atomically
            # only create or update depending on the resume config
            # we use the runs entity if set, otherwise fallback to users entity
            entity = run.entity or self._entity
            logger.info("checking resume status for %s/%s/%s", entity,
                        run.project, run.run_id)
            resume_status = self._api.run_resume_status(
                entity=entity, project_name=run.project, name=run.run_id)
            logger.info("resume status %s", resume_status)
            if resume_status is None:
                if self._settings.resume == "must":
                    error = wandb_internal_pb2.ErrorInfo()
                    error.code = wandb_internal_pb2.ErrorInfo.ErrorCode.INVALID
                    error.message = (
                        "resume='must' but run (%s) doesn't exist" %
                        run.run_id)
            else:
                if self._settings.resume == "never":
                    error = wandb_internal_pb2.ErrorInfo()
                    error.code = wandb_internal_pb2.ErrorInfo.ErrorCode.INVALID
                    error.message = "resume='never' but run (%s) exists" % run.run_id
                elif self._settings.resume in ("allow", "auto"):
                    history = {}
                    events = {}
                    try:
                        history = json.loads(
                            json.loads(resume_status["historyTail"])[-1])
                        events = json.loads(
                            json.loads(resume_status["eventsTail"])[-1])
                    except (IndexError, ValueError) as e:
                        logger.error("unable to load resume tails", exc_info=e)
                    # TODO: Do we need to restore config / summary?
                    # System metrics runtime is usually greater than history
                    events_rt = events.get("_runtime", 0)
                    history_rt = history.get("_runtime", 0)
                    self._offsets["runtime"] = max(events_rt, history_rt)
                    self._offsets["step"] = history.get("_step", -1) + 1
                    self._offsets["history"] = resume_status[
                        "historyLineCount"]
                    self._offsets["events"] = resume_status["eventsLineCount"]
                    self._offsets["output"] = resume_status["logLineCount"]
                    logger.info("configured resuming with: %s" % self._offsets)
        return error

    def handle_run(self, data):
        run = data.run
        run_tags = run.tags[:]
        error = None
        is_wandb_init = self._run is None

        # build config dict
        config_dict = None
        if run.HasField("config"):
            config_dict = _config_dict_from_proto_list(run.config.update)
            config_path = os.path.join(self._settings.files_dir, CONFIG_FNAME)
            save_config_file_from_dict(config_path, config_dict)

        repo = GitRepo(remote=self._settings.git_remote)

        if is_wandb_init:
            # Only check resume status on `wandb.init`
            error = self._maybe_setup_resume(run)

        if error is not None:
            if data.control.req_resp:
                resp = wandb_internal_pb2.Result(uuid=data.uuid)
                resp.run_result.run.CopyFrom(run)
                resp.run_result.error.CopyFrom(error)
                self._resp_q.put(resp)
            else:
                logger.error("Got error in async mode: %s", error.message)
            return

        # TODO: we don't check inserted currently, ultimately we should make
        # the upsert know the resume state and fail transactionally
        ups, inserted = self._api.upsert_run(
            name=run.run_id,
            entity=run.entity or None,
            project=run.project or None,
            group=run.run_group or None,
            job_type=run.job_type or None,
            display_name=run.display_name or None,
            notes=run.notes or None,
            tags=run_tags or None,
            config=config_dict or None,
            sweep_name=run.sweep_id or None,
            host=run.host or None,
            program_path=self._settings.program or None,
            repo=repo.remote_url,
            commit=repo.last_commit,
        )

        # We subtract the previous runs runtime when resuming
        start_time = run.start_time.ToSeconds() - self._offsets["runtime"]
        self._run = run
        self._run.starting_step = self._offsets["step"]
        self._run.start_time.FromSeconds(start_time)
        storage_id = ups.get("id")
        if storage_id:
            self._run.storage_id = storage_id
        display_name = ups.get("displayName")
        if display_name:
            self._run.display_name = display_name
        project = ups.get("project")
        if project:
            project_name = project.get("name")
            if project_name:
                self._run.project = project_name
                self._project = project_name
            entity = project.get("entity")
            if entity:
                entity_name = entity.get("name")
                if entity_name:
                    self._run.entity = entity_name
                    self._entity = entity_name

        if data.control.req_resp:
            resp = wandb_internal_pb2.Result(uuid=data.uuid)
            resp.run_result.run.CopyFrom(self._run)
            self._resp_q.put(resp)

        if self._entity is not None:
            self._api_settings["entity"] = self._entity
        if self._project is not None:
            self._api_settings["project"] = self._project

        # Only spin up our threads on the first run message
        if is_wandb_init:
            self._fs = file_stream.FileStreamApi(self._api,
                                                 run.run_id,
                                                 start_time,
                                                 settings=self._api_settings)
            # Ensure the streaming polices have the proper offsets
            self._fs.set_file_policy("wandb-summary.json",
                                     file_stream.SummaryFilePolicy())
            self._fs.set_file_policy(
                "wandb-history.jsonl",
                file_stream.JsonlFilePolicy(
                    start_chunk_id=self._offsets["history"]),
            )
            self._fs.set_file_policy(
                "wandb-events.jsonl",
                file_stream.JsonlFilePolicy(
                    start_chunk_id=self._offsets["events"]),
            )
            self._fs.set_file_policy(
                "output.log",
                file_stream.CRDedupeFilePolicy(
                    start_chunk_id=self._offsets["output"]),
            )
            self._fs.start()
            self._pusher = FilePusher(self._api)
            self._dir_watcher = DirWatcher(self._settings, self._api,
                                           self._pusher)
            self._tb_watcher = tb_watcher.TBWatcher(self._settings,
                                                    sender=self)
            if self._run_meta:
                self._run_meta.write()
            sentry_set_scope("internal", run.entity, run.project)
            logger.info("run started: %s with start time %s", self._run.run_id,
                        start_time)
        else:
            logger.info("updated run: %s", self._run.run_id)

    def _save_history(self, history_dict):
        if self._fs:
            # print("\n\nABOUT TO SAVE:\n", history_dict, "\n\n")
            self._fs.push(HISTORY_FNAME, json.dumps(history_dict))
            # print("got", x)
        # save history into summary
        self._consolidated_summary.update(history_dict)
        self._save_summary(self._consolidated_summary)

    def handle_history(self, data):
        history = data.history
        history_dict = dict_from_proto_list(history.item)
        self._save_history(history_dict)

    def _save_summary(self, summary_dict):
        json_summary = json.dumps(summary_dict)
        if self._fs:
            self._fs.push(SUMMARY_FNAME, json_summary)
        summary_path = os.path.join(self._settings.files_dir, SUMMARY_FNAME)
        with open(summary_path, "w") as f:
            f.write(json_summary)
            self._save_file(SUMMARY_FNAME)

    def handle_summary(self, data):
        summary = data.summary
        summary_dict = dict_from_proto_list(summary.update)
        self._consolidated_summary.update(summary_dict)
        self._save_summary(self._consolidated_summary)

    def handle_stats(self, data):
        stats = data.stats
        if stats.stats_type != wandb_internal_pb2.StatsRecord.StatsType.SYSTEM:
            return
        if not self._fs:
            return
        now = stats.timestamp.seconds
        d = dict()
        for item in stats.item:
            d[item.key] = json.loads(item.value_json)
        row = dict(system=d)
        self._flatten(row)
        row["_wandb"] = True
        row["_timestamp"] = now
        row["_runtime"] = int(now - self._run.start_time.ToSeconds())
        self._fs.push(EVENTS_FNAME, json.dumps(row))
        # TODO(jhr): check fs.push results?

    def handle_output(self, data):
        if not self._fs:
            return
        out = data.output
        prepend = ""
        stream = "stdout"
        if out.output_type == wandb_internal_pb2.OutputRecord.OutputType.STDERR:
            stream = "stderr"
            prepend = "ERROR "
        line = out.line
        if not line.endswith("\n"):
            self._partial_output.setdefault(stream, "")
            self._partial_output[stream] += line
            # TODO(jhr): how do we make sure this gets flushed?
            # we might need this for other stuff like telemetry
        else:
            # TODO(jhr): use time from timestamp proto
            # TODO(jhr): do we need to make sure we write full lines?
            # seems to be some issues with line breaks
            cur_time = time.time()
            timestamp = datetime.utcfromtimestamp(cur_time).isoformat() + " "
            prev_str = self._partial_output.get(stream, "")
            line = u"{}{}{}{}".format(prepend, timestamp, prev_str, line)
            self._fs.push(OUTPUT_FNAME, line)
            self._partial_output[stream] = ""

    def handle_config(self, data):
        cfg = data.config
        config_dict = _config_dict_from_proto_list(cfg.update)
        self._api.upsert_run(name=self._run.run_id,
                             config=config_dict,
                             **self._api_settings)
        config_path = os.path.join(self._settings.files_dir, "config.yaml")
        save_config_file_from_dict(config_path, config_dict)
        # TODO(jhr): check result of upsert_run?

    def _save_file(self, fname, policy="end"):
        logger.info("saving file %s with policy %s", fname, policy)
        self._dir_watcher.update_policy(fname, policy)

    def handle_files(self, data):
        files = data.files
        for k in files.files:
            # TODO(jhr): fix paths with directories
            self._save_file(k.path, interface.file_enum_to_policy(k.policy))

    def handle_artifact(self, data):
        artifact = data.artifact
        saver = artifacts.ArtifactSaver(
            api=self._api,
            digest=artifact.digest,
            manifest_json=artifacts._manifest_json_from_proto(
                artifact.manifest),
            file_pusher=self._pusher,
            is_user_created=artifact.user_created,
        )

        saver.save(
            type=artifact.type,
            name=artifact.name,
            metadata=artifact.metadata,
            description=artifact.description,
            aliases=artifact.aliases,
            use_after_commit=artifact.use_after_commit,
        )

    def handle_request_get_summary(self, data):
        result = wandb_internal_pb2.Result(uuid=data.uuid)
        for key, value in six.iteritems(self._consolidated_summary):
            item = wandb_internal_pb2.SummaryItem()
            item.key = key
            item.value_json = json.dumps(value)
            result.response.get_summary_response.item.append(item)
        self._resp_q.put(result)

    def finish(self):
        logger.info("shutting down sender")
        if self._tb_watcher:
            self._tb_watcher.finish()
        if self._dir_watcher:
            self._dir_watcher.finish()
        if self._pusher:
            self._pusher.finish()
        if self._fs:
            self._fs.finish(self._exit_code)
Beispiel #4
0
    def handle_run(self, data):
        run = data.run
        run_tags = run.tags[:]
        error = None
        is_wandb_init = self._run is None

        # build config dict
        config_dict = None
        if run.HasField("config"):
            config_dict = _config_dict_from_proto_list(run.config.update)
            config_path = os.path.join(self._settings.files_dir, CONFIG_FNAME)
            save_config_file_from_dict(config_path, config_dict)

        repo = GitRepo(remote=self._settings.git_remote)

        if is_wandb_init:
            # Only check resume status on `wandb.init`
            error = self._maybe_setup_resume(run)

        if error is not None:
            if data.control.req_resp:
                resp = wandb_internal_pb2.Result(uuid=data.uuid)
                resp.run_result.run.CopyFrom(run)
                resp.run_result.error.CopyFrom(error)
                self._resp_q.put(resp)
            else:
                logger.error("Got error in async mode: %s", error.message)
            return

        # TODO: we don't check inserted currently, ultimately we should make
        # the upsert know the resume state and fail transactionally
        ups, inserted = self._api.upsert_run(
            name=run.run_id,
            entity=run.entity or None,
            project=run.project or None,
            group=run.run_group or None,
            job_type=run.job_type or None,
            display_name=run.display_name or None,
            notes=run.notes or None,
            tags=run_tags or None,
            config=config_dict or None,
            sweep_name=run.sweep_id or None,
            host=run.host or None,
            program_path=self._settings.program or None,
            repo=repo.remote_url,
            commit=repo.last_commit,
        )

        # We subtract the previous runs runtime when resuming
        start_time = run.start_time.ToSeconds() - self._offsets["runtime"]
        self._run = run
        self._run.starting_step = self._offsets["step"]
        self._run.start_time.FromSeconds(start_time)
        storage_id = ups.get("id")
        if storage_id:
            self._run.storage_id = storage_id
        display_name = ups.get("displayName")
        if display_name:
            self._run.display_name = display_name
        project = ups.get("project")
        if project:
            project_name = project.get("name")
            if project_name:
                self._run.project = project_name
                self._project = project_name
            entity = project.get("entity")
            if entity:
                entity_name = entity.get("name")
                if entity_name:
                    self._run.entity = entity_name
                    self._entity = entity_name

        if data.control.req_resp:
            resp = wandb_internal_pb2.Result(uuid=data.uuid)
            resp.run_result.run.CopyFrom(self._run)
            self._resp_q.put(resp)

        if self._entity is not None:
            self._api_settings["entity"] = self._entity
        if self._project is not None:
            self._api_settings["project"] = self._project

        # Only spin up our threads on the first run message
        if is_wandb_init:
            self._fs = file_stream.FileStreamApi(self._api,
                                                 run.run_id,
                                                 start_time,
                                                 settings=self._api_settings)
            # Ensure the streaming polices have the proper offsets
            self._fs.set_file_policy("wandb-summary.json",
                                     file_stream.SummaryFilePolicy())
            self._fs.set_file_policy(
                "wandb-history.jsonl",
                file_stream.JsonlFilePolicy(
                    start_chunk_id=self._offsets["history"]),
            )
            self._fs.set_file_policy(
                "wandb-events.jsonl",
                file_stream.JsonlFilePolicy(
                    start_chunk_id=self._offsets["events"]),
            )
            self._fs.set_file_policy(
                "output.log",
                file_stream.CRDedupeFilePolicy(
                    start_chunk_id=self._offsets["output"]),
            )
            self._fs.start()
            self._pusher = FilePusher(self._api)
            self._dir_watcher = DirWatcher(self._settings, self._api,
                                           self._pusher)
            self._tb_watcher = tb_watcher.TBWatcher(self._settings,
                                                    sender=self)
            if self._run_meta:
                self._run_meta.write()
            sentry_set_scope("internal", run.entity, run.project)
            logger.info("run started: %s with start time %s", self._run.run_id,
                        start_time)
        else:
            logger.info("updated run: %s", self._run.run_id)
Beispiel #5
0
class SendManager(object):
    def __init__(
        self,
        settings,
        record_q,
        result_q,
        interface,
    ):
        self._settings = settings
        self._record_q = record_q
        self._result_q = result_q
        self._interface = interface

        self._fs = None
        self._pusher = None
        self._dir_watcher = None

        # State updated by login
        self._entity = None
        self._flags = None

        # State updated by wandb.init
        self._run = None
        self._project = None

        # State updated by resuming
        self._resume_state = {
            "step": 0,
            "history": 0,
            "events": 0,
            "output": 0,
            "runtime": 0,
            "summary": None,
            "config": None,
            "resumed": False,
        }

        # State added when run_exit needs results
        self._exit_sync_uuid = None

        # State added when run_exit is complete
        self._exit_result = None

        self._api = internal_api.Api(default_settings=settings)
        self._api_settings = dict()

        # TODO(jhr): do something better, why do we need to send full lines?
        self._partial_output = dict()

        self._exit_code = 0

    def send(self, record):
        record_type = record.WhichOneof("record_type")
        assert record_type
        handler_str = "send_" + record_type
        send_handler = getattr(self, handler_str, None)
        # Don't log output to reduce log noise
        if record_type != "output":
            logger.debug("send: {}".format(record_type))
        assert send_handler, "unknown send handler: {}".format(handler_str)
        send_handler(record)

    def send_request(self, record):
        request_type = record.request.WhichOneof("request_type")
        assert request_type
        handler_str = "send_request_" + request_type
        send_handler = getattr(self, handler_str, None)
        logger.debug("send_request: {}".format(request_type))
        assert send_handler, "unknown handle: {}".format(handler_str)
        send_handler(record)

    def _flatten(self, dictionary):
        if type(dictionary) == dict:
            for k, v in list(dictionary.items()):
                if type(v) == dict:
                    self._flatten(v)
                    dictionary.pop(k)
                    for k2, v2 in v.items():
                        dictionary[k + "." + k2] = v2

    def send_request_check_version(self, record):
        assert record.control.req_resp
        result = wandb_internal_pb2.Result(uuid=record.uuid)
        current_version = wandb.__version__
        messages = update.check_available(current_version)
        if messages:
            result.response.check_version_response.upgrade_message = messages[
                "upgrade_message"]
            result.response.check_version_response.yank_message = messages[
                "yank_message"]
            result.response.check_version_response.delete_message = messages[
                "delete_message"]
        self._result_q.put(result)

    def send_request_status(self, record):
        assert record.control.req_resp

        result = wandb_internal_pb2.Result(uuid=record.uuid)
        status_resp = result.response.status_response
        if record.request.status.check_stop_req:
            status_resp.run_should_stop = False
            if self._entity and self._project and self._run.run_id:
                try:
                    status_resp.run_should_stop = self._api.check_stop_requested(
                        self._project, self._entity, self._run.run_id)
                except Exception as e:
                    logger.warning("Failed to check stop requested status: %s",
                                   e)
        self._result_q.put(result)

    def send_request_login(self, record):
        # TODO: do something with api_key or anonymous?
        # TODO: return an error if we aren't logged in?
        self._api.reauth()
        viewer_tuple = self._api.viewer_server_info()
        # self._login_flags = json.loads(viewer.get("flags", "{}"))
        # self._login_entity = viewer.get("entity")
        viewer, server_info = viewer_tuple
        if server_info:
            logger.info("Login server info: {}".format(server_info))
        self._entity = viewer.get("entity")
        if record.control.req_resp:
            result = wandb_internal_pb2.Result(uuid=record.uuid)
            if self._entity:
                result.response.login_response.active_entity = self._entity
            self._result_q.put(result)

    def send_exit(self, data):
        exit = data.exit
        self._exit_code = exit.exit_code

        logger.info("handling exit code: %s", exit.exit_code)

        # Pass the responsibility to respond to handle_request_defer()
        if data.control.req_resp:
            self._exit_sync_uuid = data.uuid

        # We need to give the request queue a chance to empty between states
        # so use handle_request_defer as a state machine.
        logger.info("send defer")
        self._interface.publish_defer()

    def send_final(self, data):
        pass

    def send_request_defer(self, data):
        defer = data.request.defer
        state = defer.state
        logger.info("handle sender defer: {}".format(state))

        done = False
        if state == defer.BEGIN:
            pass
        elif state == defer.FLUSH_STATS:
            # NOTE: this is handled in handler.py:handle_request_defer()
            pass
        elif state == defer.FLUSH_TB:
            # NOTE: this is handled in handler.py:handle_request_defer()
            pass
        elif state == defer.FLUSH_SUM:
            # NOTE: this is handled in handler.py:handle_request_defer()
            pass
        elif state == defer.FLUSH_DIR:
            if self._dir_watcher:
                self._dir_watcher.finish()
                self._dir_watcher = None
        elif state == defer.FLUSH_FP:
            if self._pusher:
                self._pusher.finish()
        elif state == defer.FLUSH_FS:
            if self._fs:
                # TODO(jhr): now is a good time to output pending output lines
                self._fs.finish(self._exit_code)
                self._fs = None
        elif state == defer.FLUSH_FINAL:
            self._interface.publish_final()
            self._interface.publish_footer()
        elif state == defer.END:
            done = True
        else:
            raise AssertionError("unknown state")

        if not done:
            state += 1
            logger.info("send defer: {}".format(state))
            self._interface.publish_defer(state)
            return

        exit_result = wandb_internal_pb2.RunExitResult()

        # This path is not the prefered method to return exit results
        # as it could take a long time to flush the file pusher buffers
        if self._exit_sync_uuid:
            if self._pusher:
                # NOTE: This will block until finished
                self._pusher.print_status()
                self._pusher.join()
                self._pusher = None
            resp = wandb_internal_pb2.Result(exit_result=exit_result,
                                             uuid=self._exit_sync_uuid)
            self._result_q.put(resp)

        # mark exit done in case we are polling on exit
        self._exit_result = exit_result

    def send_request_poll_exit(self, record):
        if not record.control.req_resp:
            return

        result = wandb_internal_pb2.Result(uuid=record.uuid)

        alive = False
        if self._pusher:
            alive, status = self._pusher.get_status()
            file_counts = self._pusher.file_counts_by_category()
            resp = result.response.poll_exit_response
            resp.pusher_stats.uploaded_bytes = status["uploaded_bytes"]
            resp.pusher_stats.total_bytes = status["total_bytes"]
            resp.pusher_stats.deduped_bytes = status["deduped_bytes"]
            resp.file_counts.wandb_count = file_counts["wandb"]
            resp.file_counts.media_count = file_counts["media"]
            resp.file_counts.artifact_count = file_counts["artifact"]
            resp.file_counts.other_count = file_counts["other"]

        if self._exit_result and not alive:
            # pusher join should not block as it was reported as not alive
            if self._pusher:
                self._pusher.join()
            result.response.poll_exit_response.exit_result.CopyFrom(
                self._exit_result)
            result.response.poll_exit_response.done = True
        self._result_q.put(result)

    def _maybe_setup_resume(self, run):
        """This maybe queries the backend for a run and fails if the settings are
        incompatible."""
        if not self._settings.resume:
            return

        # TODO: This causes a race, we need to make the upsert atomically
        # only create or update depending on the resume config
        # we use the runs entity if set, otherwise fallback to users entity
        entity = run.entity or self._entity
        logger.info("checking resume status for %s/%s/%s", entity, run.project,
                    run.run_id)
        resume_status = self._api.run_resume_status(entity=entity,
                                                    project_name=run.project,
                                                    name=run.run_id)

        if not resume_status:
            if self._settings.resume == "must":
                error = wandb_internal_pb2.ErrorInfo()
                error.code = wandb_internal_pb2.ErrorInfo.ErrorCode.INVALID
                error.message = "resume='must' but run (%s) doesn't exist" % run.run_id
                return error
            return

        #
        # handle cases where we have resume_status
        #
        if self._settings.resume == "never":
            error = wandb_internal_pb2.ErrorInfo()
            error.code = wandb_internal_pb2.ErrorInfo.ErrorCode.INVALID
            error.message = "resume='never' but run (%s) exists" % run.run_id
            return error

        history = {}
        events = {}
        config = {}
        summary = {}
        try:
            events_rt = 0
            history_rt = 0
            history = json.loads(resume_status["historyTail"])
            if history:
                history = json.loads(history[-1])
                history_rt = history.get("_runtime", 0)
            events = json.loads(resume_status["eventsTail"])
            if events:
                events = json.loads(events[-1])
                events_rt = events.get("_runtime", 0)
            config = json.loads(resume_status["config"])
            summary = json.loads(resume_status["summaryMetrics"])
        except (IndexError, ValueError) as e:
            logger.error("unable to load resume tails", exc_info=e)
            if self._settings.resume == "must":
                error = wandb_internal_pb2.ErrorInfo()
                error.code = wandb_internal_pb2.ErrorInfo.ErrorCode.INVALID
                error.message = "resume='must' but could not resume (%s) " % run.run_id
                return error

        # TODO: Do we need to restore config / summary?
        # System metrics runtime is usually greater than history
        self._resume_state["runtime"] = max(events_rt, history_rt)
        self._resume_state["step"] = history.get("_step",
                                                 -1) + 1 if history else 0
        self._resume_state["history"] = resume_status["historyLineCount"]
        self._resume_state["events"] = resume_status["eventsLineCount"]
        self._resume_state["output"] = resume_status["logLineCount"]
        self._resume_state["config"] = config
        self._resume_state["summary"] = summary
        self._resume_state["resumed"] = True
        logger.info("configured resuming with: %s" % self._resume_state)
        return

    def send_run(self, data):
        run = data.run
        error = None
        is_wandb_init = self._run is None

        # build config dict
        config_dict = None
        config_path = os.path.join(self._settings.files_dir,
                                   filenames.CONFIG_FNAME)
        if run.config:
            config_dict = config_util.dict_from_proto_list(run.config.update)
            config_util.save_config_file_from_dict(config_path, config_dict)

        if is_wandb_init:
            # Ensure we have a project to query for status
            if run.project == "":
                run.project = util.auto_project_name(self._settings.program)
            # Only check resume status on `wandb.init`
            error = self._maybe_setup_resume(run)

        if error is not None:
            if data.control.req_resp:
                resp = wandb_internal_pb2.Result(uuid=data.uuid)
                resp.run_result.run.CopyFrom(run)
                resp.run_result.error.CopyFrom(error)
                self._result_q.put(resp)
            else:
                logger.error("Got error in async mode: %s", error.message)
            return

        # Save the resumed config
        if self._resume_state["config"] is not None:
            # TODO: should we merge this with resumed config?
            config_override = config_dict or {}
            config_dict = self._resume_state["config"]
            config_dict.update(config_override)
            config_util.save_config_file_from_dict(config_path, config_dict)

        self._init_run(run, config_dict)

        if data.control.req_resp:
            resp = wandb_internal_pb2.Result(uuid=data.uuid)
            # TODO: we could do self._interface.publish_defer(resp) to notify
            # the handler not to actually perform server updates for this uuid
            # because the user process will send a summary update when we resume
            resp.run_result.run.CopyFrom(self._run)
            self._result_q.put(resp)

        # Only spin up our threads on the first run message
        if is_wandb_init:
            self._start_run_threads()
        else:
            logger.info("updated run: %s", self._run.run_id)

    def _init_run(self, run, config_dict):
        # We subtract the previous runs runtime when resuming
        start_time = run.start_time.ToSeconds() - self._resume_state["runtime"]
        repo = GitRepo(remote=self._settings.git_remote)
        # TODO: we don't check inserted currently, ultimately we should make
        # the upsert know the resume state and fail transactionally
        server_run, inserted = self._api.upsert_run(
            name=run.run_id,
            entity=run.entity or None,
            project=run.project or None,
            group=run.run_group or None,
            job_type=run.job_type or None,
            display_name=run.display_name or None,
            notes=run.notes or None,
            tags=run.tags[:] or None,
            config=config_dict or None,
            sweep_name=run.sweep_id or None,
            host=run.host or None,
            program_path=self._settings.program or None,
            repo=repo.remote_url,
            commit=repo.last_commit,
        )
        self._run = run
        if self._resume_state.get("resumed"):
            self._run.resumed = True
        self._run.starting_step = self._resume_state["step"]
        self._run.start_time.FromSeconds(start_time)
        self._run.config.CopyFrom(self._interface._make_config(config_dict))
        if self._resume_state["summary"] is not None:
            self._run.summary.CopyFrom(
                self._interface._make_summary_from_dict(
                    self._resume_state["summary"]))
        storage_id = server_run.get("id")
        if storage_id:
            self._run.storage_id = storage_id
        id = server_run.get("name")
        if id:
            self._api.set_current_run_id(id)
        display_name = server_run.get("displayName")
        if display_name:
            self._run.display_name = display_name
        project = server_run.get("project")
        # TODO: remove self._api.set_settings, and make self._project a property?
        if project:
            project_name = project.get("name")
            if project_name:
                self._run.project = project_name
                self._project = project_name
                self._api_settings["project"] = project_name
                self._api.set_setting("project", project_name)
            entity = project.get("entity")
            if entity:
                entity_name = entity.get("name")
                if entity_name:
                    self._run.entity = entity_name
                    self._entity = entity_name
                    self._api_settings["entity"] = entity_name
                    self._api.set_setting("entity", entity_name)
        sweep_id = server_run.get("sweepName")
        if sweep_id:
            self._run.sweep_id = sweep_id

    def _start_run_threads(self):
        self._fs = file_stream.FileStreamApi(
            self._api,
            self._run.run_id,
            self._run.start_time.ToSeconds(),
            settings=self._api_settings,
        )
        # Ensure the streaming polices have the proper offsets
        self._fs.set_file_policy("wandb-summary.json",
                                 file_stream.SummaryFilePolicy())
        self._fs.set_file_policy(
            "wandb-history.jsonl",
            file_stream.JsonlFilePolicy(
                start_chunk_id=self._resume_state["history"]),
        )
        self._fs.set_file_policy(
            "wandb-events.jsonl",
            file_stream.JsonlFilePolicy(
                start_chunk_id=self._resume_state["events"]),
        )
        self._fs.set_file_policy(
            "output.log",
            file_stream.CRDedupeFilePolicy(
                start_chunk_id=self._resume_state["output"]),
        )
        self._fs.start()
        self._pusher = FilePusher(self._api)
        self._dir_watcher = DirWatcher(self._settings, self._api, self._pusher)
        util.sentry_set_scope(
            "internal",
            entity=self._run.entity,
            project=self._run.project,
            email=self._settings.email,
        )
        logger.info(
            "run started: %s with start time %s",
            self._run.run_id,
            self._run.start_time.ToSeconds(),
        )

    def _save_history(self, history_dict):
        if self._fs:
            self._fs.push(filenames.HISTORY_FNAME, json.dumps(history_dict))

    def send_history(self, data):
        history = data.history
        history_dict = proto_util.dict_from_proto_list(history.item)
        self._save_history(history_dict)

    def send_summary(self, data):
        summary_dict = proto_util.dict_from_proto_list(data.summary.update)
        json_summary = json.dumps(summary_dict)
        if self._fs:
            self._fs.push(filenames.SUMMARY_FNAME, json_summary)
        # TODO(jhr): we should only write this at the end of the script
        summary_path = os.path.join(self._settings.files_dir,
                                    filenames.SUMMARY_FNAME)
        with open(summary_path, "w") as f:
            f.write(json_summary)
        self._save_file(filenames.SUMMARY_FNAME)

    def send_stats(self, data):
        stats = data.stats
        if stats.stats_type != wandb_internal_pb2.StatsRecord.StatsType.SYSTEM:
            return
        if not self._fs:
            return
        now = stats.timestamp.seconds
        d = dict()
        for item in stats.item:
            d[item.key] = json.loads(item.value_json)
        row = dict(system=d)
        self._flatten(row)
        row["_wandb"] = True
        row["_timestamp"] = now
        row["_runtime"] = int(now - self._run.start_time.ToSeconds())
        self._fs.push(filenames.EVENTS_FNAME, json.dumps(row))
        # TODO(jhr): check fs.push results?

    def send_output(self, data):
        if not self._fs:
            return
        out = data.output
        prepend = ""
        stream = "stdout"
        if out.output_type == wandb_internal_pb2.OutputRecord.OutputType.STDERR:
            stream = "stderr"
            prepend = "ERROR "
        line = out.line
        if not line.endswith("\n"):
            self._partial_output.setdefault(stream, "")
            self._partial_output[stream] += line
            # TODO(jhr): how do we make sure this gets flushed?
            # we might need this for other stuff like telemetry
        else:
            # TODO(jhr): use time from timestamp proto
            # TODO(jhr): do we need to make sure we write full lines?
            # seems to be some issues with line breaks
            cur_time = time.time()
            timestamp = datetime.utcfromtimestamp(cur_time).isoformat() + " "
            prev_str = self._partial_output.get(stream, "")
            line = u"{}{}{}{}".format(prepend, timestamp, prev_str, line)
            self._fs.push(filenames.OUTPUT_FNAME, line)
            self._partial_output[stream] = ""

    def send_config(self, data):
        cfg = data.config
        config_dict = config_util.dict_from_proto_list(cfg.update)
        self._api.upsert_run(name=self._run.run_id,
                             config=config_dict,
                             **self._api_settings)
        config_path = os.path.join(self._settings.files_dir, "config.yaml")
        config_util.save_config_file_from_dict(config_path, config_dict)
        # TODO(jhr): check result of upsert_run?

    def _save_file(self, fname, policy="end"):
        logger.info("saving file %s with policy %s", fname, policy)
        if self._dir_watcher:
            self._dir_watcher.update_policy(fname, policy)

    def send_files(self, data):
        files = data.files
        for k in files.files:
            # TODO(jhr): fix paths with directories
            self._save_file(k.path, interface.file_enum_to_policy(k.policy))

    def send_header(self, data):
        pass

    def send_footer(self, data):
        pass

    def send_tbrecord(self, data):
        # tbrecord watching threads are handled by handler.py
        pass

    def send_artifact(self, data):
        artifact = data.artifact
        saver = artifacts.ArtifactSaver(
            api=self._api,
            digest=artifact.digest,
            manifest_json=artifacts._manifest_json_from_proto(
                artifact.manifest),
            file_pusher=self._pusher,
            is_user_created=artifact.user_created,
        )

        metadata = json.loads(artifact.metadata) if artifact.metadata else None
        saver.save(
            type=artifact.type,
            name=artifact.name,
            metadata=metadata,
            description=artifact.description,
            aliases=artifact.aliases,
            use_after_commit=artifact.use_after_commit,
        )

    def finish(self):
        logger.info("shutting down sender")
        # if self._tb_watcher:
        #     self._tb_watcher.finish()
        if self._dir_watcher:
            self._dir_watcher.finish()
            self._dir_watcher = None
        if self._pusher:
            self._pusher.finish()
            self._pusher.join()
            self._pusher = None
        if self._fs:
            self._fs.finish(self._exit_code)
            self._fs = None