예제 #1
0
class Artifact(object):
    """An artifact object you can write files into, and pass to log_artifact."""
    def __init__(self, name, type, description=None, metadata=None):
        if not re.match(r"^[a-zA-Z0-9_\-.]+$", name):
            raise ValueError(
                'Artifact name may only contain alphanumeric characters, dashes, underscores, and dots. Invalid name: "%s"'
                % name)
        # TODO: this shouldn't be a property of the artifact. It's a more like an
        # argument to log_artifact.
        storage_layout = StorageLayout.V2
        if env.get_use_v1_artifacts():
            storage_layout = StorageLayout.V1

        self._storage_policy = WandbStoragePolicy(
            config={
                "storageLayout": storage_layout,
                #  TODO: storage region
            })
        self._api = InternalApi()
        self._final = False
        self._digest = None
        self._file_entries = None
        self._manifest = ArtifactManifestV1(self, self._storage_policy)
        self._cache = get_artifacts_cache()
        self._added_new = False
        self._added_objs = {}
        # You can write into this directory when creating artifact files
        self._artifact_dir = compat_tempfile.TemporaryDirectory(
            missing_ok_on_cleanup=True)
        self.type = type
        self.name = name
        self.description = description
        self.metadata = metadata

    @property
    def id(self):
        # The artifact hasn't been saved so an ID doesn't exist yet.
        return None

    @property
    def entity(self):
        # TODO: querying for default entity a good idea here?
        return self._api.settings("entity") or self._api.viewer().get("entity")

    @property
    def project(self):
        return self._api.settings("project")

    @property
    def manifest(self):
        self.finalize()
        return self._manifest

    @property
    def digest(self):
        self.finalize()
        # Digest will be none if the artifact hasn't been saved yet.
        return self._digest

    def _ensure_can_add(self):
        if self._final:
            raise ValueError("Can't add to finalized artifact.")

    def new_file(self, name, mode="w"):
        self._ensure_can_add()
        path = os.path.join(self._artifact_dir.name, name.lstrip("/"))
        if os.path.exists(path):
            raise ValueError('File with name "%s" already exists at "%s"' %
                             (name, path))
        util.mkdir_exists_ok(os.path.dirname(path))
        self._added_new = True
        return open(path, mode)

    def add_file(self, local_path, name=None, is_tmp=False):
        """Adds a local file to the artifact

        Args:
            local_path (str): path to the file
            name (str, optional): new path and filename to assign inside artifact. Defaults to None.
            is_tmp (bool, optional): If true, then the file is renamed deterministically. Defaults to False.

        Returns:
            ArtifactManifestEntry: the added entry
        """
        self._ensure_can_add()
        if not os.path.isfile(local_path):
            raise ValueError("Path is not a file: %s" % local_path)

        name = name or os.path.basename(local_path)
        digest = md5_file_b64(local_path)

        if is_tmp:
            file_path, file_name = os.path.split(name)
            file_name_parts = file_name.split(".")
            file_name_parts[0] = b64_string_to_hex(digest)[:8]
            name = os.path.join(file_path, ".".join(file_name_parts))

        entry = ArtifactManifestEntry(
            name,
            None,
            digest=digest,
            size=os.path.getsize(local_path),
            local_path=local_path,
        )

        self._manifest.add_entry(entry)
        return entry

    def add_dir(self, local_path, name=None):
        self._ensure_can_add()
        if not os.path.isdir(local_path):
            raise ValueError("Path is not a directory: %s" % local_path)

        termlog(
            "Adding directory to artifact (%s)... " %
            os.path.join(".", os.path.normpath(local_path)),
            newline=False,
        )
        start_time = time.time()

        paths = []
        for dirpath, _, filenames in os.walk(local_path, followlinks=True):
            for fname in filenames:
                physical_path = os.path.join(dirpath, fname)
                logical_path = os.path.relpath(physical_path, start=local_path)
                if name is not None:
                    logical_path = os.path.join(name, logical_path)
                paths.append((logical_path, physical_path))

        def add_manifest_file(log_phy_path):
            logical_path, physical_path = log_phy_path
            self._manifest.add_entry(
                ArtifactManifestEntry(
                    logical_path,
                    None,
                    digest=md5_file_b64(physical_path),
                    size=os.path.getsize(physical_path),
                    local_path=physical_path,
                ))

        import multiprocessing.dummy  # this uses threads

        NUM_THREADS = 8
        pool = multiprocessing.dummy.Pool(NUM_THREADS)
        pool.map(add_manifest_file, paths)
        pool.close()
        pool.join()

        termlog("Done. %.1fs" % (time.time() - start_time), prefix=False)

    def add_reference(self, uri, name=None, checksum=True, max_objects=None):
        """adds `uri` to the artifact via a reference, located at `name`. 
        You can use Artifact#get_path(`name`) to retrieve this object.
        
        Arguments:
        - `uri`:str - the URI path of the reference to add. Can be an object returned from
            Artifact.get_path to store a reference to another artifact's entry.
        - `name`:str - the path to save
        """

        # This is a bit of a hack, we want to check if the uri is a of the type
        # ArtifactEntry which is a private class returned by Artifact.get_path in
        # wandb/apis/public.py. If so, then recover the reference URL.
        if (isinstance(uri, object) and hasattr(uri, "parent_artifact")
                and uri.parent_artifact != self):
            ref_url_fn = getattr(uri, "ref_url")
            uri = ref_url_fn()
        url = urlparse(uri)
        if not url.scheme:
            raise ValueError(
                "References must be URIs. To reference a local file, use file://"
            )
        if self._final:
            raise ValueError("Can't add to finalized artifact.")
        manifest_entries = self._storage_policy.store_reference(
            self, uri, name=name, checksum=checksum, max_objects=max_objects)
        for entry in manifest_entries:
            self._manifest.add_entry(entry)

        return manifest_entries

    def add(self, obj, name):
        """Adds `obj` to the artifact, located at `name`. You can use Artifact#get(`name`) after downloading
        the artifact to retrieve this object.
        
        Arguments:
            obj (wandb.WBValue): The object to save in an artifact
            name (str): The path to save
        """

        # Validate that the object is wandb.Media type
        if not isinstance(obj, WBValue):
            raise ValueError("Can only add `obj` which subclass wandb.WBValue")

        obj_id = id(obj)
        if obj_id in self._added_objs:
            return self._added_objs[obj_id]

        # If the object is coming from another artifact, save it as a reference
        if obj.artifact_source is not None:
            ref_path = obj.artifact_source["artifact"].get_path(
                type(obj).with_suffix(obj.artifact_source["name"]))
            return self.add_reference(ref_path, type(obj).with_suffix(name))[0]

        val = obj.to_json(self)
        name = obj.with_suffix(name)
        entry = self._manifest.get_entry_by_path(name)
        if entry is not None:
            return entry
        with self.new_file(name) as f:
            import json

            # TODO: Do we need to open with utf-8 codec?
            f.write(json.dumps(val, sort_keys=True))

        # Note, we add the file from our temp directory.
        # It will be added again later on finalize, but succeed since
        # the checksum should match
        entry = self.add_file(os.path.join(self._artifact_dir.name, name),
                              name)
        self._added_objs[obj_id] = entry

        return entry

    def get_added_local_path_name(self, local_path):
        """If local_path was already added to artifact, return its internal name."""
        entry = self._manifest.get_entry_by_local_path(local_path)
        if entry is None:
            return None
        return entry.path

    def get_path(self, name):
        raise ValueError(
            "Cannot load paths from an artifact before it has been saved")

    def download(self):
        raise ValueError(
            "Cannot call download on an artifact before it has been saved")

    def get(self):
        raise ValueError(
            "Cannot call get on an artifact before it has been saved")

    def finalize(self):
        if self._final:
            return self._file_entries

        # Record any created files in the manifest.
        if self._added_new:
            self.add_dir(self._artifact_dir.name)

        # mark final after all files are added
        self._final = True
        self._digest = self._manifest.digest()

        # If there are new files, move them into the artifact cache now. Our temp
        # self._artifact_dir may not be available by the time file pusher syncs
        # these files.
        if self._added_new:
            # Update the file entries for new files to point at their new location.
            def remap_entry(entry):
                if entry.local_path is None or not entry.local_path.startswith(
                        self._artifact_dir.name):
                    return entry
                rel_path = os.path.relpath(entry.local_path,
                                           start=self._artifact_dir.name)
                local_path = os.path.join(self._artifact_dir.name, rel_path)
                cache_path, hit = self._cache.check_md5_obj_path(
                    entry.digest, entry.size)
                if not hit:
                    shutil.copyfile(local_path, cache_path)
                entry.local_path = cache_path

            for entry in self._manifest.entries.values():
                remap_entry(entry)
