def _user_process_finished(server, hooks, wandb_process, stdout_redirector, stderr_redirector): global _user_process_finished_called if _user_process_finished_called: return _user_process_finished_called = True trigger.call('on_finished') shutdown_async_log_thread() stdout_redirector.restore() if not env.is_debug(): stderr_redirector.restore() termlog() termlog("Waiting for W&B process to finish, PID {}".format( wandb_process.pid)) server.done(hooks.exit_code) try: while wandb_process.poll() is None: time.sleep(0.1) except KeyboardInterrupt: pass if wandb_process.poll() is None: termlog('Killing W&B process, PID {}'.format(wandb_process.pid)) wandb_process.kill()
def _user_process_finished(server, hooks, wandb_process, stdout_redirector, stderr_redirector): global _user_process_finished_called if _user_process_finished_called: return _user_process_finished_called = True trigger.call('on_finished') if run: run.close_files() stdout_redirector.restore() if not env.is_debug(): stderr_redirector.restore() termlog() termlog("Waiting for W&B process to finish, PID {}".format( wandb_process.pid)) server.done(hooks.exit_code) try: while wandb_process.poll() is None: time.sleep(0.1) except KeyboardInterrupt: termlog( 'Sending ctrl-c to W&B process, PID {}. Press ctrl-c again to kill it.' .format(wandb_process.pid)) try: while wandb_process.poll() is None: time.sleep(0.1) except KeyboardInterrupt: if wandb_process.poll() is None: termlog('Killing W&B process, PID {}'.format(wandb_process.pid)) wandb_process.kill()
def _fit_wrapper(self, fn, generator=None, *args, **kwargs): trigger.call("on_fit") keras = sys.modules.get("keras", None) tfkeras = sys.modules.get("tensorflow.python.keras", None) epochs = kwargs.pop("epochs", None) batch_size = kwargs.pop("batch_size", None) magic_epochs = _magic_get_config("keras.fit.epochs", None) if magic_epochs is not None: epochs = magic_epochs magic_batch_size = _magic_get_config("keras.fit.batch_size", None) if magic_batch_size is not None: batch_size = magic_batch_size callbacks = kwargs.pop("callbacks", []) if tb_enabled := _magic_get_config( "keras.fit.callbacks.tensorboard.enable", None): if k := getattr(self, "_keras_or_tfkeras", None): tb_duplicate = _magic_get_config( "keras.fit.callbacks.tensorboard.duplicate", None) tb_overwrite = _magic_get_config( "keras.fit.callbacks.tensorboard.overwrite", None) tb_present = any( isinstance(cb, k.callbacks.TensorBoard) for cb in callbacks) if tb_present and tb_overwrite: callbacks = [ cb for cb in callbacks if not isinstance(cb, k.callbacks.TensorBoard) ] if tb_overwrite or tb_duplicate or not tb_present: tb_callback_kwargs = {"log_dir": wandb.run.dir} cb_args = ( "write_graph", "histogram_freq", "update_freq", "write_grads", "write_images", "batch_size", ) for cb_arg in cb_args: v = _magic_get_config( f'keras.fit.callbacks.tensorboard.{cb_arg}', None) if v is not None: tb_callback_kwargs[cb_arg] = v tb_callback = k.callbacks.TensorBoard(**tb_callback_kwargs) callbacks.append(tb_callback)
def init(self): # noqa: C901 trigger.call("on_init", **self.kwargs) s = self.settings config = self.config if s._noop: run = Dummy() run.config = wandb.wandb_sdk.wandb_config.Config() run.config.update(config) run.summary = DummyDict() run.log = lambda data, *_, **__: run.summary.update(data) run.finish = lambda *_, **__: module.unset_globals() run.step = 0 run.resumed = False run.disabled = True run.id = shortuuid.uuid() run.name = "dummy-" + run.id run.dir = "/" module.set_global( run=run, config=run.config, log=run.log, summary=run.summary, save=run.save, use_artifact=run.use_artifact, log_artifact=run.log_artifact, plot_table=run.plot_table, alert=run.alert, ) return run if s.reinit or (s._jupyter and s.reinit is not False): if len(self._wl._global_run_stack) > 0: if len(self._wl._global_run_stack) > 1: wandb.termwarn( "If you want to track multiple runs concurrently in wandb you should use multi-processing not threads" # noqa: E501 ) last_id = self._wl._global_run_stack[-1]._run_id if s._jupyter and not s._silent: ipython.display_html( "Finishing last run (ID:{}) before initializing another...".format( last_id ) ) self._wl._global_run_stack[-1].finish() if s._jupyter and not s._silent: ipython.display_html( "...Successfully finished last run (ID:{}). Initializing new run:<br/><br/>".format( last_id ) ) elif isinstance(wandb.run, Run): logger.info("wandb.init() called when a run is still active") return wandb.run use_redirect = True stdout_master_fd, stderr_master_fd = None, None stdout_slave_fd, stderr_slave_fd = None, None backend = Backend() backend.ensure_launched( settings=s, stdout_fd=stdout_master_fd, stderr_fd=stderr_master_fd, use_redirect=use_redirect, ) backend.server_connect() # Make sure we are logged in # wandb_login._login(_backend=backend, _settings=self.settings) # resuming needs access to the server, check server_status()? run = Run(config=config, settings=s) run._set_console( use_redirect=use_redirect, stdout_slave_fd=stdout_slave_fd, stderr_slave_fd=stderr_slave_fd, ) run._set_library(self._wl) run._set_backend(backend) run._set_reporter(self._reporter) run._set_teardown_hooks(self._teardown_hooks) # TODO: pass mode to backend # run_synced = None backend._hack_set_run(run) backend.interface.publish_header() if s._offline: run_proto = backend.interface._make_run(run) backend.interface._publish_run(run_proto) run._set_run_obj_offline(run_proto) else: ret = backend.interface.communicate_check_version( current_version=wandb.__version__ ) if ret: if ret.upgrade_message: run._set_upgraded_version_message(ret.upgrade_message) if ret.delete_message: run._set_deleted_version_message(ret.delete_message) if ret.yank_message: run._set_yanked_version_message(ret.yank_message) run._on_init() ret = backend.interface.communicate_run(run, timeout=30) error_message = None if not ret: error_message = "Error communicating with backend" if ret and ret.error: error_message = ret.error.message if error_message: # Shutdown the backend and get rid of the logger # we don't need to do console cleanup at this point backend.cleanup() self.teardown() raise UsageError(error_message) run._set_run_obj(ret.run) # initiate run (stats and metadata probing) _ = backend.interface.communicate_run_start() self._wl._global_run_stack.append(run) self.run = run self.backend = backend module.set_global( run=run, config=run.config, log=run.log, summary=run.summary, save=run.save, use_artifact=run.use_artifact, log_artifact=run.log_artifact, plot_table=run.plot_table, alert=run.alert, ) self._reporter.set_context(run=run) run._on_start() run._freeze() return run
def init(self): # noqa: C901 assert logger logger.info("calling init triggers") trigger.call("on_init", **self.kwargs) s = self.settings sweep_config = self.sweep_config config = self.config logger.info( "wandb.init called with sweep_config: {}\nconfig: {}".format( sweep_config, config ) ) if s._noop: return self._make_run_disabled() if s.reinit or (s._jupyter and s.reinit is not False): if len(self._wl._global_run_stack) > 0: if len(self._wl._global_run_stack) > 1: wandb.termwarn( "If you want to track multiple runs concurrently in wandb you should use multi-processing not threads" # noqa: E501 ) last_id = self._wl._global_run_stack[-1]._run_id logger.info( "re-initializing run, found existing run on stack: {}".format( last_id ) ) jupyter = ( s._jupyter and not s._silent and ipython._get_python_type() == "jupyter" ) if jupyter: ipython.display_html( "Finishing last run (ID:{}) before initializing another...".format( last_id ) ) self._wl._global_run_stack[-1].finish() if jupyter: ipython.display_html( "...Successfully finished last run (ID:{}). Initializing new run:<br/><br/>".format( last_id ) ) elif isinstance(wandb.run, Run): logger.info("wandb.init() called when a run is still active") return wandb.run logger.info("starting backend") backend = Backend(settings=s) backend.ensure_launched() backend.server_connect() logger.info("backend started and connected") # Make sure we are logged in # wandb_login._login(_backend=backend, _settings=self.settings) # resuming needs access to the server, check server_status()? run = Run(config=config, settings=s, sweep_config=sweep_config) # probe the active start method active_start_method = None if s.start_method == "thread": active_start_method = s.start_method else: get_start_fn = getattr(backend._multiprocessing, "get_start_method", None) active_start_method = get_start_fn() if get_start_fn else None # Populate intial telemetry with telemetry.context(run=run) as tel: tel.cli_version = wandb.__version__ tel.python_version = platform.python_version() hf_version = _huggingface_version() if hf_version: tel.huggingface_version = hf_version if s._jupyter: tel.env.jupyter = True if s._kaggle: tel.env.kaggle = True if s._windows: tel.env.windows = True run._telemetry_imports(tel.imports_init) if active_start_method == "spawn": tel.env.start_spawn = True elif active_start_method == "fork": tel.env.start_fork = True elif active_start_method == "forkserver": tel.env.start_forkserver = True elif active_start_method == "thread": tel.env.start_thread = True logger.info("updated telemetry") run._set_library(self._wl) run._set_backend(backend) run._set_reporter(self._reporter) run._set_teardown_hooks(self._teardown_hooks) # TODO: pass mode to backend # run_synced = None backend._hack_set_run(run) backend.interface.publish_header() if s._offline: with telemetry.context(run=run) as tel: tel.feature.offline = True run_proto = backend.interface._make_run(run) backend.interface._publish_run(run_proto) run._set_run_obj_offline(run_proto) else: logger.info("communicating current version") ret = backend.interface.communicate_check_version( current_version=wandb.__version__ ) if ret: logger.info("got version response {}".format(ret)) if ret.upgrade_message: run._set_upgraded_version_message(ret.upgrade_message) if ret.delete_message: run._set_deleted_version_message(ret.delete_message) if ret.yank_message: run._set_yanked_version_message(ret.yank_message) run._on_init() logger.info("communicating run to backend with 30 second timeout") ret = backend.interface.communicate_run(run, timeout=30) error_message = None if not ret: logger.error("backend process timed out") error_message = "Error communicating with wandb process" if active_start_method != "fork": error_message += "\ntry: wandb.init(settings=wandb.Settings(start_method='fork'))" error_message += "\nor: wandb.init(settings=wandb.Settings(start_method='thread'))" error_message += "\nFor more info see: https://docs.wandb.ai/library/init#init-start-error" if ret and ret.error: error_message = ret.error.message if error_message: logger.error("encountered error: {}".format(error_message)) # Shutdown the backend and get rid of the logger # we don't need to do console cleanup at this point backend.cleanup() self.teardown() raise UsageError(error_message) if ret.run.resumed: logger.info("run resumed") with telemetry.context(run=run) as tel: tel.feature.resumed = True run._set_run_obj(ret.run) logger.info("starting run threads in backend") # initiate run (stats and metadata probing) run_obj = run._run_obj or run._run_obj_offline _ = backend.interface.communicate_run_start(run_obj) self._wl._global_run_stack.append(run) self.run = run self.backend = backend module.set_global( run=run, config=run.config, log=run.log, summary=run.summary, save=run.save, use_artifact=run.use_artifact, log_artifact=run.log_artifact, define_metric=run._define_metric, plot_table=run.plot_table, alert=run.alert, ) self._reporter.set_context(run=run) run._on_start() run._freeze() logger.info("run started, returning control to user process") return run
def init(job_type=None, dir=None, config=None, project=None, entity=None, reinit=None, tags=None, group=None, allow_val_change=False, resume=False, force=False, tensorboard=False, sync_tensorboard=False, name=None, notes=None, id=None, magic=None): """Initialize W&B If called from within Jupyter, initializes a new run and waits for a call to `wandb.log` to begin pushing metrics. Otherwise, spawns a new process to communicate with W&B. Args: job_type (str, optional): The type of job running, defaults to 'train' config (dict, argparse, or tf.FLAGS, optional): The hyper parameters to store with the run project (str, optional): The project to push metrics to entity (str, optional): The entity to push metrics to dir (str, optional): An absolute path to a directory where metadata will be stored group (str, optional): A unique string shared by all runs in a given group tags (list, optional): A list of tags to apply to the run id (str, optional): A globally unique (per project) identifier for the run name (str, optional): A display name which does not have to be unique notes (str, optional): A multiline string associated with the run reinit (bool, optional): Allow multiple calls to init in the same process resume (bool, str, optional): Automatically resume this run if run from the same machine, you can also pass a unique run_id sync_tensorboard (bool, optional): Synchronize wandb logs to tensorboard or tensorboardX force (bool, optional): Force authentication with wandb, defaults to False magic (bool, dict, or str, optional): magic configuration as bool, dict, json string, yaml filename Returns: A wandb.run object for metric and config logging. """ trigger.call('on_init', **locals()) global run global __stage_dir__ # We allow re-initialization when we're in Jupyter or explicity opt-in to it. in_jupyter = _get_python_type() != "python" if reinit or (in_jupyter and reinit != False): reset_env(exclude=env.immutable_keys()) run = None # TODO: deprecate tensorboard if tensorboard or sync_tensorboard and len(patched["tensorboard"]) == 0: util.get_module("wandb.tensorboard").patch() sagemaker_config = util.parse_sm_config() tf_config = util.parse_tfjob_config() if group == None: group = os.getenv(env.RUN_GROUP) if job_type == None: job_type = os.getenv(env.JOB_TYPE) if sagemaker_config: # Set run_id and potentially grouping if we're in SageMaker run_id = os.getenv('TRAINING_JOB_NAME') if run_id: os.environ[env.RUN_ID] = '-'.join([ run_id, os.getenv('CURRENT_HOST', socket.gethostname())]) conf = json.load( open("/opt/ml/input/config/resourceconfig.json")) if group == None and len(conf["hosts"]) > 1: group = os.getenv('TRAINING_JOB_NAME') # Set secret variables if os.path.exists("secrets.env"): for line in open("secrets.env", "r"): key, val = line.strip().split('=', 1) os.environ[key] = val elif tf_config: cluster = tf_config.get('cluster') job_name = tf_config.get('task', {}).get('type') task_index = tf_config.get('task', {}).get('index') if job_name is not None and task_index is not None: # TODO: set run_id for resuming? run_id = cluster[job_name][task_index].rsplit(":")[0] if job_type == None: job_type = job_name if group == None and len(cluster.get("worker", [])) > 0: group = cluster[job_name][0].rsplit("-"+job_name, 1)[0] image = util.image_id_from_k8s() if image: os.environ[env.DOCKER] = image if project: os.environ[env.PROJECT] = project if entity: os.environ[env.ENTITY] = entity if group: os.environ[env.RUN_GROUP] = group if job_type: os.environ[env.JOB_TYPE] = job_type if tags: os.environ[env.TAGS] = ",".join(tags) if id: os.environ[env.RUN_ID] = id if name is None: # We do this because of https://github.com/wandb/core/issues/2170 # to ensure that the run's name is explicitly set to match its # id. If we don't do this and the id is eight characters long, the # backend will set the name to a generated human-friendly value. # # In any case, if the user is explicitly setting `id` but not # `name`, their id is probably a meaningful string that we can # use to label the run. name = os.environ.get(env.NAME, id) # environment variable takes precedence over this. if name: os.environ[env.NAME] = name if notes: os.environ[env.NOTES] = notes if magic is not None and magic is not False: if isinstance(magic, dict): os.environ[env.MAGIC] = json.dumps(magic) elif isinstance(magic, str): os.environ[env.MAGIC] = magic elif isinstance(magic, bool): pass else: termwarn("wandb.init called with invalid magic parameter type", repeat=False) from wandb import magic_impl magic_impl.magic_install() if dir: os.environ[env.DIR] = dir util.mkdir_exists_ok(wandb_dir()) resume_path = os.path.join(wandb_dir(), wandb_run.RESUME_FNAME) if resume == True: os.environ[env.RESUME] = "auto" elif resume: os.environ[env.RESUME] = os.environ.get(env.RESUME, "allow") # TODO: remove allowing resume as a string in the future os.environ[env.RUN_ID] = id or resume elif os.path.exists(resume_path): os.remove(resume_path) if os.environ.get(env.RESUME) == 'auto' and os.path.exists(resume_path): if not os.environ.get(env.RUN_ID): os.environ[env.RUN_ID] = json.load(open(resume_path))["run_id"] # the following line is useful to ensure that no W&B logging happens in the user # process that might interfere with what they do # logging.basicConfig(format='user process %(asctime)s - %(name)s - %(levelname)s - %(message)s') # If a thread calls wandb.init() it will get the same Run object as # the parent. If a child process with distinct memory space calls # wandb.init(), it won't get an error, but it will get a result of # None. # This check ensures that a child process can safely call wandb.init() # after a parent has (only the parent will create the Run object). # This doesn't protect against the case where the parent doesn't call # wandb.init but two children do. if run or os.getenv(env.INITED): return run if __stage_dir__ is None: __stage_dir__ = "wandb" util.mkdir_exists_ok(wandb_dir()) try: signal.signal(signal.SIGQUIT, _debugger) except AttributeError: pass try: run = wandb_run.Run.from_environment_or_defaults() except IOError as e: termerror('Failed to create run directory: {}'.format(e)) raise LaunchError("Could not write to filesystem.") run.set_environment() def set_global_config(run): global config # because we already have a local config config = run.config set_global_config(run) global summary summary = run.summary # set this immediately after setting the run and the config. if there is an # exception after this it'll probably break the user script anyway os.environ[env.INITED] = '1' # we do these checks after setting the run and the config because users scripts # may depend on those things if sys.platform == 'win32' and run.mode != 'clirun': termerror( 'To use wandb on Windows, you need to run the command "wandb run python <your_train_script>.py"') return run if in_jupyter: _init_jupyter(run) elif run.mode == 'clirun': pass elif run.mode == 'run': api = InternalApi() # let init_jupyter handle this itself if not in_jupyter and not api.api_key: termlog( "W&B is a tool that helps track and visualize machine learning experiments") if force: termerror( "No credentials found. Run \"wandb login\" or \"wandb off\" to disable wandb") else: if run.check_anonymous(): _init_headless(run) else: termlog( "No credentials found. Run \"wandb login\" to visualize your metrics") run.mode = "dryrun" _init_headless(run, False) else: _init_headless(run) elif run.mode == 'dryrun': termlog( 'Dry run mode, not syncing to the cloud.') _init_headless(run, False) else: termerror( 'Invalid run mode "%s". Please unset WANDB_MODE.' % run.mode) raise LaunchError("The WANDB_MODE environment variable is invalid.") # set the run directory in the config so it actually gets persisted run.config.set_run_dir(run.dir) if sagemaker_config: run.config.update(sagemaker_config) allow_val_change = True if config: run.config.update(config, allow_val_change=allow_val_change) # Access history to ensure resumed is set when resuming run.history # Load the summary to support resuming run.summary.load() atexit.register(run.close_files) return run
def _fit_wrapper(self, fn, generator=None, *args, **kwargs): trigger.call('on_fit') keras = sys.modules.get("keras", None) tfkeras = sys.modules.get("tensorflow.python.keras", None) epochs = kwargs.pop("epochs", None) batch_size = kwargs.pop("batch_size", None) magic_epochs = _magic_get_config("keras.fit.epochs", None) if magic_epochs is not None: epochs = magic_epochs magic_batch_size = _magic_get_config("keras.fit.batch_size", None) if magic_batch_size is not None: batch_size = magic_batch_size callbacks = kwargs.pop("callbacks", []) tb_enabled = _magic_get_config("keras.fit.callbacks.tensorboard.enable", None) if tb_enabled: k = getattr(self, '_keras_or_tfkeras', None) if k: tb_duplicate = _magic_get_config( "keras.fit.callbacks.tensorboard.duplicate", None) tb_overwrite = _magic_get_config( "keras.fit.callbacks.tensorboard.overwrite", None) tb_present = any( [isinstance(cb, k.callbacks.TensorBoard) for cb in callbacks]) if tb_present and tb_overwrite: callbacks = [ cb for cb in callbacks if not isinstance(cb, k.callbacks.TensorBoard) ] if tb_overwrite or tb_duplicate or not tb_present: tb_callback_kwargs = {'log_dir': wandb.run.dir} cb_args = ('write_graph', 'histogram_freq', 'update_freq', 'write_grads', 'write_images', 'batch_size') for cb_arg in cb_args: v = _magic_get_config( "keras.fit.callbacks.tensorboard." + cb_arg, None) if v is not None: tb_callback_kwargs[cb_arg] = v tb_callback = k.callbacks.TensorBoard(**tb_callback_kwargs) callbacks.append(tb_callback) wandb_enabled = _magic_get_config("keras.fit.callbacks.wandb.enable", None) if wandb_enabled: wandb_duplicate = _magic_get_config( "keras.fit.callbacks.wandb.duplicate", None) wandb_overwrite = _magic_get_config( "keras.fit.callbacks.wandb.overwrite", None) wandb_present = any( [isinstance(cb, wandb.keras.WandbCallback) for cb in callbacks]) if wandb_present and wandb_overwrite: callbacks = [ cb for cb in callbacks if not isinstance(cb, wandb.keras.WandbCallback) ] if wandb_overwrite or wandb_duplicate or not wandb_present: wandb_callback_kwargs = {} log_gradients = _magic_get_config( "keras.fit.callbacks.wandb.log_gradients", None) if log_gradients and kwargs.get('x') and kwargs.get('y'): wandb_callback_kwargs['log_gradients'] = log_gradients cb_args = ("predictions", "log_weights", "data_type", "save_model", "save_weights_only", "monitor", "mode", "verbose", "input_type", "output_type", "log_evaluation", "labels") for cb_arg in cb_args: v = _magic_get_config("keras.fit.callbacks.wandb." + cb_arg, None) if v is not None: wandb_callback_kwargs[cb_arg] = v wandb_callback = wandb.keras.WandbCallback(**wandb_callback_kwargs) callbacks.append(wandb_callback) kwargs["callbacks"] = callbacks if epochs is not None: kwargs["epochs"] = epochs if batch_size is not None: kwargs["batch_size"] = batch_size if generator: return fn(generator, *args, **kwargs) return fn(*args, **kwargs)
def init(self): trigger.call("on_init", **self.kwargs) s = self.settings config = self.config if s._noop: run = RunDummy() module.set_global( run=run, config=run.config, log=run.log, summary=run.summary, save=run.save, restore=run.restore, use_artifact=run.use_artifact, log_artifact=run.log_artifact, plot_table=run.plot_table, ) return run if s.reinit or (s._jupyter and s.reinit is not False): if len(self._wl._global_run_stack) > 0: if len(self._wl._global_run_stack) > 1: wandb.termwarn( "If you want to track multiple runs concurrently in wandb you should use multi-processing not threads" # noqa: E501 ) self._wl._global_run_stack[-1].finish() elif wandb.run: logger.info("wandb.init() called when a run is still active") return wandb.run use_redirect = True stdout_master_fd, stderr_master_fd = None, None stdout_slave_fd, stderr_slave_fd = None, None backend = Backend() backend.ensure_launched( settings=s, stdout_fd=stdout_master_fd, stderr_fd=stderr_master_fd, use_redirect=use_redirect, ) backend.server_connect() # Make sure we are logged in # wandb_login._login(_backend=backend, _settings=self.settings) # resuming needs access to the server, check server_status()? run = Run(config=config, settings=s) run._set_console( use_redirect=use_redirect, stdout_slave_fd=stdout_slave_fd, stderr_slave_fd=stderr_slave_fd, ) run._set_library(self._wl) run._set_backend(backend) run._set_reporter(self._reporter) run._set_teardown_hooks(self._teardown_hooks) # TODO: pass mode to backend # run_synced = None backend._hack_set_run(run) backend.interface.publish_header() if s._offline: run_proto = backend.interface._make_run(run) backend.interface._publish_run(run_proto) run._set_run_obj_offline(run_proto) else: ret = backend.interface.communicate_check_version( current_version=wandb.__version__) if ret: if ret.upgrade_message: run._set_upgraded_version_message(ret.upgrade_message) if ret.delete_message: run._set_deleted_version_message(ret.delete_message) if ret.yank_message: run._set_yanked_version_message(ret.yank_message) run._on_init() ret = backend.interface.communicate_run(run, timeout=30) error_message = None if not ret: error_message = "Error communicating with backend" if ret and ret.error: error_message = ret.error.message if error_message: # Shutdown the backend and get rid of the logger # we don't need to do console cleanup at this point backend.cleanup() self.teardown() raise UsageError(error_message) run._set_run_obj(ret.run) # initiate run (stats and metadata probing) _ = backend.interface.communicate_run_start() self._wl._global_run_stack.append(run) self.run = run self.backend = backend module.set_global( run=run, config=run.config, log=run.log, summary=run.summary, save=run.save, restore=run.restore, use_artifact=run.use_artifact, log_artifact=run.log_artifact, plot_table=run.plot_table, ) self._reporter.set_context(run=run) run._on_start() run._freeze() return run
def init(self) -> Union[Run, RunDisabled, None]: # noqa: C901 if logger is None: raise RuntimeError("Logger not initialized") logger.info("calling init triggers") trigger.call("on_init", **self.kwargs) logger.info( f"wandb.init called with sweep_config: {self.sweep_config}\nconfig: {self.config}" ) if self.settings._noop: return self._make_run_disabled() if self.settings.reinit or (self.settings._jupyter and self.settings.reinit is not False): if len(self._wl._global_run_stack) > 0: if len(self._wl._global_run_stack) > 1: wandb.termwarn( "If you want to track multiple runs concurrently in wandb, " "you should use multi-processing not threads" # noqa: E501 ) last_id = self._wl._global_run_stack[-1]._run_id logger.info( f"re-initializing run, found existing run on stack: {last_id}" ) jupyter = (self.settings._jupyter and not self.settings.silent and ipython.in_jupyter()) if jupyter: ipython.display_html( f"Finishing last run (ID:{last_id}) before initializing another..." ) self._wl._global_run_stack[-1].finish() if jupyter: ipython.display_html( f"Successfully finished last run (ID:{last_id}). Initializing new run:<br/>" ) elif isinstance(wandb.run, Run): allow_return_run = True manager = self._wl._get_manager() if manager: current_pid = os.getpid() if current_pid != wandb.run._init_pid: # We shouldn't return a stale global run if we are in a new pid allow_return_run = False if allow_return_run: logger.info("wandb.init() called when a run is still active") with telemetry.context() as tel: tel.feature.init_return_run = True return wandb.run logger.info("starting backend") manager = self._wl._get_manager() if manager: manager._inform_init(settings=self.settings, run_id=self.settings.run_id) backend = Backend(settings=self.settings, manager=manager) backend.ensure_launched() backend.server_connect() logger.info("backend started and connected") # Make sure we are logged in # wandb_login._login(_backend=backend, _settings=self.settings) # resuming needs access to the server, check server_status()? run = Run(config=self.config, settings=self.settings, sweep_config=self.sweep_config) # probe the active start method active_start_method: Optional[str] = None if self.settings.start_method == "thread": active_start_method = self.settings.start_method else: get_start_fn = getattr(backend._multiprocessing, "get_start_method", None) active_start_method = get_start_fn() if get_start_fn else None # Populate initial telemetry with telemetry.context(run=run) as tel: tel.cli_version = wandb.__version__ tel.python_version = platform.python_version() hf_version = _huggingface_version() if hf_version: tel.huggingface_version = hf_version if self.settings._jupyter: tel.env.jupyter = True if self.settings._kaggle: tel.env.kaggle = True if self.settings._windows: tel.env.windows = True run._telemetry_imports(tel.imports_init) if self._use_sagemaker: tel.feature.sagemaker = True if self._set_init_config: tel.feature.set_init_config = True if self._set_init_name: tel.feature.set_init_name = True if self._set_init_id: tel.feature.set_init_id = True if self._set_init_tags: tel.feature.set_init_tags = True if self.settings.launch: tel.feature.launch = True if active_start_method == "spawn": tel.env.start_spawn = True elif active_start_method == "fork": tel.env.start_fork = True elif active_start_method == "forkserver": tel.env.start_forkserver = True elif active_start_method == "thread": tel.env.start_thread = True if manager: tel.feature.service = True tel.env.maybe_mp = _maybe_mp_process(backend) # fixme: detected issues with settings if self.settings.__dict__["_Settings__preprocessing_warnings"]: tel.issues.settings__preprocessing_warnings = True if self.settings.__dict__["_Settings__validation_warnings"]: tel.issues.settings__validation_warnings = True if self.settings.__dict__["_Settings__unexpected_args"]: tel.issues.settings__unexpected_args = True if not self.settings.label_disable: if self.notebook: run._label_probe_notebook(self.notebook) else: run._label_probe_main() logger.info("updated telemetry") run._set_library(self._wl) run._set_backend(backend) run._set_reporter(self._reporter) run._set_teardown_hooks(self._teardown_hooks) # TODO: pass mode to backend # run_synced = None backend._hack_set_run(run) assert backend.interface backend.interface.publish_header() # Using GitRepo() blocks & can be slow, depending on user's current git setup. # We don't want to block run initialization/start request, so populate run's git # info beforehand. if not self.settings.disable_git: run._populate_git_info() if self.settings._offline: with telemetry.context(run=run) as tel: tel.feature.offline = True run_proto = backend.interface._make_run(run) backend.interface._publish_run(run_proto) run._set_run_obj_offline(run_proto) if self.settings.resume: wandb.termwarn( "`resume` will be ignored since W&B syncing is set to `offline`. " f"Starting a new run with run id {run.id}.") else: logger.info("communicating run to backend with 30 second timeout") run_result = backend.interface.communicate_run(run, timeout=30) error_message: Optional[str] = None if not run_result: logger.error("backend process timed out") error_message = "Error communicating with wandb process" if active_start_method != "fork": error_message += "\ntry: wandb.init(settings=wandb.Settings(start_method='fork'))" error_message += "\nor: wandb.init(settings=wandb.Settings(start_method='thread'))" error_message += "\nFor more info see: https://docs.wandb.ai/library/init#init-start-error" elif run_result.error: error_message = run_result.error.message if error_message: logger.error(f"encountered error: {error_message}") # Shutdown the backend and get rid of the logger # we don't need to do console cleanup at this point backend.cleanup() self.teardown() raise UsageError(error_message) assert run_result and run_result.run if run_result.run.resumed: logger.info("run resumed") with telemetry.context(run=run) as tel: tel.feature.resumed = True run._set_run_obj(run_result.run) run._on_init() logger.info("starting run threads in backend") # initiate run (stats and metadata probing) run_obj = run._run_obj or run._run_obj_offline self.settings._apply_run_start(message_to_dict(run_obj)) run._update_settings(self.settings) if manager: manager._inform_start(settings=self.settings, run_id=self.settings.run_id) assert backend.interface assert run_obj _ = backend.interface.communicate_run_start(run_obj) self._wl._global_run_stack.append(run) self.run = run self.backend = backend module.set_global( run=run, config=run.config, log=run.log, summary=run.summary, save=run.save, use_artifact=run.use_artifact, log_artifact=run.log_artifact, define_metric=run.define_metric, plot_table=run.plot_table, alert=run.alert, mark_preempting=run.mark_preempting, ) self._reporter.set_context(run=run) run._on_start() run._freeze() logger.info("run started, returning control to user process") return run