class JupyterAgent(object): """A class that only logs metrics after `wandb.log` has been called and stops logging at cell completion""" def __init__(self): self.paused = True def start(self): if self.paused: self.api = InternalApi() self.rm = RunManager(self.api, wandb.run, output=False) self.api._file_stream_api = None self.api.set_current_run_id(wandb.run.id) self.rm.mirror_stdout_stderr() self.paused = False # Init will return the last step of a resumed run # we update the runs history._steps in extreme hack fashion # TODO: this reserves a bigtime refactor new_step = self.rm.init_run(dict(os.environ)) if new_step: wandb.run.history._steps = new_step + 1 def stop(self): if not self.paused: self.rm.unmirror_stdout_stderr() self.rm.shutdown() wandb.run.close_files() self.paused = True
def upload_debug(self): """Uploads the debug log to cloud storage""" if os.path.exists(self.log_fname): api = InternalApi() api.set_current_run_id(self.id) pusher = FilePusher(api) pusher.update_file("wandb-debug.log", self.log_fname) pusher.file_changed("wandb-debug.log", self.log_fname) pusher.finish()
def fake_run_manager(mocker, run=None, cloud=True, rm_class=wandb.run_manager.RunManager): # NOTE: This will create a run directory so make sure it's called in an isolated file system # We have an optional rm_class object because we mock it above so we need it before it's mocked api = InternalApi(load_settings=False) api.set_setting('project', 'testing') if wandb.run is None: wandb.run = run or Run() wandb.config = wandb.run.config wandb.run._api = api wandb.run._mkdir() wandb.run.socket = wandb_socket.Server() api.set_current_run_id(wandb.run.id) mocker.patch('wandb.apis.internal.FileStreamApi') api._file_stream_api = mocker.MagicMock() run_manager = rm_class(wandb.run, cloud=cloud, port=wandb.run.socket.port) class FakeProc(object): def poll(self): return None def exit(self, code=0): return None run_manager.proc = FakeProc() run_manager._meta = mocker.MagicMock() run_manager._stdout_tee = mocker.MagicMock() run_manager._stderr_tee = mocker.MagicMock() run_manager._output_log = mocker.MagicMock() run_manager._stdout_stream = mocker.MagicMock() run_manager._stderr_stream = mocker.MagicMock() run_manager.mirror_stdout_stderr = mocker.MagicMock() run_manager.unmirror_stdout_stderr = mocker.MagicMock() socket_thread = threading.Thread(target=wandb.run.socket.listen) socket_thread.start() run_manager._socket.ready() thread = threading.Thread(target=run_manager._sync_etc) thread.daemon = True thread.start() def test_shutdown(): if wandb.run and wandb.run.socket: wandb.run.socket.done() # TODO: is this needed? socket_thread.join() thread.join() run_manager.test_shutdown = test_shutdown run_manager._unblock_file_observer() run_manager._file_pusher._push_function = mocker.MagicMock() return run_manager
def _init_jupyter(run): """Asks for user input to configure the machine if it isn't already and creates a new run. Log pushing and system stats don't start until `wandb.monitor()` is called. """ from wandb import jupyter # TODO: Should we log to jupyter? # global logging had to be disabled because it set the level to debug # I also disabled run logging because we're rairly using it. # try_to_set_up_global_logging() # run.enable_logging() api = InternalApi() if not api.api_key: termerror( "Not authenticated. Copy a key from https://app.wandb.ai/authorize" ) key = getpass.getpass("API Key: ").strip() if len(key) == 40: os.environ[env.API_KEY] = key util.write_netrc(api.api_url, "user", key) else: raise ValueError("API Key must be 40 characters long") # Ensure our api client picks up the new key api = InternalApi() os.environ["WANDB_JUPYTER"] = "true" run.resume = "allow" api.set_current_run_id(run.id) print("W&B Run: %s" % run.get_url(api)) print( "Call `%%wandb` in the cell containing your training loop to display live results." ) try: run.save(api=api) except (CommError, ValueError) as e: termerror(str(e)) run.set_environment() run._init_jupyter_agent() ipython = get_ipython() ipython.register_magics(jupyter.WandBMagics) def reset_start(): """Reset START_TIME to when the cell starts""" global START_TIME START_TIME = time.time() ipython.events.register("pre_run_cell", reset_start) ipython.events.register('post_run_cell', run._stop_jupyter_agent)
class Run(object): def __init__(self, run_id=None, mode=None, dir=None, group=None, job_type=None, config=None, sweep_id=None, storage_id=None, description=None, resume=None, program=None, args=None, wandb_dir=None, tags=None, name=None, notes=None, api=None): """Create a Run. Arguments: description (str): This is the old, deprecated style of description: the run's name followed by a newline, followed by multiline notes. """ # self.storage_id is "id" in GQL. self.storage_id = storage_id # self.id is "name" in GQL. self.id = run_id if run_id else util.generate_id() # self._name is "display_name" in GQL. self._name = None self.notes = None self.resume = resume if resume else 'never' self.mode = mode if mode else 'run' self.group = group self.job_type = job_type self.pid = os.getpid() self.resumed = False # we set resume when history is first accessed if api: if api.current_run_id and api.current_run_id != self.id: raise RuntimeError( 'Api object passed to run {} is already being used by run {}' .format(self.id, api.current_run_id)) else: api.set_current_run_id(self.id) self._api = api if dir is None: self._dir = run_dir_path(self.id, dry=self.mode == 'dryrun') else: self._dir = os.path.abspath(dir) self._mkdir() # self.name and self.notes used to be combined into a single field. # Now if name and notes don't have their own values, we get them from # self._name_and_description, but we don't update description.md # if they're changed. This is to discourage relying on self.description # and self._name_and_description so that we can drop them later. # # This needs to be set before name and notes because name and notes may # influence it. They have higher precedence. self._name_and_description = None if description: wandb.termwarn( 'Run.description is deprecated. Please use wandb.init(notes="long notes") instead.' ) self._name_and_description = description elif os.path.exists(self.description_path): with open(self.description_path) as d_file: self._name_and_description = d_file.read() if name is not None: self.name = name if notes is not None: self.notes = notes self.program = program if not self.program: try: import __main__ self.program = __main__.__file__ except (ImportError, AttributeError): # probably `python -c`, an embedded interpreter or something self.program = '<python with no main file>' self.args = args if self.args is None: self.args = sys.argv[1:] self.wandb_dir = wandb_dir with configure_scope() as scope: self.project = self.api.settings("project") scope.set_tag("project", self.project) scope.set_tag("entity", self.entity) try: scope.set_tag("url", self.get_url(self.api, network=False) ) # TODO: Move this somewhere outside of init except CommError: pass if self.resume == "auto": util.mkdir_exists_ok(wandb.wandb_dir()) resume_path = os.path.join(wandb.wandb_dir(), RESUME_FNAME) with open(resume_path, "w") as f: f.write(json.dumps({"run_id": self.id})) if config is None: self.config = Config() else: self.config = config # socket server, currently only available in headless mode self.socket = None self.tags = tags if tags else [] self.sweep_id = sweep_id self._history = None self._events = None self._summary = None self._meta = None self._run_manager = None self._jupyter_agent = None @property def config_static(self): return ConfigStatic(self.config) @property def api(self): if self._api is None: self._api = InternalApi() self._api.set_current_run_id(self.id) return self._api @property def entity(self): return self.api.settings('entity') @entity.setter def entity(self, entity): self.api.set_setting("entity", entity) @property def path(self): # TODO: theres an edge case where self.entity is None return "/".join([str(self.entity), self.project_name(), self.id]) def _init_jupyter_agent(self): from wandb.jupyter import JupyterAgent self._jupyter_agent = JupyterAgent() def _stop_jupyter_agent(self): self._jupyter_agent.stop() def send_message(self, options): """ Sends a message to the wandb process changing the policy of saved files. This is primarily used internally by wandb.save """ if not options.get("save_policy") and not options.get("tensorboard"): raise ValueError( "Only configuring save_policy and tensorboard is supported") if self.socket: # In the user process self.socket.send(options) elif self._jupyter_agent: # Running in jupyter self._jupyter_agent.start() if options.get("save_policy"): self._jupyter_agent.rm.update_user_file_policy( options["save_policy"]) elif options.get("tensorboard"): self._jupyter_agent.rm.start_tensorboard_watcher( options["tensorboard"]["logdir"], options["tensorboard"]["save"]) elif self._run_manager: # Running in the wandb process, used for tfevents saving if options.get("save_policy"): self._run_manager.update_user_file_policy( options["save_policy"]) else: wandb.termerror( "wandb.init hasn't been called, can't configure run") @classmethod def from_environment_or_defaults(cls, environment=None): """Create a Run object taking values from the local environment where possible. The run ID comes from WANDB_RUN_ID or is randomly generated. The run mode ("dryrun", or "run") comes from WANDB_MODE or defaults to "dryrun". The run directory comes from WANDB_RUN_DIR or is generated from the run ID. The Run will have a .config attribute but its run directory won't be set by default. """ if environment is None: environment = os.environ run_id = environment.get(env.RUN_ID) resume = environment.get(env.RESUME) storage_id = environment.get(env.RUN_STORAGE_ID) mode = environment.get(env.MODE) api = InternalApi(environ=environment) disabled = api.disabled() if not mode and disabled: mode = "dryrun" elif disabled and mode != "dryrun": wandb.termwarn( "WANDB_MODE is set to run, but W&B was disabled. Run `wandb on` to remove this message" ) elif disabled: wandb.termlog( 'W&B is disabled in this directory. Run `wandb on` to enable cloud syncing.' ) group = environment.get(env.RUN_GROUP) job_type = environment.get(env.JOB_TYPE) run_dir = environment.get(env.RUN_DIR) sweep_id = environment.get(env.SWEEP_ID) program = environment.get(env.PROGRAM) description = environment.get(env.DESCRIPTION) name = environment.get(env.NAME) notes = environment.get(env.NOTES) args = env.get_args(env=environment) wandb_dir = env.get_dir(env=environment) tags = env.get_tags(env=environment) # TODO(adrian): should pass environment into here as well. config = Config.from_environment_or_defaults() run = cls(run_id, mode, run_dir, group, job_type, config, sweep_id, storage_id, program=program, description=description, args=args, wandb_dir=wandb_dir, tags=tags, name=name, notes=notes, resume=resume, api=api) return run @classmethod def from_directory(cls, directory, project=None, entity=None, run_id=None, api=None, ignore_globs=None): api = api or InternalApi() run_id = run_id or util.generate_id() run = Run(run_id=run_id, dir=directory) run_name = None project_from_meta = None snap = DirectorySnapshot(directory) meta = next((p for p in snap.paths if METADATA_FNAME in p), None) if meta: meta = json.load(open(meta)) run_name = meta.get("name") project_from_meta = meta.get("project") project = project or project_from_meta or api.settings( "project") or run.auto_project_name(api=api) if project is None: raise ValueError("You must specify project") api.set_current_run_id(run_id) api.set_setting("project", project) if entity: api.set_setting("entity", entity) res = api.upsert_run(name=run_id, project=project, entity=entity, display_name=run_name) entity = res["project"]["entity"]["name"] wandb.termlog("Syncing {} to:".format(directory)) try: wandb.termlog(res["displayName"] + " " + run.get_url(api)) except CommError as e: wandb.termwarn(e.message) file_api = api.get_file_stream_api() file_api.start() paths = [ os.path.relpath(abs_path, directory) for abs_path in snap.paths if os.path.isfile(abs_path) ] if ignore_globs: paths = set(paths) for g in ignore_globs: paths = paths - set(fnmatch.filter(paths, g)) paths = list(paths) run_update = {"id": res["id"]} tfevents = sorted([p for p in snap.paths if ".tfevents." in p]) history = next((p for p in snap.paths if HISTORY_FNAME in p), None) event = next((p for p in snap.paths if EVENTS_FNAME in p), None) config = next((p for p in snap.paths if CONFIG_FNAME in p), None) user_config = next((p for p in snap.paths if USER_CONFIG_FNAME in p), None) summary = next((p for p in snap.paths if SUMMARY_FNAME in p), None) if history: wandb.termlog("Uploading history metrics") file_api.stream_file(history) snap.paths.remove(history) elif len(tfevents) > 0: from wandb import tensorflow as wbtf wandb.termlog("Found tfevents file, converting...") summary = {} for path in tfevents: filename = os.path.basename(path) namespace = path.replace(filename, "").replace(directory, "").strip(os.sep) summary.update( wbtf.stream_tfevents(path, file_api, run, namespace=namespace)) for path in glob.glob(os.path.join(directory, "media/**/*"), recursive=True): if os.path.isfile(path): paths.append(path) else: wandb.termerror( "No history or tfevents files found, only syncing files") if event: file_api.stream_file(event) snap.paths.remove(event) if config: run_update["config"] = util.load_yaml(open(config)) elif user_config: # TODO: half backed support for config.json run_update["config"] = { k: { "value": v } for k, v in six.iteritems(user_config) } if isinstance(summary, dict): #TODO: summary should already have data_types converted here... run_update["summary_metrics"] = util.json_dumps_safer(summary) elif summary: run_update["summary_metrics"] = open(summary).read() if meta: if meta.get("git"): run_update["commit"] = meta["git"].get("commit") run_update["repo"] = meta["git"].get("remote") if meta.get("host"): run_update["host"] = meta["host"] run_update["program_path"] = meta["program"] run_update["job_type"] = meta.get("jobType") run_update["notes"] = meta.get("notes") else: run_update["host"] = run.host wandb.termlog("Updating run and uploading files") api.upsert_run(**run_update) pusher = FilePusher(api) for k in paths: path = os.path.abspath(os.path.join(directory, k)) pusher.update_file(k, path) pusher.file_changed(k, path) pusher.finish() pusher.print_status() file_api.finish(0) # Remove temporary media images generated from tfevents if history is None and os.path.exists(os.path.join(directory, "media")): shutil.rmtree(os.path.join(directory, "media")) wandb.termlog("Finished!") return run def auto_project_name(self, api): return util.auto_project_name(self.program, api) def save(self, id=None, program=None, summary_metrics=None, num_retries=None, api=None): api = api or self.api project = api.settings('project') if project is None: project = self.auto_project_name(api) upsert_result = api.upsert_run( id=id or self.storage_id, name=self.id, commit=api.git.last_commit, project=project, entity=self.entity, group=self.group, tags=self.tags if len(self.tags) > 0 else None, config=self.config.as_dict(), description=self._name_and_description, host=self.host, program_path=program or self.program, repo=api.git.remote_url, sweep_name=self.sweep_id, display_name=self._name, notes=self.notes, summary_metrics=summary_metrics, job_type=self.job_type, num_retries=num_retries) self.storage_id = upsert_result['id'] self.name = upsert_result.get('displayName') return upsert_result def set_environment(self, environment=None): """Set environment variables needed to reconstruct this object inside a user scripts (eg. in `wandb.init()`). """ if environment is None: environment = os.environ environment[env.RUN_ID] = self.id environment[env.RESUME] = self.resume if self.storage_id: environment[env.RUN_STORAGE_ID] = self.storage_id environment[env.MODE] = self.mode environment[env.RUN_DIR] = self.dir if self.group: environment[env.RUN_GROUP] = self.group if self.job_type: environment[env.JOB_TYPE] = self.job_type if self.wandb_dir: environment[env.DIR] = self.wandb_dir if self.sweep_id is not None: environment[env.SWEEP_ID] = self.sweep_id if self.program is not None: environment[env.PROGRAM] = self.program if self.args is not None: environment[env.ARGS] = json.dumps(self.args) if self._name_and_description is not None: environment[env.DESCRIPTION] = self._name_and_description if self._name is not None: environment[env.NAME] = self._name if self.notes is not None: environment[env.NOTES] = self.notes if len(self.tags) > 0: environment[env.TAGS] = ",".join(self.tags) return environment def _mkdir(self): util.mkdir_exists_ok(self._dir) def project_name(self, api=None): api = api or self.api return api.settings('project') or self.auto_project_name( api) or "uncategorized" def _generate_query_string(self, api, params=None): """URL encodes dictionary of params""" params = params or {} if str(api.settings().get('anonymous', 'false')) == 'true': params['apiKey'] = api.api_key if not params: return "" return '?' + urllib.parse.urlencode(params) def _load_entity(self, api, network): if not api.api_key: raise CommError( "Can't find API key, run wandb login or set WANDB_API_KEY") entity = api.settings('entity') if network: if api.settings('entity') is None: viewer = api.viewer() if viewer.get('entity'): api.set_setting('entity', viewer['entity']) entity = api.settings('entity') if not entity: # This can happen on network failure raise CommError( "Can't connect to network to query entity from API key") return entity def get_project_url(self, api=None, network=True, params=None): """Generate a url for a project. If network is false and entity isn't specified in the environment raises wandb.apis.CommError """ params = params or {} api = api or self.api self._load_entity(api, network) return "{base}/{entity}/{project}{query_string}".format( base=api.app_url, entity=urllib.parse.quote_plus(api.settings('entity')), project=urllib.parse.quote_plus(self.project_name(api)), query_string=self._generate_query_string(api, params)) def get_sweep_url(self, api=None, network=True, params=None): """Generate a url for a sweep. If network is false and entity isn't specified in the environment raises wandb.apis.CommError Returns: string - url if the run is part of a sweep None - if the run is not part of the sweep """ params = params or {} api = api or self.api self._load_entity(api, network) sweep_id = self.sweep_id if sweep_id is None: return return "{base}/{entity}/{project}/sweeps/{sweepid}{query_string}".format( base=api.app_url, entity=urllib.parse.quote_plus(api.settings('entity')), project=urllib.parse.quote_plus(self.project_name(api)), sweepid=urllib.parse.quote_plus(sweep_id), query_string=self._generate_query_string(api, params)) def get_url(self, api=None, network=True, params=None): """Generate a url for a run. If network is false and entity isn't specified in the environment raises wandb.apis.CommError """ params = params or {} api = api or self.api self._load_entity(api, network) return "{base}/{entity}/{project}/runs/{run}{query_string}".format( base=api.app_url, entity=urllib.parse.quote_plus(api.settings('entity')), project=urllib.parse.quote_plus(self.project_name(api)), run=urllib.parse.quote_plus(self.id), query_string=self._generate_query_string(api, params)) def upload_debug(self): """Uploads the debug log to cloud storage""" if os.path.exists(self.log_fname): pusher = FilePusher(self.api) pusher.update_file("wandb-debug.log", self.log_fname) pusher.file_changed("wandb-debug.log", self.log_fname) pusher.finish() def __repr__(self): try: return "W&B Run: %s" % self.get_url() except CommError as e: return "W&B Error: %s" % e.message @property def name(self): if self._name is not None: return self._name elif self._name_and_description is not None: return self._name_and_description.split("\n")[0] else: return None @name.setter def name(self, name): self._name = name if self._name_and_description is not None: parts = self._name_and_description.split("\n", 1) parts[0] = name self._name_and_description = "\n".join(parts) @property def description(self): wandb.termwarn( 'Run.description is deprecated. Please use run.notes instead.') if self._name_and_description is None: self._name_and_description = '' parts = self._name_and_description.split("\n", 1) if len(parts) > 1: return parts[1] else: return "" @description.setter def description(self, desc): wandb.termwarn( 'Run.description is deprecated. Please use wandb.init(notes="long notes") instead.' ) if self._name_and_description is None: self._name_and_description = self._name or "" parts = self._name_and_description.split("\n", 1) if len(parts) == 1: parts.append("") parts[1] = desc self._name_and_description = "\n".join(parts) with open(self.description_path, 'w') as d_file: d_file.write(self._name_and_description) @property def host(self): return os.environ.get(env.HOST, socket.gethostname()) @property def dir(self): return self._dir @property def log_fname(self): # TODO: we started work to log to a file in the run dir, but it had issues. # For now all logs goto the same place. return util.get_log_file_path() def enable_logging(self): """Enable logging to the global debug log. This adds a run_id to the log, in case of muliple processes on the same machine. Currently no way to disable logging after it's enabled. """ handler = logging.FileHandler(self.log_fname) handler.setLevel(logging.INFO) run_id = self.id class WBFilter(logging.Filter): def filter(self, record): record.run_id = run_id return True formatter = logging.Formatter( '%(asctime)s %(levelname)-7s %(threadName)-10s:%(process)d [%(run_id)s:%(filename)s:%(funcName)s():%(lineno)s] %(message)s' ) handler.setFormatter(formatter) handler.addFilter(WBFilter()) root = logging.getLogger() root.addHandler(handler) @property def summary(self): if self._summary is None: self._summary = summary.FileSummary(self) return self._summary @property def has_summary(self): return self._summary or os.path.exists( os.path.join(self._dir, summary.SUMMARY_FNAME)) def _history_added(self, row): self.summary.update(row, overwrite=False) def log(self, row=None, commit=None, step=None, sync=True, *args, **kwargs): if sync == False: wandb._ensure_async_log_thread_started() return wandb._async_log_queue.put({ "row": row, "commit": commit, "step": step }) if row is None: row = {} for k in row: if isinstance(row[k], Visualize): self._add_viz(k, row[k].viz_id) row[k] = row[k].value if not isinstance(row, collections.Mapping): raise ValueError("wandb.log must be passed a dictionary") if any(not isinstance(key, six.string_types) for key in row.keys()): raise ValueError( "Key values passed to `wandb.log` must be strings.") if commit is not False or step is not None: self.history.add(row, *args, step=step, commit=commit, **kwargs) else: self.history.update(row, *args, **kwargs) def _add_viz(self, key, viz_id): if not 'viz' in self.config['_wandb']: self.config._set_wandb('viz', {}) self.config['_wandb']['viz'][key] = { 'id': viz_id, 'historyFieldSettings': { 'key': key, 'x-axis': '_step' } } self.config.persist() @property def history(self): if self._history is None: jupyter_callback = self._jupyter_agent.start if self._jupyter_agent else None self._history = history.History(self, add_callback=self._history_added, jupyter_callback=jupyter_callback) if self._history._steps > 0: self.resumed = True return self._history @property def step(self): return self.history._steps @property def has_history(self): return self._history or os.path.exists( os.path.join(self._dir, HISTORY_FNAME)) @property def events(self): if self._events is None: self._events = jsonlfile.JsonlEventsFile(EVENTS_FNAME, self._dir) return self._events @property def has_events(self): return self._events or os.path.exists( os.path.join(self._dir, EVENTS_FNAME)) @property def description_path(self): return os.path.join(self.dir, DESCRIPTION_FNAME) def close_files(self): """Close open files to avoid Python warnings on termination: Exception ignored in: <_io.FileIO name='wandb/dryrun-20180130_144602-9vmqjhgy/wandb-history.jsonl' mode='wb' closefd=True> ResourceWarning: unclosed file <_io.TextIOWrapper name='wandb/dryrun-20180130_144602-9vmqjhgy/wandb-history.jsonl' mode='w' encoding='UTF-8'> """ if self._events is not None: self._events.close() self._events = None if self._history is not None: self._history.close() self._history = None def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): exit_code = 0 if exc_type is None else 1 wandb.join(exit_code) return exc_type is None
def wandb_init_run(request, tmpdir, request_mocker, upsert_run, query_run_resume_status, upload_logs, monkeypatch, mocker, capsys, local_netrc): """Fixture that calls wandb.init(), yields a run (or an exception) that gets created, then cleans up afterward. This is meant to test the logic in wandb.init, it should generally not spawn a run_manager. If you need to test run_manager logic use that fixture. """ # save the environment so we can restore it later. pytest # may actually do this itself. didn't check. orig_environ = dict(os.environ) orig_namespace = None run = None api = InternalApi(load_settings=False) try: with CliRunner().isolated_filesystem(): upsert_run(request_mocker) if request.node.get_closest_marker('jupyter'): query_run_resume_status(request_mocker) def fake_ipython(): class Jupyter(object): __module__ = "jupyter" def __init__(self): class Hook(object): def register(self, what, where): pass self.events = Hook() def register_magics(self, magic): pass return Jupyter() wandb.get_ipython = fake_ipython # no i/o wrapping - it breaks pytest os.environ['WANDB_MODE'] = 'clirun' if request.node.get_closest_marker('headless'): mocker.patch('subprocess.Popen') else: def mock_headless(run, cloud=True): print("_init_headless called with cloud=%s" % cloud) mocker.patch('wandb._init_headless', mock_headless) if not request.node.get_closest_marker('unconfigured'): os.environ['WANDB_API_KEY'] = 'test' os.environ['WANDB_ENTITY'] = 'test' os.environ['WANDB_PROJECT'] = 'unit-test-project' else: # when unconfigured we enable run mode to test missing creds os.environ['WANDB_MODE'] = 'run' monkeypatch.setattr('wandb.apis.InternalApi.api_key', None) monkeypatch.setattr( 'getpass.getpass', lambda x: "0123456789012345678901234567890123456789") assert InternalApi().api_key == None os.environ['WANDB_RUN_DIR'] = str(tmpdir) assert wandb.run is None assert wandb.config is None orig_namespace = vars(wandb) # Mock out run_manager, we add it to run to access state in tests orig_rm = wandb.run_manager.RunManager mock = mocker.patch('wandb.run_manager.RunManager') def fake_init(api, run, port=None, output=None): print("Initialized fake run manager") rm = fake_run_manager(mocker, api, run, rm_class=orig_rm) rm._block_file_observer() run.run_manager = rm return rm mock.side_effect = fake_init if request.node.get_closest_marker('args'): kwargs = request.node.get_closest_marker('args').kwargs # Unfortunate to enable the test to work if kwargs.get("dir"): del os.environ['WANDB_RUN_DIR'] if kwargs.get("tensorboard"): # The test uses tensorboardX so we need to be sure it's imported # we use get_module because tensorboardX isn't available in py2 wandb.util.get_module("tensorboardX") if kwargs.get("error"): err = kwargs["error"] del kwargs['error'] if err == "io": @classmethod def error(cls): raise IOError monkeypatch.setattr( 'wandb.wandb_run.Run.from_environment_or_defaults', error) elif err == "socket": class Error(object): @property def port(self): return 123 def listen(self, secs): return False, None monkeypatch.setattr("wandb.wandb_socket.Server", Error) if kwargs.get('k8s') is not None: token_path = "/var/run/secrets/kubernetes.io/serviceaccount/token" crt_path = "/var/run/secrets/kubernetes.io/serviceaccount/ca.crt" orig_exist = os.path.exists def exists(path): return True if path in token_path else orig_exist(path) def magic(path, *args, **kwargs): if path == token_path: return six.StringIO('token') mocker.patch('wandb.util.open', magic, create=True) mocker.patch('wandb.util.os.path.exists', exists) os.environ["KUBERNETES_SERVICE_HOST"] = "k8s" os.environ["KUBERNETES_PORT_443_TCP_PORT"] = "123" os.environ["HOSTNAME"] = "test" if kwargs["k8s"]: request_mocker.register_uri("GET", "https://*****:*****@sha256:1234"}]}}') else: request_mocker.register_uri("GET", "https://k8s:123/api/v1/namespaces/default/pods/test", content=b'{}', status_code=500) del kwargs["k8s"] if kwargs.get('sagemaker'): del kwargs['sagemaker'] config_path = "/opt/ml/input/config/hyperparameters.json" resource_path = "/opt/ml/input/config/resourceconfig.json" secrets_path = "secrets.env" os.environ['TRAINING_JOB_NAME'] = 'sage' os.environ['CURRENT_HOST'] = 'maker' orig_exist = os.path.exists def exists(path): return True if path in (config_path, secrets_path) else orig_exist(path) mocker.patch('wandb.os.path.exists', exists) def magic(path, *args, **kwargs): if path == config_path: return six.StringIO('{"f****n": "A"}') elif path == resource_path: return six.StringIO('{"hosts":["a", "b"]}') elif path == secrets_path: return six.StringIO('WANDB_TEST_SECRET=TRUE') else: return six.StringIO() mocker.patch('wandb.open', magic, create=True) mocker.patch('wandb.util.open', magic, create=True) elif kwargs.get("tf_config"): os.environ['TF_CONFIG'] = json.dumps(kwargs['tf_config']) del kwargs['tf_config'] elif kwargs.get("env"): for k, v in six.iteritems(kwargs["env"]): os.environ[k] = v del kwargs["env"] else: kwargs = {} if request.node.get_closest_marker('resume'): # env was leaking when running the whole suite... if os.getenv(env.RUN_ID): del os.environ[env.RUN_ID] query_run_resume_status(request_mocker) os.mkdir(wandb.wandb_dir()) with open(os.path.join(wandb.wandb_dir(), wandb_run.RESUME_FNAME), "w") as f: f.write(json.dumps({"run_id": "test"})) try: print("Initializing with", kwargs) run = wandb.init(**kwargs) api.set_current_run_id(run.id) if request.node.get_closest_marker('resume') or request.node.get_closest_marker('mocked_run_manager'): # Reset history run._history = None rm = wandb.run_manager.RunManager(api, run) rm.init_run(os.environ) if request.node.get_closest_marker('mock_socket'): run.socket = mocker.MagicMock() assert run is wandb.run assert run.config is wandb.config except wandb.LaunchError as e: print("!!! wandb LaunchError raised") run = e yield run if hasattr(run, "run_manager"): print("Shutting down run manager") run.run_manager.test_shutdown() finally: # restore the original environment os.environ.clear() os.environ.update(orig_environ) wandb.uninit() wandb.get_ipython = lambda: None assert vars(wandb) == orig_namespace
class Run(object): def __init__(self, run_id=None, mode=None, dir=None, group=None, job_type=None, config=None, sweep_id=None, storage_id=None, description=None, resume=None, program=None, args=None, wandb_dir=None, tags=None, name=None, notes=None): # self.id is actually stored in the "name" attribute in GQL self.id = run_id if run_id else util.generate_id() self.display_name = self.id self.resume = resume if resume else 'never' self.mode = mode if mode else 'run' self.group = group self.job_type = job_type self.pid = os.getpid() self.resumed = False # we set resume when history is first accessed self._api = None self.run_name = None self.notes = None self.program = program if not self.program: try: import __main__ self.program = __main__.__file__ except (ImportError, AttributeError): # probably `python -c`, an embedded interpreter or something self.program = '<python with no main file>' self.args = args if self.args is None: self.args = sys.argv[1:] self.wandb_dir = wandb_dir with configure_scope() as scope: self.project = self.api.settings("project") scope.set_tag("project", self.project) scope.set_tag("entity", self.entity) scope.set_tag("url", self.get_url(self.api)) if dir is None: self._dir = run_dir_path(self.id, dry=self.mode == 'dryrun') else: self._dir = os.path.abspath(dir) self._mkdir() if self.resume == "auto": util.mkdir_exists_ok(wandb.wandb_dir()) resume_path = os.path.join(wandb.wandb_dir(), RESUME_FNAME) with open(resume_path, "w") as f: f.write(json.dumps({"run_id": self.id})) if config is None: self.config = Config() else: self.config = config # this is the GQL ID: self.storage_id = storage_id # socket server, currently only available in headless mode self.socket = None self.name_and_description = "" if description is not None: self.name_and_description = description elif os.path.exists(self.description_path): with open(self.description_path) as d_file: self.name_and_description = d_file.read() if name is not None: self.run_name = name if notes is not None: self.notes = notes self.tags = tags if tags else [] self.sweep_id = sweep_id self._history = None self._events = None self._summary = None self._meta = None self._run_manager = None self._jupyter_agent = None @property def api(self): if self._api is None: self._api = InternalApi() self._api.set_current_run_id(self.id) return self._api @property def entity(self): return self.api.settings('entity') @entity.setter def entity(self, entity): self.api.set_setting("entity", entity) @property def path(self): # TODO: theres an edge case where self.entity is None return "/".join([str(self.entity), self.project_name(), self.id]) def _init_jupyter_agent(self): from wandb.jupyter import JupyterAgent self._jupyter_agent = JupyterAgent() def _stop_jupyter_agent(self): self._jupyter_agent.stop() def send_message(self, options): """ Sends a message to the wandb process changing the policy of saved files. This is primarily used internally by wandb.save """ if not options.get("save_policy") and not options.get("tensorboard"): raise ValueError( "Only configuring save_policy and tensorboard is supported") if self.socket: # In the user process self.socket.send(options) elif self._jupyter_agent: # Running in jupyter self._jupyter_agent.start() if options.get("save_policy"): self._jupyter_agent.rm.update_user_file_policy( options["save_policy"]) elif options.get("tensorboard"): self._jupyter_agent.rm.start_tensorboard_watcher( options["tensorboard"]["logdir"], options["tensorboard"]["save"]) elif self._run_manager: # Running in the wandb process, used for tfevents saving if options.get("save_policy"): self._run_manager.update_user_file_policy( options["save_policy"]) else: wandb.termerror( "wandb.init hasn't been called, can't configure run") @classmethod def from_environment_or_defaults(cls, environment=None): """Create a Run object taking values from the local environment where possible. The run ID comes from WANDB_RUN_ID or is randomly generated. The run mode ("dryrun", or "run") comes from WANDB_MODE or defaults to "dryrun". The run directory comes from WANDB_RUN_DIR or is generated from the run ID. The Run will have a .config attribute but its run directory won't be set by default. """ if environment is None: environment = os.environ run_id = environment.get(env.RUN_ID) resume = environment.get(env.RESUME) storage_id = environment.get(env.RUN_STORAGE_ID) mode = environment.get(env.MODE) api = InternalApi() disabled = api.disabled() if not mode and disabled: mode = "dryrun" elif disabled and mode != "dryrun": wandb.termwarn( "WANDB_MODE is set to run, but W&B was disabled. Run `wandb on` to remove this message") elif disabled: wandb.termlog( 'W&B is disabled in this directory. Run `wandb on` to enable cloud syncing.') group = environment.get(env.RUN_GROUP) job_type = environment.get(env.JOB_TYPE) run_dir = environment.get(env.RUN_DIR) sweep_id = environment.get(env.SWEEP_ID) program = environment.get(env.PROGRAM) description = environment.get(env.DESCRIPTION) name = environment.get(env.NAME) notes = environment.get(env.NOTES) args = env.get_args() wandb_dir = env.get_dir() tags = env.get_tags() config = Config.from_environment_or_defaults() run = cls(run_id, mode, run_dir, group, job_type, config, sweep_id, storage_id, program=program, description=description, args=args, wandb_dir=wandb_dir, tags=tags, name=name, notes=notes, resume=resume) return run @classmethod def from_directory(cls, directory, project=None, entity=None, run_id=None, api=None, ignore_globs=None): api = api or InternalApi() run_id = run_id or util.generate_id() run = Run(run_id=run_id, dir=directory) run_name = None project_from_meta = None snap = DirectorySnapshot(directory) meta = next((p for p in snap.paths if METADATA_FNAME in p), None) if meta: meta = json.load(open(meta)) run_name = meta.get("name") project_from_meta = meta.get("project") project = project or project_from_meta or api.settings( "project") or run.auto_project_name(api=api) if project is None: raise ValueError("You must specify project") api.set_current_run_id(run_id) api.set_setting("project", project) if entity: api.set_setting("entity", entity) res = api.upsert_run(name=run_id, project=project, entity=entity, display_name=run_name) entity = res["project"]["entity"]["name"] wandb.termlog("Syncing {} to:".format(directory)) wandb.termlog(res["displayName"] + " " + run.get_url(api)) file_api = api.get_file_stream_api() file_api.start() paths = [os.path.relpath(abs_path, directory) for abs_path in snap.paths if os.path.isfile(abs_path)] if ignore_globs: paths = set(paths) for g in ignore_globs: paths = paths - set(fnmatch.filter(paths, g)) paths = list(paths) run_update = {"id": res["id"]} tfevents = sorted([p for p in snap.paths if ".tfevents." in p]) history = next((p for p in snap.paths if HISTORY_FNAME in p), None) event = next((p for p in snap.paths if EVENTS_FNAME in p), None) config = next((p for p in snap.paths if CONFIG_FNAME in p), None) user_config = next( (p for p in snap.paths if USER_CONFIG_FNAME in p), None) summary = next((p for p in snap.paths if SUMMARY_FNAME in p), None) if history: wandb.termlog("Uploading history metrics") file_api.stream_file(history) snap.paths.remove(history) elif len(tfevents) > 0: from wandb import tensorflow as wbtf wandb.termlog("Found tfevents file, converting...") summary = {} for path in tfevents: filename = os.path.basename(path) namespace = path.replace(filename, "").replace(directory, "").strip(os.sep) summary.update(wbtf.stream_tfevents(path, file_api, run, namespace=namespace)) for path in glob.glob(os.path.join(directory, "media/**/*"), recursive=True): if os.path.isfile(path): paths.append(path) else: wandb.termerror( "No history or tfevents files found, only syncing files") if event: file_api.stream_file(event) snap.paths.remove(event) if config: run_update["config"] = util.load_yaml( open(config)) elif user_config: # TODO: half backed support for config.json run_update["config"] = {k: {"value": v} for k, v in six.iteritems(user_config)} if isinstance(summary, dict): #TODO: summary should already have data_types converted here... run_update["summary_metrics"] = util.json_dumps_safer(summary) elif summary: run_update["summary_metrics"] = open(summary).read() if meta: if meta.get("git"): run_update["commit"] = meta["git"].get("commit") run_update["repo"] = meta["git"].get("remote") run_update["host"] = meta["host"] run_update["program_path"] = meta["program"] run_update["job_type"] = meta.get("jobType") run_update["notes"] = meta.get("notes") else: run_update["host"] = socket.gethostname() wandb.termlog("Updating run and uploading files") api.upsert_run(**run_update) pusher = FilePusher(api) for k in paths: path = os.path.abspath(os.path.join(directory, k)) pusher.update_file(k, path) pusher.file_changed(k, path) pusher.finish() pusher.print_status() file_api.finish(0) # Remove temporary media images generated from tfevents if history is None and os.path.exists(os.path.join(directory, "media")): shutil.rmtree(os.path.join(directory, "media")) wandb.termlog("Finished!") return run def auto_project_name(self, api): # if we're in git, set project name to git repo name + relative path within repo root_dir = api.git.root_dir if root_dir is None: return None repo_name = os.path.basename(root_dir) program = self.program if program is None: return repo_name if not os.path.isabs(program): program = os.path.join(os.curdir, program) prog_dir = os.path.dirname(os.path.abspath(program)) if not prog_dir.startswith(root_dir): return repo_name project = repo_name sub_path = os.path.relpath(prog_dir, root_dir) if sub_path != '.': project += '-' + sub_path return project.replace(os.sep, '_') def save(self, id=None, program=None, summary_metrics=None, num_retries=None, api=None): api = api or self.api project = api.settings('project') if project is None: project = self.auto_project_name(api) upsert_result = api.upsert_run(id=id or self.storage_id, name=self.id, commit=api.git.last_commit, project=project, entity=self.entity, group=self.group, tags=self.tags if len( self.tags) > 0 else None, config=self.config.as_dict(), description=self.name_and_description, host=socket.gethostname(), program_path=program or self.program, repo=api.git.remote_url, sweep_name=self.sweep_id, display_name=self.run_name, notes=self.notes, summary_metrics=summary_metrics, job_type=self.job_type, num_retries=num_retries) self.storage_id = upsert_result['id'] self.display_name = upsert_result.get('displayName') or self.id return upsert_result def set_environment(self, environment=None): """Set environment variables needed to reconstruct this object inside a user scripts (eg. in `wandb.init()`). """ if environment is None: environment = os.environ environment[env.RUN_ID] = self.id environment[env.RESUME] = self.resume if self.storage_id: environment[env.RUN_STORAGE_ID] = self.storage_id environment[env.MODE] = self.mode environment[env.RUN_DIR] = self.dir if self.group: environment[env.RUN_GROUP] = self.group if self.job_type: environment[env.JOB_TYPE] = self.job_type if self.wandb_dir: environment[env.DIR] = self.wandb_dir if self.sweep_id is not None: environment[env.SWEEP_ID] = self.sweep_id if self.program is not None: environment[env.PROGRAM] = self.program if self.args is not None: environment[env.ARGS] = json.dumps(self.args) if self.name_and_description is not None: environment[env.DESCRIPTION] = self.name_and_description if self.run_name is not None: environment[env.NAME] = self.run_name if self.notes is not None: environment[env.NOTES] = self.notes if len(self.tags) > 0: environment[env.TAGS] = ",".join(self.tags) def _mkdir(self): util.mkdir_exists_ok(self._dir) def project_name(self, api=None): api = api or self.api return api.settings('project') or self.auto_project_name(api) or "uncategorized" def get_url(self, api=None): api = api or self.api if api.api_key: if api.settings('entity') is None: viewer = api.viewer() if viewer.get('entity'): api.set_setting('entity', viewer['entity']) if api.settings('entity'): return "{base}/{entity}/{project}/runs/{run}".format( base=api.app_url, entity=urllib.parse.quote_plus(api.settings('entity')), project=urllib.parse.quote_plus(self.project_name(api)), run=urllib.parse.quote_plus(self.id) ) else: # TODO: I think this could only happen if the api key is invalid return "run pending creation, url not known" else: return "not logged in, run wandb login or set WANDB_API_KEY" def upload_debug(self): """Uploads the debug log to cloud storage""" if os.path.exists(self.log_fname): pusher = FilePusher(self.api) pusher.update_file("wandb-debug.log", self.log_fname) pusher.file_changed("wandb-debug.log", self.log_fname) pusher.finish() def __repr__(self): return "W&B Run: %s" % self.get_url() @property def name(self): """We assume the first line of the description is the name users want to use, and we automatically set it to id if the user didn't specify """ if self.run_name is not None: return self.run_name return self.name_and_description.split("\n")[0] @name.setter def name(self, name): self.run_name = name parts = self.name_and_description.split("\n", 1) parts[0] = name self.name_and_description = "\n".join(parts) # deprecate description in future release in favor of notes @property def description(self): parts = self.name_and_description.split("\n", 1) if len(parts) > 1: return parts[1] else: return "" @description.setter def description(self, desc): parts = self.name_and_description.split("\n", 1) if len(parts) == 1: parts.append("") parts[1] = desc self.name_and_description = "\n".join(parts) with open(self.description_path, 'w') as d_file: d_file.write(self.name_and_description) @property def host(self): return socket.gethostname() @property def dir(self): return self._dir @property def log_fname(self): # TODO: we started work to log to a file in the run dir, but it had issues. # For now all logs goto the same place. return util.get_log_file_path() def enable_logging(self): """Enable logging to the global debug log. This adds a run_id to the log, in case of muliple processes on the same machine. Currently no way to disable logging after it's enabled. """ handler = logging.FileHandler(self.log_fname) handler.setLevel(logging.INFO) run_id = self.id class WBFilter(logging.Filter): def filter(self, record): record.run_id = run_id return True formatter = logging.Formatter( '%(asctime)s %(levelname)-7s %(threadName)-10s:%(process)d [%(run_id)s:%(filename)s:%(funcName)s():%(lineno)s] %(message)s') handler.setFormatter(formatter) handler.addFilter(WBFilter()) root = logging.getLogger() root.addHandler(handler) @property def summary(self): if self._summary is None: self._summary = summary.FileSummary(self) return self._summary @property def has_summary(self): return self._summary or os.path.exists(os.path.join(self._dir, summary.SUMMARY_FNAME)) def _history_added(self, row): if self._summary is None: self._summary = summary.FileSummary(self) self._summary.update(row, overwrite=False) @property def history(self): if self._history is None: jupyter_callback = self._jupyter_agent.start if self._jupyter_agent else None self._history = history.History( self, add_callback=self._history_added, jupyter_callback=jupyter_callback) if self._history._steps > 0: self.resumed = True return self._history @property def step(self): return self.history._steps @property def has_history(self): return self._history or os.path.exists(os.path.join(self._dir, HISTORY_FNAME)) @property def events(self): if self._events is None: self._events = jsonlfile.JsonlEventsFile(EVENTS_FNAME, self._dir) return self._events @property def has_events(self): return self._events or os.path.exists(os.path.join(self._dir, EVENTS_FNAME)) @property def description_path(self): return os.path.join(self.dir, DESCRIPTION_FNAME) def close_files(self): """Close open files to avoid Python warnings on termination: Exception ignored in: <_io.FileIO name='wandb/dryrun-20180130_144602-9vmqjhgy/wandb-history.jsonl' mode='wb' closefd=True> ResourceWarning: unclosed file <_io.TextIOWrapper name='wandb/dryrun-20180130_144602-9vmqjhgy/wandb-history.jsonl' mode='w' encoding='UTF-8'> """ if self._events is not None: self._events.close() self._events = None if self._history is not None: self._history.close() self._history = None