예제 #2
0
class WandbStoragePolicy(StoragePolicy):
    @classmethod
    def name(cls):
        return "wandb-storage-policy-v1"

    @classmethod
    def from_config(cls, config):
        return cls(config=config)

    def __init__(self, config=None):
        self._cache = get_artifacts_cache()
        self._config = config or {}
        self._session = requests.Session()
        adapter = requests.adapters.HTTPAdapter(
            max_retries=_REQUEST_RETRY_STRATEGY,
            pool_connections=_REQUEST_POOL_CONNECTIONS,
            pool_maxsize=_REQUEST_POOL_MAXSIZE,
        )
        self._session.mount("http://", adapter)
        self._session.mount("https://", adapter)

        s3 = S3Handler()
        gcs = GCSHandler()
        http = HTTPHandler(self._session)
        https = HTTPHandler(self._session, scheme="https")
        artifact = WBArtifactHandler()
        file_handler = LocalFileHandler()

        self._api = InternalApi()
        self._handler = MultiHandler(
            handlers=[
                s3,
                gcs,
                http,
                https,
                artifact,
                file_handler,
            ],
            default_handler=TrackingHandler(),
        )

    def config(self):
        return self._config

    def load_file(self, artifact, name, manifest_entry):
        path, hit = self._cache.check_md5_obj_path(manifest_entry.digest,
                                                   manifest_entry.size)
        if hit:
            return path
        response = self._session.get(
            self._file_url(self._api, artifact.entity, manifest_entry),
            auth=("api", self._api.api_key),
            stream=True,
        )
        response.raise_for_status()

        with open(path, "wb") as file:
            for data in response.iter_content(chunk_size=16 * 1024):
                file.write(data)
        return path

    def store_reference(self,
                        artifact,
                        path,
                        name=None,
                        checksum=True,
                        max_objects=None):
        return self._handler.store_path(artifact,
                                        path,
                                        name=name,
                                        checksum=checksum,
                                        max_objects=max_objects)

    def load_reference(self, artifact, name, manifest_entry, local=False):
        return self._handler.load_path(self._cache, manifest_entry, local)

    def _file_url(self, api, entity_name, manifest_entry):
        storage_layout = self._config.get("storageLayout", StorageLayout.V1)
        storage_region = self._config.get("storageRegion", "default")
        md5_hex = util.bytes_to_hex(base64.b64decode(manifest_entry.digest))

        if storage_layout == StorageLayout.V1:
            return "{}/artifacts/{}/{}".format(api.settings("base_url"),
                                               entity_name, md5_hex)
        elif storage_layout == StorageLayout.V2:
            return "{}/artifactsV2/{}/{}/{}/{}".format(
                api.settings("base_url"),
                storage_region,
                entity_name,
                quote(manifest_entry.birth_artifact_id),
                md5_hex,
            )
        else:
            raise Exception(
                "unrecognized storage layout: {}".format(storage_layout))

    def store_file(self, artifact_id, entry, preparer, progress_callback=None):
        # write-through cache
        cache_path, hit = self._cache.check_md5_obj_path(
            entry.digest, entry.size)
        if not hit:
            shutil.copyfile(entry.local_path, cache_path)

        resp = preparer.prepare(lambda: {
            "artifactID": artifact_id,
            "name": entry.path,
            "md5": entry.digest,
        })

        entry.birth_artifact_id = resp.birth_artifact_id
        exists = resp.upload_url is None
        if not exists:
            with open(entry.local_path, "rb") as file:
                # This fails if we don't send the first byte before the signed URL
                # expires.
                self._api.upload_file_retry(
                    resp.upload_url,
                    file,
                    progress_callback,
                    extra_headers={
                        header.split(":", 1)[0]: header.split(":", 1)[1]
                        for header in (resp.upload_headers or {})
                    },
                )
        return exists
