class Artifact(object): """An artifact object you can write files into, and pass to log_artifact.""" def __init__(self, name, type, description=None, metadata=None): if not re.match(r"^[a-zA-Z0-9_\-.]+$", name): raise ValueError( 'Artifact name may only contain alphanumeric characters, dashes, underscores, and dots. Invalid name: "%s"' % name) # TODO: this shouldn't be a property of the artifact. It's a more like an # argument to log_artifact. storage_layout = StorageLayout.V2 if env.get_use_v1_artifacts(): storage_layout = StorageLayout.V1 self._storage_policy = WandbStoragePolicy( config={ "storageLayout": storage_layout, # TODO: storage region }) self._api = InternalApi() self._final = False self._digest = None self._file_entries = None self._manifest = ArtifactManifestV1(self, self._storage_policy) self._cache = get_artifacts_cache() self._added_new = False self._added_objs = {} # You can write into this directory when creating artifact files self._artifact_dir = compat_tempfile.TemporaryDirectory( missing_ok_on_cleanup=True) self.type = type self.name = name self.description = description self.metadata = metadata @property def id(self): # The artifact hasn't been saved so an ID doesn't exist yet. return None @property def entity(self): # TODO: querying for default entity a good idea here? return self._api.settings("entity") or self._api.viewer().get("entity") @property def project(self): return self._api.settings("project") @property def manifest(self): self.finalize() return self._manifest @property def digest(self): self.finalize() # Digest will be none if the artifact hasn't been saved yet. return self._digest def _ensure_can_add(self): if self._final: raise ValueError("Can't add to finalized artifact.") def new_file(self, name, mode="w"): self._ensure_can_add() path = os.path.join(self._artifact_dir.name, name.lstrip("/")) if os.path.exists(path): raise ValueError('File with name "%s" already exists at "%s"' % (name, path)) util.mkdir_exists_ok(os.path.dirname(path)) self._added_new = True return open(path, mode) def add_file(self, local_path, name=None, is_tmp=False): """Adds a local file to the artifact Args: local_path (str): path to the file name (str, optional): new path and filename to assign inside artifact. Defaults to None. is_tmp (bool, optional): If true, then the file is renamed deterministically. Defaults to False. Returns: ArtifactManifestEntry: the added entry """ self._ensure_can_add() if not os.path.isfile(local_path): raise ValueError("Path is not a file: %s" % local_path) name = name or os.path.basename(local_path) digest = md5_file_b64(local_path) if is_tmp: file_path, file_name = os.path.split(name) file_name_parts = file_name.split(".") file_name_parts[0] = b64_string_to_hex(digest)[:8] name = os.path.join(file_path, ".".join(file_name_parts)) entry = ArtifactManifestEntry( name, None, digest=digest, size=os.path.getsize(local_path), local_path=local_path, ) self._manifest.add_entry(entry) return entry def add_dir(self, local_path, name=None): self._ensure_can_add() if not os.path.isdir(local_path): raise ValueError("Path is not a directory: %s" % local_path) termlog( "Adding directory to artifact (%s)... " % os.path.join(".", os.path.normpath(local_path)), newline=False, ) start_time = time.time() paths = [] for dirpath, _, filenames in os.walk(local_path, followlinks=True): for fname in filenames: physical_path = os.path.join(dirpath, fname) logical_path = os.path.relpath(physical_path, start=local_path) if name is not None: logical_path = os.path.join(name, logical_path) paths.append((logical_path, physical_path)) def add_manifest_file(log_phy_path): logical_path, physical_path = log_phy_path self._manifest.add_entry( ArtifactManifestEntry( logical_path, None, digest=md5_file_b64(physical_path), size=os.path.getsize(physical_path), local_path=physical_path, )) import multiprocessing.dummy # this uses threads NUM_THREADS = 8 pool = multiprocessing.dummy.Pool(NUM_THREADS) pool.map(add_manifest_file, paths) pool.close() pool.join() termlog("Done. %.1fs" % (time.time() - start_time), prefix=False) def add_reference(self, uri, name=None, checksum=True, max_objects=None): """adds `uri` to the artifact via a reference, located at `name`. You can use Artifact#get_path(`name`) to retrieve this object. Arguments: - `uri`:str - the URI path of the reference to add. Can be an object returned from Artifact.get_path to store a reference to another artifact's entry. - `name`:str - the path to save """ # This is a bit of a hack, we want to check if the uri is a of the type # ArtifactEntry which is a private class returned by Artifact.get_path in # wandb/apis/public.py. If so, then recover the reference URL. if (isinstance(uri, object) and hasattr(uri, "parent_artifact") and uri.parent_artifact != self): ref_url_fn = getattr(uri, "ref_url") uri = ref_url_fn() url = urlparse(uri) if not url.scheme: raise ValueError( "References must be URIs. To reference a local file, use file://" ) if self._final: raise ValueError("Can't add to finalized artifact.") manifest_entries = self._storage_policy.store_reference( self, uri, name=name, checksum=checksum, max_objects=max_objects) for entry in manifest_entries: self._manifest.add_entry(entry) return manifest_entries def add(self, obj, name): """Adds `obj` to the artifact, located at `name`. You can use Artifact#get(`name`) after downloading the artifact to retrieve this object. Arguments: obj (wandb.WBValue): The object to save in an artifact name (str): The path to save """ # Validate that the object is wandb.Media type if not isinstance(obj, WBValue): raise ValueError("Can only add `obj` which subclass wandb.WBValue") obj_id = id(obj) if obj_id in self._added_objs: return self._added_objs[obj_id] # If the object is coming from another artifact, save it as a reference if obj.artifact_source is not None: ref_path = obj.artifact_source["artifact"].get_path( type(obj).with_suffix(obj.artifact_source["name"])) return self.add_reference(ref_path, type(obj).with_suffix(name))[0] val = obj.to_json(self) name = obj.with_suffix(name) entry = self._manifest.get_entry_by_path(name) if entry is not None: return entry with self.new_file(name) as f: import json # TODO: Do we need to open with utf-8 codec? f.write(json.dumps(val, sort_keys=True)) # Note, we add the file from our temp directory. # It will be added again later on finalize, but succeed since # the checksum should match entry = self.add_file(os.path.join(self._artifact_dir.name, name), name) self._added_objs[obj_id] = entry return entry def get_added_local_path_name(self, local_path): """If local_path was already added to artifact, return its internal name.""" entry = self._manifest.get_entry_by_local_path(local_path) if entry is None: return None return entry.path def get_path(self, name): raise ValueError( "Cannot load paths from an artifact before it has been saved") def download(self): raise ValueError( "Cannot call download on an artifact before it has been saved") def get(self): raise ValueError( "Cannot call get on an artifact before it has been saved") def finalize(self): if self._final: return self._file_entries # Record any created files in the manifest. if self._added_new: self.add_dir(self._artifact_dir.name) # mark final after all files are added self._final = True self._digest = self._manifest.digest() # If there are new files, move them into the artifact cache now. Our temp # self._artifact_dir may not be available by the time file pusher syncs # these files. if self._added_new: # Update the file entries for new files to point at their new location. def remap_entry(entry): if entry.local_path is None or not entry.local_path.startswith( self._artifact_dir.name): return entry rel_path = os.path.relpath(entry.local_path, start=self._artifact_dir.name) local_path = os.path.join(self._artifact_dir.name, rel_path) cache_path, hit = self._cache.check_md5_obj_path( entry.digest, entry.size) if not hit: shutil.copyfile(local_path, cache_path) entry.local_path = cache_path for entry in self._manifest.entries.values(): remap_entry(entry)
class WandbStoragePolicy(StoragePolicy): @classmethod def name(cls): return "wandb-storage-policy-v1" @classmethod def from_config(cls, config): return cls(config=config) def __init__(self, config=None): self._cache = get_artifacts_cache() self._config = config or {} self._session = requests.Session() adapter = requests.adapters.HTTPAdapter( max_retries=_REQUEST_RETRY_STRATEGY, pool_connections=_REQUEST_POOL_CONNECTIONS, pool_maxsize=_REQUEST_POOL_MAXSIZE, ) self._session.mount("http://", adapter) self._session.mount("https://", adapter) s3 = S3Handler() gcs = GCSHandler() http = HTTPHandler(self._session) https = HTTPHandler(self._session, scheme="https") artifact = WBArtifactHandler() file_handler = LocalFileHandler() self._api = InternalApi() self._handler = MultiHandler( handlers=[ s3, gcs, http, https, artifact, file_handler, ], default_handler=TrackingHandler(), ) def config(self): return self._config def load_file(self, artifact, name, manifest_entry): path, hit = self._cache.check_md5_obj_path(manifest_entry.digest, manifest_entry.size) if hit: return path response = self._session.get( self._file_url(self._api, artifact.entity, manifest_entry), auth=("api", self._api.api_key), stream=True, ) response.raise_for_status() with open(path, "wb") as file: for data in response.iter_content(chunk_size=16 * 1024): file.write(data) return path def store_reference(self, artifact, path, name=None, checksum=True, max_objects=None): return self._handler.store_path(artifact, path, name=name, checksum=checksum, max_objects=max_objects) def load_reference(self, artifact, name, manifest_entry, local=False): return self._handler.load_path(self._cache, manifest_entry, local) def _file_url(self, api, entity_name, manifest_entry): storage_layout = self._config.get("storageLayout", StorageLayout.V1) storage_region = self._config.get("storageRegion", "default") md5_hex = util.bytes_to_hex(base64.b64decode(manifest_entry.digest)) if storage_layout == StorageLayout.V1: return "{}/artifacts/{}/{}".format(api.settings("base_url"), entity_name, md5_hex) elif storage_layout == StorageLayout.V2: return "{}/artifactsV2/{}/{}/{}/{}".format( api.settings("base_url"), storage_region, entity_name, quote(manifest_entry.birth_artifact_id), md5_hex, ) else: raise Exception( "unrecognized storage layout: {}".format(storage_layout)) def store_file(self, artifact_id, entry, preparer, progress_callback=None): # write-through cache cache_path, hit = self._cache.check_md5_obj_path( entry.digest, entry.size) if not hit: shutil.copyfile(entry.local_path, cache_path) resp = preparer.prepare(lambda: { "artifactID": artifact_id, "name": entry.path, "md5": entry.digest, }) entry.birth_artifact_id = resp.birth_artifact_id exists = resp.upload_url is None if not exists: with open(entry.local_path, "rb") as file: # This fails if we don't send the first byte before the signed URL # expires. self._api.upload_file_retry( resp.upload_url, file, progress_callback, extra_headers={ header.split(":", 1)[0]: header.split(":", 1)[1] for header in (resp.upload_headers or {}) }, ) return exists
def test_meta_thread(git_repo): meta = Meta(InternalApi(), "wandb") meta.start() meta.shutdown() print("GO", glob.glob("**")) assert os.path.exists("wandb/wandb-metadata.json")
def test_disable_code(git_repo): os.environ[env.DISABLE_CODE] = "true" meta = Meta(InternalApi()) assert meta.data.get("git") is None del os.environ[env.DISABLE_CODE]
def project_name(self, api=None): if api is None: api = InternalApi() return api.settings('project') or self.auto_project_name( api) or "uncategorized"
def __init__(self, run_id=None, mode=None, dir=None, group=None, job_type=None, config=None, sweep_id=None, storage_id=None, description=None, resume=None, program=None, args=None, wandb_dir=None, tags=None): # self.id is actually stored in the "name" attribute in GQL self.id = run_id if run_id else util.generate_id() self.resume = resume if resume else 'never' self.mode = mode if mode else 'run' self.group = group self.job_type = job_type self.pid = os.getpid() self.resumed = False # we set resume when history is first accessed self.program = program if not self.program: try: import __main__ self.program = __main__.__file__ except (ImportError, AttributeError): # probably `python -c`, an embedded interpreter or something self.program = '<python with no main file>' self.args = args if self.args is None: self.args = sys.argv[1:] self.wandb_dir = wandb_dir with configure_scope() as scope: api = InternalApi() self.project = api.settings("project") self.entity = api.settings("entity") scope.set_tag("project", self.project) scope.set_tag("entity", self.entity) scope.set_tag("url", self.get_url(api)) if dir is None: self._dir = run_dir_path(self.id, dry=self.mode == 'dryrun') else: self._dir = os.path.abspath(dir) self._mkdir() if self.resume == "auto": util.mkdir_exists_ok(wandb.wandb_dir()) resume_path = os.path.join(wandb.wandb_dir(), RESUME_FNAME) with open(resume_path, "w") as f: f.write(json.dumps({"run_id": self.id})) if config is None: self.config = Config() else: self.config = config # this is the GQL ID: self.storage_id = storage_id # socket server, currently only available in headless mode self.socket = None self.name_and_description = None if description is not None: self.name_and_description = description elif os.path.exists(self.description_path): with open(self.description_path) as d_file: self.name_and_description = d_file.read() # An empty description.md may have been created by RunManager() so it's # important that we overwrite empty strings here. if not self.name_and_description: self.name_and_description = self.id self.tags = tags if tags else [] self.sweep_id = sweep_id self._history = None self._events = None self._summary = None self._meta = None self._jupyter_agent = None
def from_directory(cls, directory, project=None, entity=None, run_id=None, api=None): api = api or InternalApi() run_id = run_id or util.generate_id() run = Run(run_id=run_id, dir=directory) project = project or api.settings("project") or run.auto_project_name( api=api) if project is None: raise ValueError("You must specify project") api.set_current_run_id(run_id) api.set_setting("project", project) if entity: api.set_setting("entity", entity) res = api.upsert_run(name=run_id, project=project, entity=entity) entity = res["project"]["entity"]["name"] wandb.termlog("Syncing {} to:".format(directory)) wandb.termlog(run.get_url(api)) file_api = api.get_file_stream_api() snap = DirectorySnapshot(directory) paths = [ os.path.relpath(abs_path, directory) for abs_path in snap.paths if os.path.isfile(abs_path) ] run_update = {"id": res["id"]} tfevents = sorted([p for p in snap.paths if ".tfevents." in p]) history = next((p for p in snap.paths if HISTORY_FNAME in p), None) event = next((p for p in snap.paths if EVENTS_FNAME in p), None) config = next((p for p in snap.paths if CONFIG_FNAME in p), None) user_config = next((p for p in snap.paths if USER_CONFIG_FNAME in p), None) summary = next((p for p in snap.paths if SUMMARY_FNAME in p), None) meta = next((p for p in snap.paths if METADATA_FNAME in p), None) if history: wandb.termlog("Uploading history metrics") file_api.stream_file(history) snap.paths.remove(history) elif len(tfevents) > 0: from wandb import tensorflow as wbtf wandb.termlog("Found tfevents file, converting.") for file in tfevents: summary = wbtf.stream_tfevents(file, file_api) else: wandb.termerror( "No history or tfevents files found, only syncing files") if event: file_api.stream_file(event) snap.paths.remove(event) if config: run_update["config"] = util.load_yaml(open(config)) elif user_config: # TODO: half backed support for config.json run_update["config"] = { k: { "value": v } for k, v in six.iteritems(user_config) } if summary: run_update["summary_metrics"] = open(summary).read() if meta: meta = json.load(open(meta)) if meta.get("git"): run_update["commit"] = meta["git"].get("commit") run_update["repo"] = meta["git"].get("remote") run_update["host"] = meta["host"] run_update["program_path"] = meta["program"] run_update["job_type"] = meta.get("jobType") else: run_update["host"] = socket.gethostname() wandb.termlog("Updating run and uploading files") api.upsert_run(**run_update) pusher = FilePusher(api) for k in paths: path = os.path.abspath(os.path.join(directory, k)) pusher.update_file(k, path) pusher.file_changed(k, path) pusher.finish() pusher.print_status() wandb.termlog("Finished!") return run
def init(job_type=None, dir=None, config=None, project=None, entity=None, reinit=None, tags=None, group=None, allow_val_change=False, resume=False, force=False, tensorboard=False, sync_tensorboard=False, name=None, notes=None, id=None, magic=None): """Initialize W&B If called from within Jupyter, initializes a new run and waits for a call to `wandb.log` to begin pushing metrics. Otherwise, spawns a new process to communicate with W&B. Args: job_type (str, optional): The type of job running, defaults to 'train' config (dict, argparse, or tf.FLAGS, optional): The hyper parameters to store with the run project (str, optional): The project to push metrics to entity (str, optional): The entity to push metrics to dir (str, optional): An absolute path to a directory where metadata will be stored group (str, optional): A unique string shared by all runs in a given group tags (list, optional): A list of tags to apply to the run id (str, optional): A globally unique (per project) identifier for the run name (str, optional): A display name which does not have to be unique notes (str, optional): A multiline string associated with the run reinit (bool, optional): Allow multiple calls to init in the same process resume (bool, str, optional): Automatically resume this run if run from the same machine, you can also pass a unique run_id sync_tensorboard (bool, optional): Synchronize wandb logs to tensorboard or tensorboardX force (bool, optional): Force authentication with wandb, defaults to False magic (bool, dict, or str, optional): magic configuration as bool, dict, json string, yaml filename Returns: A wandb.run object for metric and config logging. """ trigger.call('on_init', **locals()) global run global __stage_dir__ # We allow re-initialization when we're in Jupyter or explicity opt-in to it. in_jupyter = _get_python_type() != "python" if reinit or (in_jupyter and reinit != False): reset_env(exclude=env.immutable_keys()) run = None # TODO: deprecate tensorboard if tensorboard or sync_tensorboard and len(patched["tensorboard"]) == 0: util.get_module("wandb.tensorboard").patch() sagemaker_config = util.parse_sm_config() tf_config = util.parse_tfjob_config() if group == None: group = os.getenv(env.RUN_GROUP) if job_type == None: job_type = os.getenv(env.JOB_TYPE) if sagemaker_config: # Set run_id and potentially grouping if we're in SageMaker run_id = os.getenv('TRAINING_JOB_NAME') if run_id: os.environ[env.RUN_ID] = '-'.join( [run_id, os.getenv('CURRENT_HOST', socket.gethostname())]) conf = json.load(open("/opt/ml/input/config/resourceconfig.json")) if group == None and len(conf["hosts"]) > 1: group = os.getenv('TRAINING_JOB_NAME') # Set secret variables if os.path.exists("secrets.env"): for line in open("secrets.env", "r"): key, val = line.strip().split('=', 1) os.environ[key] = val elif tf_config: cluster = tf_config.get('cluster') job_name = tf_config.get('task', {}).get('type') task_index = tf_config.get('task', {}).get('index') if job_name is not None and task_index is not None: # TODO: set run_id for resuming? run_id = cluster[job_name][task_index].rsplit(":")[0] if job_type == None: job_type = job_name if group == None and len(cluster.get("worker", [])) > 0: group = cluster[job_name][0].rsplit("-" + job_name, 1)[0] image = util.image_id_from_k8s() if image: os.environ[env.DOCKER] = image if project: os.environ[env.PROJECT] = project if entity: os.environ[env.ENTITY] = entity if group: os.environ[env.RUN_GROUP] = group if job_type: os.environ[env.JOB_TYPE] = job_type if tags: os.environ[env.TAGS] = ",".join(tags) if id: os.environ[env.RUN_ID] = id if name is None: # We do this because of https://github.com/wandb/core/issues/2170 # to ensure that the run's name is explicitly set to match its # id. If we don't do this and the id is eight characters long, the # backend will set the name to a generated human-friendly value. # # In any case, if the user is explicitly setting `id` but not # `name`, their id is probably a meaningful string that we can # use to label the run. name = os.environ.get( env.NAME, id) # environment variable takes precedence over this. if name: os.environ[env.NAME] = name if notes: os.environ[env.NOTES] = notes if magic is not None and magic is not False: if isinstance(magic, dict): os.environ[env.MAGIC] = json.dumps(magic) elif isinstance(magic, str): os.environ[env.MAGIC] = magic elif isinstance(magic, bool): pass else: termwarn("wandb.init called with invalid magic parameter type", repeat=False) from wandb import magic_impl magic_impl.magic_install() if dir: os.environ[env.DIR] = dir util.mkdir_exists_ok(wandb_dir()) resume_path = os.path.join(wandb_dir(), wandb_run.RESUME_FNAME) if resume == True: os.environ[env.RESUME] = "auto" elif resume: os.environ[env.RESUME] = os.environ.get(env.RESUME, "allow") # TODO: remove allowing resume as a string in the future os.environ[env.RUN_ID] = id or resume elif os.path.exists(resume_path): os.remove(resume_path) if os.environ.get(env.RESUME) == 'auto' and os.path.exists(resume_path): if not os.environ.get(env.RUN_ID): os.environ[env.RUN_ID] = json.load(open(resume_path))["run_id"] # the following line is useful to ensure that no W&B logging happens in the user # process that might interfere with what they do # logging.basicConfig(format='user process %(asctime)s - %(name)s - %(levelname)s - %(message)s') # If a thread calls wandb.init() it will get the same Run object as # the parent. If a child process with distinct memory space calls # wandb.init(), it won't get an error, but it will get a result of # None. # This check ensures that a child process can safely call wandb.init() # after a parent has (only the parent will create the Run object). # This doesn't protect against the case where the parent doesn't call # wandb.init but two children do. if run or os.getenv(env.INITED): return run if __stage_dir__ is None: __stage_dir__ = "wandb" util.mkdir_exists_ok(wandb_dir()) try: signal.signal(signal.SIGQUIT, _debugger) except AttributeError: pass try: run = wandb_run.Run.from_environment_or_defaults() except IOError as e: termerror('Failed to create run directory: {}'.format(e)) raise LaunchError("Could not write to filesystem.") run.set_environment() def set_global_config(run): global config # because we already have a local config config = run.config set_global_config(run) global summary summary = run.summary # set this immediately after setting the run and the config. if there is an # exception after this it'll probably break the user script anyway os.environ[env.INITED] = '1' # we do these checks after setting the run and the config because users scripts # may depend on those things if sys.platform == 'win32' and run.mode != 'clirun': termerror( 'To use wandb on Windows, you need to run the command "wandb run python <your_train_script>.py"' ) return run if in_jupyter: _init_jupyter(run) elif run.mode == 'clirun': pass elif run.mode == 'run': api = InternalApi() # let init_jupyter handle this itself if not in_jupyter and not api.api_key: termlog( "W&B is a tool that helps track and visualize machine learning experiments" ) if force: termerror( "No credentials found. Run \"wandb login\" or \"wandb off\" to disable wandb" ) else: if run.check_anonymous(): _init_headless(run) else: termlog( "No credentials found. Run \"wandb login\" to visualize your metrics" ) run.mode = "dryrun" _init_headless(run, False) else: _init_headless(run) elif run.mode == 'dryrun': termlog('Dry run mode, not syncing to the cloud.') _init_headless(run, False) else: termerror('Invalid run mode "%s". Please unset WANDB_MODE.' % run.mode) raise LaunchError("The WANDB_MODE environment variable is invalid.") # set the run directory in the config so it actually gets persisted run.config.set_run_dir(run.dir) if sagemaker_config: run.config.update(sagemaker_config) allow_val_change = True if config: run.config.update(config, allow_val_change=allow_val_change) # Access history to ensure resumed is set when resuming run.history # Load the summary to support resuming run.summary.load() atexit.register(run.close_files) return run
def ensure_configured(): global GLOBAL_LOG_FNAME, api # We re-initialize here for tests api = InternalApi() GLOBAL_LOG_FNAME = os.path.abspath(os.path.join(wandb_dir(), 'debug.log'))