def _attach( attach_id: Optional[str] = None, run_id: Optional[str] = None, ) -> Union[Run, RunDisabled, None]: """Attach to a run currently executing in another process/thread. Arguments: attach_id: (str, optional) The id of the run or an attach identifier that maps to a run. run_id: (str, optional) The id of the run to attach to. """ attach_id = attach_id or run_id if attach_id is None: raise UsageError("attach_id or run_id must be specified") wandb._assert_is_user_process() _wl = wandb_setup._setup() _set_logger(_wl._get_logger()) if logger is None: raise UsageError("logger is not initialized") manager = _wl._get_manager() if manager: manager._inform_attach(attach_id=attach_id) settings: Settings = copy.copy(_wl._settings) settings.update(run_id=attach_id, source=Source.INIT) # TODO: consolidate this codepath with wandb.init() backend = Backend(settings=settings, manager=manager) backend.ensure_launched() backend.server_connect() logger.info("attach backend started and connected") run = Run(settings=settings) run._set_library(_wl) run._set_backend(backend) backend._hack_set_run(run) assert backend.interface resp = backend.interface.communicate_attach(attach_id) if not resp: raise UsageError("problem") if resp and resp.error and resp.error.message: raise UsageError("bad: {}".format(resp.error.message)) run._set_run_obj(resp.run) return run
def no_retry_4xx(e): if not isinstance(e, requests.HTTPError): return True if not (e.response.status_code >= 400 and e.response.status_code < 500): return True body = json.loads(e.response.content) raise UsageError(body["errors"][0]["message"])
def prompt_api_key(self): key, status = self._prompt_api_key() if status == ApiKeyStatus.NOTTY: directive = ("wandb login [your_api_key]" if self._settings._cli_only_mode else "wandb.login(key=[your_api_key])") raise UsageError("api_key not configured (no-tty). call " + directive) self.update_session(key, status=status) self._key = key
def prompt_api_key(self): api = Api(self._settings) key = apikey.prompt_api_key( self._settings, api=api, no_offline=self._settings.force, no_create=self._settings.force, ) if key is False: raise UsageError("api_key not configured (no-tty). Run wandb login") self.update_session(key) self._key = key
def downsample(values: Sequence, target_length: int) -> list: """Downsamples 1d values to target_length, including start and end. Algorithm just rounds index down. Values can be any sequence, including a generator. """ if not target_length > 1: raise UsageError("target_length must be > 1") values = list(values) if len(values) < target_length: return values ratio = float(len(values) - 1) / (target_length - 1) result = [] for i in range(target_length): result.append(values[int(i * ratio)]) return result
def prompt_api_key(self): api = Api(self._settings) key = apikey.prompt_api_key( self._settings, api=api, no_offline=self._settings.force, no_create=self._settings.force, ) if key is False: directive = ("wandb login [your_api_key]" if self._settings._cli_only_mode else "wandb.login(key=[your_api_key])") raise UsageError("api_key not configured (no-tty). call " + directive) self.update_session(key) self._key = key
def parse_config(params, exclude=None, include=None): if exclude and include: raise UsageError("Expected at most only one of exclude or include") if isinstance(params, six.string_types): params = config_util.dict_from_config_file(params, must_exist=True) params = _to_dict(params) if include: params = { key: value for key, value in six.iteritems(params) if key in include } if exclude: params = { key: value for key, value in six.iteritems(params) if key not in exclude } return params
def init(self): # noqa: C901 assert logger logger.info("calling init triggers") trigger.call("on_init", **self.kwargs) s = self.settings sweep_config = self.sweep_config config = self.config logger.info( "wandb.init called with sweep_config: {}\nconfig: {}".format( sweep_config, config ) ) if s._noop: return self._make_run_disabled() if s.reinit or (s._jupyter and s.reinit is not False): if len(self._wl._global_run_stack) > 0: if len(self._wl._global_run_stack) > 1: wandb.termwarn( "If you want to track multiple runs concurrently in wandb you should use multi-processing not threads" # noqa: E501 ) last_id = self._wl._global_run_stack[-1]._run_id logger.info( "re-initializing run, found existing run on stack: {}".format( last_id ) ) jupyter = ( s._jupyter and not s._silent and ipython._get_python_type() == "jupyter" ) if jupyter: ipython.display_html( "Finishing last run (ID:{}) before initializing another...".format( last_id ) ) self._wl._global_run_stack[-1].finish() if jupyter: ipython.display_html( "...Successfully finished last run (ID:{}). Initializing new run:<br/><br/>".format( last_id ) ) elif isinstance(wandb.run, Run): logger.info("wandb.init() called when a run is still active") return wandb.run logger.info("starting backend") backend = Backend(settings=s) backend.ensure_launched() backend.server_connect() logger.info("backend started and connected") # Make sure we are logged in # wandb_login._login(_backend=backend, _settings=self.settings) # resuming needs access to the server, check server_status()? run = Run(config=config, settings=s, sweep_config=sweep_config) # probe the active start method active_start_method = None if s.start_method == "thread": active_start_method = s.start_method else: get_start_fn = getattr(backend._multiprocessing, "get_start_method", None) active_start_method = get_start_fn() if get_start_fn else None # Populate intial telemetry with telemetry.context(run=run) as tel: tel.cli_version = wandb.__version__ tel.python_version = platform.python_version() hf_version = _huggingface_version() if hf_version: tel.huggingface_version = hf_version if s._jupyter: tel.env.jupyter = True if s._kaggle: tel.env.kaggle = True if s._windows: tel.env.windows = True run._telemetry_imports(tel.imports_init) if active_start_method == "spawn": tel.env.start_spawn = True elif active_start_method == "fork": tel.env.start_fork = True elif active_start_method == "forkserver": tel.env.start_forkserver = True elif active_start_method == "thread": tel.env.start_thread = True logger.info("updated telemetry") run._set_library(self._wl) run._set_backend(backend) run._set_reporter(self._reporter) run._set_teardown_hooks(self._teardown_hooks) # TODO: pass mode to backend # run_synced = None backend._hack_set_run(run) backend.interface.publish_header() if s._offline: with telemetry.context(run=run) as tel: tel.feature.offline = True run_proto = backend.interface._make_run(run) backend.interface._publish_run(run_proto) run._set_run_obj_offline(run_proto) else: logger.info("communicating current version") ret = backend.interface.communicate_check_version( current_version=wandb.__version__ ) if ret: logger.info("got version response {}".format(ret)) if ret.upgrade_message: run._set_upgraded_version_message(ret.upgrade_message) if ret.delete_message: run._set_deleted_version_message(ret.delete_message) if ret.yank_message: run._set_yanked_version_message(ret.yank_message) run._on_init() logger.info("communicating run to backend with 30 second timeout") ret = backend.interface.communicate_run(run, timeout=30) error_message = None if not ret: logger.error("backend process timed out") error_message = "Error communicating with wandb process" if active_start_method != "fork": error_message += "\ntry: wandb.init(settings=wandb.Settings(start_method='fork'))" error_message += "\nor: wandb.init(settings=wandb.Settings(start_method='thread'))" error_message += "\nFor more info see: https://docs.wandb.ai/library/init#init-start-error" if ret and ret.error: error_message = ret.error.message if error_message: logger.error("encountered error: {}".format(error_message)) # Shutdown the backend and get rid of the logger # we don't need to do console cleanup at this point backend.cleanup() self.teardown() raise UsageError(error_message) if ret.run.resumed: logger.info("run resumed") with telemetry.context(run=run) as tel: tel.feature.resumed = True run._set_run_obj(ret.run) logger.info("starting run threads in backend") # initiate run (stats and metadata probing) run_obj = run._run_obj or run._run_obj_offline _ = backend.interface.communicate_run_start(run_obj) self._wl._global_run_stack.append(run) self.run = run self.backend = backend module.set_global( run=run, config=run.config, log=run.log, summary=run.summary, save=run.save, use_artifact=run.use_artifact, log_artifact=run.log_artifact, define_metric=run._define_metric, plot_table=run.plot_table, alert=run.alert, ) self._reporter.set_context(run=run) run._on_start() run._freeze() logger.info("run started, returning control to user process") return run
def torch_trace_handler(): """Creates a trace handler for traces generated by the profiler. Provide as an argument to `torch.profiler.profile`: ```python torch.profiler.profile(..., on_trace_ready = wandb.profiler.torch_trace_handler()) ``` Calling this function ensures that profiler charts & tables can be viewed in your run dashboard on wandb.ai. Please note that `wandb.init()` must be called before this function is invoked. The PyTorch (torch) version must also be at least 1.9, in order to ensure stability of their Profiler API. Args: None Returns: None Raises: UsageError if wandb.init() hasn't been called before profiling. Error if torch version is less than 1.9.0. Examples: ```python run = wandb.init() run.config.id = "profile_code" with torch.profiler.profile( schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=1), on_trace_ready=wandb.profiler.torch_trace_handler(), record_shapes=True, with_stack=True, ) as prof: for i, batch in enumerate(dataloader): if step >= 5: break train(batch) prof.step() ``` """ torch = wandb.util.get_module(PYTORCH_MODULE, required=True) torch_profiler = wandb.util.get_module(PYTORCH_PROFILER_MODULE, required=True) version = tuple( map(lambda x: int(x), torch.__version__.replace("+cpu", "").split("."))) if version < (1, 9, 0): raise Error( f"torch version must be at least 1.9 in order to use the PyTorch Profiler API.\ \nVersion of torch currently installed: {torch.__version__}") try: logdir = os.path.join(wandb.run.dir, "pytorch_traces") # type: ignore os.mkdir(logdir) except AttributeError: raise UsageError( "Please call `wandb.init()` before `wandb.profiler.torch_trace_handler()`" ) from None with telemetry.context() as tel: tel.feature.torch_profiler_trace = True return torch_profiler.tensorboard_trace_handler(logdir)
def init(self) -> Union[Run, RunDisabled, None]: # noqa: C901 if logger is None: raise RuntimeError("Logger not initialized") logger.info("calling init triggers") trigger.call("on_init", **self.kwargs) logger.info( f"wandb.init called with sweep_config: {self.sweep_config}\nconfig: {self.config}" ) if self.settings._noop: return self._make_run_disabled() if self.settings.reinit or (self.settings._jupyter and self.settings.reinit is not False): if len(self._wl._global_run_stack) > 0: if len(self._wl._global_run_stack) > 1: wandb.termwarn( "If you want to track multiple runs concurrently in wandb, " "you should use multi-processing not threads" # noqa: E501 ) last_id = self._wl._global_run_stack[-1]._run_id logger.info( f"re-initializing run, found existing run on stack: {last_id}" ) jupyter = (self.settings._jupyter and not self.settings.silent and ipython.in_jupyter()) if jupyter: ipython.display_html( f"Finishing last run (ID:{last_id}) before initializing another..." ) self._wl._global_run_stack[-1].finish() if jupyter: ipython.display_html( f"Successfully finished last run (ID:{last_id}). Initializing new run:<br/>" ) elif isinstance(wandb.run, Run): allow_return_run = True manager = self._wl._get_manager() if manager: current_pid = os.getpid() if current_pid != wandb.run._init_pid: # We shouldn't return a stale global run if we are in a new pid allow_return_run = False if allow_return_run: logger.info("wandb.init() called when a run is still active") with telemetry.context() as tel: tel.feature.init_return_run = True return wandb.run logger.info("starting backend") manager = self._wl._get_manager() if manager: manager._inform_init(settings=self.settings, run_id=self.settings.run_id) backend = Backend(settings=self.settings, manager=manager) backend.ensure_launched() backend.server_connect() logger.info("backend started and connected") # Make sure we are logged in # wandb_login._login(_backend=backend, _settings=self.settings) # resuming needs access to the server, check server_status()? run = Run(config=self.config, settings=self.settings, sweep_config=self.sweep_config) # probe the active start method active_start_method: Optional[str] = None if self.settings.start_method == "thread": active_start_method = self.settings.start_method else: get_start_fn = getattr(backend._multiprocessing, "get_start_method", None) active_start_method = get_start_fn() if get_start_fn else None # Populate initial telemetry with telemetry.context(run=run) as tel: tel.cli_version = wandb.__version__ tel.python_version = platform.python_version() hf_version = _huggingface_version() if hf_version: tel.huggingface_version = hf_version if self.settings._jupyter: tel.env.jupyter = True if self.settings._kaggle: tel.env.kaggle = True if self.settings._windows: tel.env.windows = True run._telemetry_imports(tel.imports_init) if self._use_sagemaker: tel.feature.sagemaker = True if self._set_init_config: tel.feature.set_init_config = True if self._set_init_name: tel.feature.set_init_name = True if self._set_init_id: tel.feature.set_init_id = True if self._set_init_tags: tel.feature.set_init_tags = True if self.settings.launch: tel.feature.launch = True if active_start_method == "spawn": tel.env.start_spawn = True elif active_start_method == "fork": tel.env.start_fork = True elif active_start_method == "forkserver": tel.env.start_forkserver = True elif active_start_method == "thread": tel.env.start_thread = True if manager: tel.feature.service = True tel.env.maybe_mp = _maybe_mp_process(backend) # fixme: detected issues with settings if self.settings.__dict__["_Settings__preprocessing_warnings"]: tel.issues.settings__preprocessing_warnings = True if self.settings.__dict__["_Settings__validation_warnings"]: tel.issues.settings__validation_warnings = True if self.settings.__dict__["_Settings__unexpected_args"]: tel.issues.settings__unexpected_args = True if not self.settings.label_disable: if self.notebook: run._label_probe_notebook(self.notebook) else: run._label_probe_main() logger.info("updated telemetry") run._set_library(self._wl) run._set_backend(backend) run._set_reporter(self._reporter) run._set_teardown_hooks(self._teardown_hooks) # TODO: pass mode to backend # run_synced = None backend._hack_set_run(run) assert backend.interface backend.interface.publish_header() # Using GitRepo() blocks & can be slow, depending on user's current git setup. # We don't want to block run initialization/start request, so populate run's git # info beforehand. if not self.settings.disable_git: run._populate_git_info() if self.settings._offline: with telemetry.context(run=run) as tel: tel.feature.offline = True run_proto = backend.interface._make_run(run) backend.interface._publish_run(run_proto) run._set_run_obj_offline(run_proto) if self.settings.resume: wandb.termwarn( "`resume` will be ignored since W&B syncing is set to `offline`. " f"Starting a new run with run id {run.id}.") else: logger.info("communicating run to backend with 30 second timeout") run_result = backend.interface.communicate_run(run, timeout=30) error_message: Optional[str] = None if not run_result: logger.error("backend process timed out") error_message = "Error communicating with wandb process" if active_start_method != "fork": error_message += "\ntry: wandb.init(settings=wandb.Settings(start_method='fork'))" error_message += "\nor: wandb.init(settings=wandb.Settings(start_method='thread'))" error_message += "\nFor more info see: https://docs.wandb.ai/library/init#init-start-error" elif run_result.error: error_message = run_result.error.message if error_message: logger.error(f"encountered error: {error_message}") # Shutdown the backend and get rid of the logger # we don't need to do console cleanup at this point backend.cleanup() self.teardown() raise UsageError(error_message) assert run_result and run_result.run if run_result.run.resumed: logger.info("run resumed") with telemetry.context(run=run) as tel: tel.feature.resumed = True run._set_run_obj(run_result.run) run._on_init() logger.info("starting run threads in backend") # initiate run (stats and metadata probing) run_obj = run._run_obj or run._run_obj_offline self.settings._apply_run_start(message_to_dict(run_obj)) run._update_settings(self.settings) if manager: manager._inform_start(settings=self.settings, run_id=self.settings.run_id) assert backend.interface assert run_obj _ = backend.interface.communicate_run_start(run_obj) self._wl._global_run_stack.append(run) self.run = run self.backend = backend module.set_global( run=run, config=run.config, log=run.log, summary=run.summary, save=run.save, use_artifact=run.use_artifact, log_artifact=run.log_artifact, define_metric=run.define_metric, plot_table=run.plot_table, alert=run.alert, mark_preempting=run.mark_preempting, ) self._reporter.set_context(run=run) run._on_start() run._freeze() logger.info("run started, returning control to user process") return run