예제 #3
0
파일: test_meta.py 프로젝트: zbpjlc/client
def test_meta_thread(git_repo):
    meta = Meta(InternalApi(), "wandb")
    meta.start()
    meta.shutdown()
    print("GO", glob.glob("**"))
    assert os.path.exists("wandb/wandb-metadata.json")
예제 #4
0
파일: test_meta.py 프로젝트: zbpjlc/client
def test_disable_code(git_repo):
    os.environ[env.DISABLE_CODE] = "true"
    meta = Meta(InternalApi())
    assert meta.data.get("git") is None
    del os.environ[env.DISABLE_CODE]
예제 #5
0
 def project_name(self, api=None):
     if api is None:
         api = InternalApi()
     return api.settings('project') or self.auto_project_name(
         api) or "uncategorized"
예제 #6
0
    def __init__(self,
                 run_id=None,
                 mode=None,
                 dir=None,
                 group=None,
                 job_type=None,
                 config=None,
                 sweep_id=None,
                 storage_id=None,
                 description=None,
                 resume=None,
                 program=None,
                 args=None,
                 wandb_dir=None,
                 tags=None):
        # self.id is actually stored in the "name" attribute in GQL
        self.id = run_id if run_id else util.generate_id()
        self.resume = resume if resume else 'never'
        self.mode = mode if mode else 'run'
        self.group = group
        self.job_type = job_type
        self.pid = os.getpid()
        self.resumed = False  # we set resume when history is first accessed

        self.program = program
        if not self.program:
            try:
                import __main__
                self.program = __main__.__file__
            except (ImportError, AttributeError):
                # probably `python -c`, an embedded interpreter or something
                self.program = '<python with no main file>'
        self.args = args
        if self.args is None:
            self.args = sys.argv[1:]
        self.wandb_dir = wandb_dir

        with configure_scope() as scope:
            api = InternalApi()
            self.project = api.settings("project")
            self.entity = api.settings("entity")
            scope.set_tag("project", self.project)
            scope.set_tag("entity", self.entity)
            scope.set_tag("url", self.get_url(api))

        if dir is None:
            self._dir = run_dir_path(self.id, dry=self.mode == 'dryrun')
        else:
            self._dir = os.path.abspath(dir)
        self._mkdir()

        if self.resume == "auto":
            util.mkdir_exists_ok(wandb.wandb_dir())
            resume_path = os.path.join(wandb.wandb_dir(), RESUME_FNAME)
            with open(resume_path, "w") as f:
                f.write(json.dumps({"run_id": self.id}))

        if config is None:
            self.config = Config()
        else:
            self.config = config

        # this is the GQL ID:
        self.storage_id = storage_id
        # socket server, currently only available in headless mode
        self.socket = None

        self.name_and_description = None
        if description is not None:
            self.name_and_description = description
        elif os.path.exists(self.description_path):
            with open(self.description_path) as d_file:
                self.name_and_description = d_file.read()

        # An empty description.md may have been created by RunManager() so it's
        # important that we overwrite empty strings here.
        if not self.name_and_description:
            self.name_and_description = self.id
        self.tags = tags if tags else []

        self.sweep_id = sweep_id

        self._history = None
        self._events = None
        self._summary = None
        self._meta = None
        self._jupyter_agent = None
예제 #7
0
    def from_directory(cls,
                       directory,
                       project=None,
                       entity=None,
                       run_id=None,
                       api=None):
        api = api or InternalApi()
        run_id = run_id or util.generate_id()
        run = Run(run_id=run_id, dir=directory)
        project = project or api.settings("project") or run.auto_project_name(
            api=api)
        if project is None:
            raise ValueError("You must specify project")
        api.set_current_run_id(run_id)
        api.set_setting("project", project)
        if entity:
            api.set_setting("entity", entity)
        res = api.upsert_run(name=run_id, project=project, entity=entity)
        entity = res["project"]["entity"]["name"]
        wandb.termlog("Syncing {} to:".format(directory))
        wandb.termlog(run.get_url(api))

        file_api = api.get_file_stream_api()
        snap = DirectorySnapshot(directory)
        paths = [
            os.path.relpath(abs_path, directory) for abs_path in snap.paths
            if os.path.isfile(abs_path)
        ]
        run_update = {"id": res["id"]}
        tfevents = sorted([p for p in snap.paths if ".tfevents." in p])
        history = next((p for p in snap.paths if HISTORY_FNAME in p), None)
        event = next((p for p in snap.paths if EVENTS_FNAME in p), None)
        config = next((p for p in snap.paths if CONFIG_FNAME in p), None)
        user_config = next((p for p in snap.paths if USER_CONFIG_FNAME in p),
                           None)
        summary = next((p for p in snap.paths if SUMMARY_FNAME in p), None)
        meta = next((p for p in snap.paths if METADATA_FNAME in p), None)
        if history:
            wandb.termlog("Uploading history metrics")
            file_api.stream_file(history)
            snap.paths.remove(history)
        elif len(tfevents) > 0:
            from wandb import tensorflow as wbtf
            wandb.termlog("Found tfevents file, converting.")
            for file in tfevents:
                summary = wbtf.stream_tfevents(file, file_api)
        else:
            wandb.termerror(
                "No history or tfevents files found, only syncing files")
        if event:
            file_api.stream_file(event)
            snap.paths.remove(event)
        if config:
            run_update["config"] = util.load_yaml(open(config))
        elif user_config:
            # TODO: half backed support for config.json
            run_update["config"] = {
                k: {
                    "value": v
                }
                for k, v in six.iteritems(user_config)
            }
        if summary:
            run_update["summary_metrics"] = open(summary).read()
        if meta:
            meta = json.load(open(meta))
            if meta.get("git"):
                run_update["commit"] = meta["git"].get("commit")
                run_update["repo"] = meta["git"].get("remote")
            run_update["host"] = meta["host"]
            run_update["program_path"] = meta["program"]
            run_update["job_type"] = meta.get("jobType")
        else:
            run_update["host"] = socket.gethostname()

        wandb.termlog("Updating run and uploading files")
        api.upsert_run(**run_update)
        pusher = FilePusher(api)
        for k in paths:
            path = os.path.abspath(os.path.join(directory, k))
            pusher.update_file(k, path)
            pusher.file_changed(k, path)
        pusher.finish()
        pusher.print_status()
        wandb.termlog("Finished!")
        return run
예제 #8
0
파일: __init__.py 프로젝트: chmod644/client
def init(job_type=None,
         dir=None,
         config=None,
         project=None,
         entity=None,
         reinit=None,
         tags=None,
         group=None,
         allow_val_change=False,
         resume=False,
         force=False,
         tensorboard=False,
         sync_tensorboard=False,
         name=None,
         notes=None,
         id=None,
         magic=None):
    """Initialize W&B

    If called from within Jupyter, initializes a new run and waits for a call to
    `wandb.log` to begin pushing metrics.  Otherwise, spawns a new process
    to communicate with W&B.

    Args:
        job_type (str, optional): The type of job running, defaults to 'train'
        config (dict, argparse, or tf.FLAGS, optional): The hyper parameters to store with the run
        project (str, optional): The project to push metrics to
        entity (str, optional): The entity to push metrics to
        dir (str, optional): An absolute path to a directory where metadata will be stored
        group (str, optional): A unique string shared by all runs in a given group
        tags (list, optional): A list of tags to apply to the run
        id (str, optional): A globally unique (per project) identifier for the run
        name (str, optional): A display name which does not have to be unique
        notes (str, optional): A multiline string associated with the run
        reinit (bool, optional): Allow multiple calls to init in the same process
        resume (bool, str, optional): Automatically resume this run if run from the same machine,
            you can also pass a unique run_id
        sync_tensorboard (bool, optional): Synchronize wandb logs to tensorboard or tensorboardX
        force (bool, optional): Force authentication with wandb, defaults to False
        magic (bool, dict, or str, optional): magic configuration as bool, dict, json string,
            yaml filename

    Returns:
        A wandb.run object for metric and config logging.
    """
    trigger.call('on_init', **locals())
    global run
    global __stage_dir__

    # We allow re-initialization when we're in Jupyter or explicity opt-in to it.
    in_jupyter = _get_python_type() != "python"
    if reinit or (in_jupyter and reinit != False):
        reset_env(exclude=env.immutable_keys())
        run = None

    # TODO: deprecate tensorboard
    if tensorboard or sync_tensorboard and len(patched["tensorboard"]) == 0:
        util.get_module("wandb.tensorboard").patch()

    sagemaker_config = util.parse_sm_config()
    tf_config = util.parse_tfjob_config()
    if group == None:
        group = os.getenv(env.RUN_GROUP)
    if job_type == None:
        job_type = os.getenv(env.JOB_TYPE)
    if sagemaker_config:
        # Set run_id and potentially grouping if we're in SageMaker
        run_id = os.getenv('TRAINING_JOB_NAME')
        if run_id:
            os.environ[env.RUN_ID] = '-'.join(
                [run_id,
                 os.getenv('CURRENT_HOST', socket.gethostname())])
        conf = json.load(open("/opt/ml/input/config/resourceconfig.json"))
        if group == None and len(conf["hosts"]) > 1:
            group = os.getenv('TRAINING_JOB_NAME')
        # Set secret variables
        if os.path.exists("secrets.env"):
            for line in open("secrets.env", "r"):
                key, val = line.strip().split('=', 1)
                os.environ[key] = val
    elif tf_config:
        cluster = tf_config.get('cluster')
        job_name = tf_config.get('task', {}).get('type')
        task_index = tf_config.get('task', {}).get('index')
        if job_name is not None and task_index is not None:
            # TODO: set run_id for resuming?
            run_id = cluster[job_name][task_index].rsplit(":")[0]
            if job_type == None:
                job_type = job_name
            if group == None and len(cluster.get("worker", [])) > 0:
                group = cluster[job_name][0].rsplit("-" + job_name, 1)[0]
    image = util.image_id_from_k8s()
    if image:
        os.environ[env.DOCKER] = image
    if project:
        os.environ[env.PROJECT] = project
    if entity:
        os.environ[env.ENTITY] = entity
    if group:
        os.environ[env.RUN_GROUP] = group
    if job_type:
        os.environ[env.JOB_TYPE] = job_type
    if tags:
        os.environ[env.TAGS] = ",".join(tags)
    if id:
        os.environ[env.RUN_ID] = id
        if name is None:
            # We do this because of https://github.com/wandb/core/issues/2170
            # to ensure that the run's name is explicitly set to match its
            # id. If we don't do this and the id is eight characters long, the
            # backend will set the name to a generated human-friendly value.
            #
            # In any case, if the user is explicitly setting `id` but not
            # `name`, their id is probably a meaningful string that we can
            # use to label the run.
            name = os.environ.get(
                env.NAME,
                id)  # environment variable takes precedence over this.
    if name:
        os.environ[env.NAME] = name
    if notes:
        os.environ[env.NOTES] = notes
    if magic is not None and magic is not False:
        if isinstance(magic, dict):
            os.environ[env.MAGIC] = json.dumps(magic)
        elif isinstance(magic, str):
            os.environ[env.MAGIC] = magic
        elif isinstance(magic, bool):
            pass
        else:
            termwarn("wandb.init called with invalid magic parameter type",
                     repeat=False)
        from wandb import magic_impl
        magic_impl.magic_install()
    if dir:
        os.environ[env.DIR] = dir
        util.mkdir_exists_ok(wandb_dir())
    resume_path = os.path.join(wandb_dir(), wandb_run.RESUME_FNAME)
    if resume == True:
        os.environ[env.RESUME] = "auto"
    elif resume:
        os.environ[env.RESUME] = os.environ.get(env.RESUME, "allow")
        # TODO: remove allowing resume as a string in the future
        os.environ[env.RUN_ID] = id or resume
    elif os.path.exists(resume_path):
        os.remove(resume_path)
    if os.environ.get(env.RESUME) == 'auto' and os.path.exists(resume_path):
        if not os.environ.get(env.RUN_ID):
            os.environ[env.RUN_ID] = json.load(open(resume_path))["run_id"]

    # the following line is useful to ensure that no W&B logging happens in the user
    # process that might interfere with what they do
    # logging.basicConfig(format='user process %(asctime)s - %(name)s - %(levelname)s - %(message)s')

    # If a thread calls wandb.init() it will get the same Run object as
    # the parent. If a child process with distinct memory space calls
    # wandb.init(), it won't get an error, but it will get a result of
    # None.
    # This check ensures that a child process can safely call wandb.init()
    # after a parent has (only the parent will create the Run object).
    # This doesn't protect against the case where the parent doesn't call
    # wandb.init but two children do.
    if run or os.getenv(env.INITED):
        return run

    if __stage_dir__ is None:
        __stage_dir__ = "wandb"
        util.mkdir_exists_ok(wandb_dir())

    try:
        signal.signal(signal.SIGQUIT, _debugger)
    except AttributeError:
        pass

    try:
        run = wandb_run.Run.from_environment_or_defaults()
    except IOError as e:
        termerror('Failed to create run directory: {}'.format(e))
        raise LaunchError("Could not write to filesystem.")

    run.set_environment()

    def set_global_config(run):
        global config  # because we already have a local config
        config = run.config

    set_global_config(run)
    global summary
    summary = run.summary

    # set this immediately after setting the run and the config. if there is an
    # exception after this it'll probably break the user script anyway
    os.environ[env.INITED] = '1'

    # we do these checks after setting the run and the config because users scripts
    # may depend on those things
    if sys.platform == 'win32' and run.mode != 'clirun':
        termerror(
            'To use wandb on Windows, you need to run the command "wandb run python <your_train_script>.py"'
        )
        return run

    if in_jupyter:
        _init_jupyter(run)
    elif run.mode == 'clirun':
        pass
    elif run.mode == 'run':
        api = InternalApi()
        # let init_jupyter handle this itself
        if not in_jupyter and not api.api_key:
            termlog(
                "W&B is a tool that helps track and visualize machine learning experiments"
            )
            if force:
                termerror(
                    "No credentials found.  Run \"wandb login\" or \"wandb off\" to disable wandb"
                )
            else:
                if run.check_anonymous():
                    _init_headless(run)
                else:
                    termlog(
                        "No credentials found.  Run \"wandb login\" to visualize your metrics"
                    )
                    run.mode = "dryrun"
                    _init_headless(run, False)
        else:
            _init_headless(run)
    elif run.mode == 'dryrun':
        termlog('Dry run mode, not syncing to the cloud.')
        _init_headless(run, False)
    else:
        termerror('Invalid run mode "%s". Please unset WANDB_MODE.' % run.mode)
        raise LaunchError("The WANDB_MODE environment variable is invalid.")

    # set the run directory in the config so it actually gets persisted
    run.config.set_run_dir(run.dir)

    if sagemaker_config:
        run.config.update(sagemaker_config)
        allow_val_change = True
    if config:
        run.config.update(config, allow_val_change=allow_val_change)

    # Access history to ensure resumed is set when resuming
    run.history
    # Load the summary to support resuming
    run.summary.load()

    atexit.register(run.close_files)

    return run
예제 #9
0
파일: __init__.py 프로젝트: chmod644/client
def ensure_configured():
    global GLOBAL_LOG_FNAME, api
    # We re-initialize here for tests
    api = InternalApi()
    GLOBAL_LOG_FNAME = os.path.abspath(os.path.join(wandb_dir(), 'debug.log'))