示例#1
0
class Api(object):
    """W&B Internal Api wrapper

    Note:
        Settings are automatically overridden by looking for
        a `wandb/settings` file in the current working directory or it's parent
        directory.  If none can be found, we look in the current users home
        directory.

    Args:
        default_settings(:obj:`dict`, optional): If you aren't using a settings
        file or you wish to override the section to use in the settings file
        Override the settings here.
    """

    HTTP_TIMEOUT = env.get_http_timeout(10)

    def __init__(self, default_settings=None, load_settings=True, retry_timedelta=datetime.timedelta(days=1), environ=os.environ):
        self._environ = environ
        self.default_settings = {
            'section': "default",
            'run': "latest",
            'git_remote': "origin",
            'ignore_globs': [],
            'base_url': "https://api.wandb.ai"
        }
        self.retry_timedelta = retry_timedelta
        self.default_settings.update(default_settings or {})
        self.retry_uploads = 10
        self._settings = Settings(load_settings=load_settings)
        self.git = GitRepo(remote=self.settings("git_remote"))
        # Mutable settings set by the _file_stream_api
        self.dynamic_settings = {
            'system_sample_seconds': 2,
            'system_samples': 15,
            'heartbeat_seconds': 30,
        }
        self.client = Client(
            transport=RequestsHTTPTransport(
                headers={'User-Agent': self.user_agent, 'X-WANDB-USERNAME': env.get_username(env=self._environ)},
                use_json=True,
                # this timeout won't apply when the DNS lookup fails. in that case, it will be 60s
                # https://bugs.python.org/issue22889
                timeout=self.HTTP_TIMEOUT,
                auth=("api", self.api_key or ""),
                url='%s/graphql' % self.settings('base_url')
            )
        )
        self.gql = retry.Retry(self.execute,
            retry_timedelta=retry_timedelta,
            check_retry_fn=util.no_retry_auth,
            retryable_exceptions=(RetryError, requests.RequestException))
        self._current_run_id = None
        self._file_stream_api = None

    def reauth(self):
        """Ensures the current api key is set in the transport"""
        self.client.transport.auth = ("api", self.api_key or "")

    def execute(self, *args, **kwargs):
        """Wrapper around execute that logs in cases of failure."""
        try:
            return self.client.execute(*args, **kwargs)
        except requests.exceptions.HTTPError as err:
            res = err.response
            logger.error("%s response executing GraphQL." % res.status_code)
            logger.error(res.text)
            self.display_gorilla_error_if_found(res)
            six.reraise(*sys.exc_info())

    def display_gorilla_error_if_found(self, res):
        try:
            data = res.json()
        except ValueError:
            return

        if 'errors' in data and isinstance(data['errors'], list):
            for err in data['errors']:
                if not err.get('message'):
                    continue
                wandb.termerror('Error while calling W&B API: {} ({})'.format(err['message'], res))


    def disabled(self):
        return self._settings.get(Settings.DEFAULT_SECTION, 'disabled', fallback=False)

    def sync_spell(self, run, env=None):
        """Syncs this run with spell"""
        try:
            env = env or os.environ
            run.config._set_wandb("spell_url", env.get("SPELL_RUN_URL"))
            run.config.persist()
            try:
                url = run.get_url()
            except CommError as e:
                wandb.termerror("Unable to register run with spell.run: %s" % e.message)
                return False
            return requests.put(env.get("SPELL_API_URL", "https://api.spell.run") + "/wandb_url", json={
                "access_token": env.get("WANDB_ACCESS_TOKEN"),
                "url": url
            }, timeout=2)
        except requests.RequestException:
            return False

    def save_pip(self, out_dir):
        """Saves the current working set of pip packages to requirements.txt"""
        try:
            import pkg_resources

            installed_packages = [d for d in iter(pkg_resources.working_set)]
            installed_packages_list = sorted(
                ["%s==%s" % (i.key, i.version) for i in installed_packages]
            )
            with open(os.path.join(out_dir, 'requirements.txt'), 'w') as f:
                f.write("\n".join(installed_packages_list))
        except Exception as e:
            logger.error("Error saving pip packages")

    def save_patches(self, out_dir):
        """Save the current state of this repository to one or more patches.

        Makes one patch against HEAD and another one against the most recent
        commit that occurs in an upstream branch. This way we can be robust
        to history editing as long as the user never does "push -f" to break
        history on an upstream branch.

        Writes the first patch to <out_dir>/diff.patch and the second to
        <out_dir>/upstream_diff_<commit_id>.patch.

        Args:
            out_dir (str): Directory to write the patch files.
        """
        if not self.git.enabled:
            return False

        try:
            root = self.git.root
            if self.git.dirty:
                patch_path = os.path.join(out_dir, 'diff.patch')
                if self.git.has_submodule_diff:
                    with open(patch_path, 'wb') as patch:
                        # we diff against HEAD to ensure we get changes in the index
                        subprocess.check_call(
                            ['git', 'diff', '--submodule=diff', 'HEAD'], stdout=patch, cwd=root, timeout=5)
                else:
                    with open(patch_path, 'wb') as patch:
                        subprocess.check_call(
                            ['git', 'diff', 'HEAD'], stdout=patch, cwd=root, timeout=5)

            upstream_commit = self.git.get_upstream_fork_point()
            if upstream_commit and upstream_commit != self.git.repo.head.commit:
                sha = upstream_commit.hexsha
                upstream_patch_path = os.path.join(
                    out_dir, 'upstream_diff_{}.patch'.format(sha))
                if self.git.has_submodule_diff:
                    with open(upstream_patch_path, 'wb') as upstream_patch:
                        subprocess.check_call(
                            ['git', 'diff', '--submodule=diff', sha], stdout=upstream_patch, cwd=root, timeout=5)
                else:
                    with open(upstream_patch_path, 'wb') as upstream_patch:
                        subprocess.check_call(
                            ['git', 'diff', sha], stdout=upstream_patch, cwd=root, timeout=5)
        # TODO: A customer saw `ValueError: Reference at 'refs/remotes/origin/foo' does not exist`
        # so we now catch ValueError.  Catching this error feels too generic.
        except (ValueError, subprocess.CalledProcessError, subprocess.TimeoutExpired) as e:
            logger.error('Error generating diff: %s' % e)

    def set_current_run_id(self, run_id):
        self._current_run_id = run_id

    @property
    def current_run_id(self):
        return self._current_run_id

    @property
    def user_agent(self):
        return 'W&B Internal Client %s' % __version__

    @property
    def api_key(self):
        auth = requests.utils.get_netrc_auth(self.api_url)
        key = None
        if auth:
            key = auth[-1]
        # Environment should take precedence
        if self._environ.get(env.API_KEY):
            key = self._environ.get(env.API_KEY)
        return key

    @property
    def api_url(self):
        return self.settings('base_url')

    @property
    def app_url(self):
        api_url = self.api_url
        # Development
        if api_url.endswith('.test') or self.settings().get("dev_prod"):
            return 'http://app.test'
        # On-prem VM
        if api_url.endswith(':11001'):
            return api_url.replace(':11001', ':11000')
        # Normal
        if api_url.startswith('https://api.'):
            return api_url.replace('api.', 'app.')
        # Unexpected
        return api_url

    def settings(self, key=None, section=None):
        """The settings overridden from the wandb/settings file.

        Args:
            key (str, optional): If provided only this setting is returned
            section (str, optional): If provided this section of the setting file is
            used, defaults to "default"

        Returns:
            A dict with the current settings

                {
                    "entity": "models",
                    "base_url": "https://api.wandb.ai",
                    "project": None
                }
        """
        result = self.default_settings.copy()
        result.update(self._settings.items(section=section))
        result.update({
            'entity': env.get_entity(
                self._settings.get(Settings.DEFAULT_SECTION, "entity", fallback=result.get('entity')),
                env=self._environ),
            'project': env.get_project(
                self._settings.get(Settings.DEFAULT_SECTION, "project", fallback=result.get('project')),
                env=self._environ),
            'base_url': env.get_base_url(
                self._settings.get(Settings.DEFAULT_SECTION, "base_url", fallback=result.get('base_url')),
                env=self._environ),
            'ignore_globs': env.get_ignore(
                self._settings.get(Settings.DEFAULT_SECTION, "ignore_globs", fallback=result.get('ignore_globs')),
                env=self._environ),
        })

        return result if key is None else result[key]

    def clear_setting(self, key):
        self._settings.clear(Settings.DEFAULT_SECTION, key)

    def set_setting(self, key, value, globally=False):
        self._settings.set(Settings.DEFAULT_SECTION, key, value, globally=globally)
        if key == 'entity':
            env.set_entity(value, env=self._environ)
        elif key == 'project':
            env.set_project(value, env=self._environ)

    def parse_slug(self, slug, project=None, run=None):
        if slug and "/" in slug:
            parts = slug.split("/")
            project = parts[0]
            run = parts[1]
        else:
            project = project or self.settings().get("project")
            if project is None:
                raise CommError("No default project configured.")
            run = run or slug or env.get_run(env=self._environ)
            if run is None:
                run = "latest"
        return (project, run)

    @normalize_exceptions
    def viewer(self):
        query = gql('''
        query Viewer{
            viewer {
                id
                entity
                teams {
                    edges {
                        node {
                            name
                        }
                    }
                }
            }
        }
        ''')
        res = self.gql(query)
        return res.get('viewer') or {}

    @normalize_exceptions
    def list_projects(self, entity=None):
        """Lists projects in W&B scoped by entity.

        Args:
            entity (str, optional): The entity to scope this project to.

        Returns:
                [{"id","name","description"}]
        """
        query = gql('''
        query Models($entity: String!) {
            models(first: 10, entityName: $entity) {
                edges {
                    node {
                        id
                        name
                        description
                    }
                }
            }
        }
        ''')
        return self._flatten_edges(self.gql(query, variable_values={
            'entity': entity or self.settings('entity')})['models'])

    @normalize_exceptions
    def project(self, project, entity=None):
        """Retrive project

        Args:
            project (str): The project to get details for
            entity (str, optional): The entity to scope this project to.

        Returns:
                [{"id","name","repo","dockerImage","description"}]
        """
        query = gql('''
        query Models($entity: String, $project: String!) {
            model(name: $project, entityName: $entity) {
                id
                name
                repo
                dockerImage
                description
            }
        }
        ''')
        return self.gql(query, variable_values={
            'entity': entity, 'project': project})['model']

    @normalize_exceptions
    def sweep(self, sweep, specs, project=None, entity=None):
        """Retrieve sweep.

        Args:
            sweep (str): The sweep to get details for
            specs (str): history specs
            project (str, optional): The project to scope this sweep to.
            entity (str, optional): The entity to scope this sweep to.

        Returns:
                [{"id","name","repo","dockerImage","description"}]
        """
        query = gql('''
        query Models($entity: String, $project: String!, $sweep: String!, $specs: [JSONString!]!) {
            model(name: $project, entityName: $entity) {
                sweep(sweepName: $sweep) {
                    id
                    name
                    method
                    state
                    description
                    config
                    createdAt
                    heartbeatAt
                    updatedAt
                    earlyStopJobRunning
                    bestLoss
                    controller
                    scheduler
                    runs {
                        edges {
                            node {
                                name
                                state
                                config
                                exitcode
                                heartbeatAt
                                shouldStop
                                failed
                                stopped
                                running
                                summaryMetrics
                                sampledHistory(specs: $specs)
                            }
                        }
                    }
                }
            }
        }
        ''')
        data =  self.gql(query, variable_values={
            'entity': entity or self.settings('entity'), 'project': project or self.settings('project'), 'sweep': sweep, 'specs': specs})['model']['sweep']
        if data:
            data['runs'] = self._flatten_edges(data['runs'])
        return data

    @normalize_exceptions
    def list_runs(self, project, entity=None):
        """Lists runs in W&B scoped by project.

        Args:
            project (str): The project to scope the runs to
            entity (str, optional): The entity to scope this project to.  Defaults to public models

        Returns:
                [{"id",name","description"}]
        """
        query = gql('''
        query Buckets($model: String!, $entity: String!) {
            model(name: $model, entityName: $entity) {
                buckets(first: 10) {
                    edges {
                        node {
                            id
                            name
                            displayName
                            description
                        }
                    }
                }
            }
        }
        ''')
        return self._flatten_edges(self.gql(query, variable_values={
            'entity': entity or self.settings('entity'),
            'model': project or self.settings('project')})['model']['buckets'])

    @normalize_exceptions
    def launch_run(self, command, project=None, entity=None, run_id=None):
        """Launch a run in the cloud.

        Args:
            command (str): The command to run
            program (str): The file to run
            project (str): The project to scope the runs to
            entity (str, optional): The entity to scope this project to.  Defaults to public models
            run_id (str, optional): The run_id to scope to

        Returns:
                [{"podName","status"}]
        """
        query = gql('''
        mutation launchRun(
            $entity: String
            $model: String
            $runId: String
            $image: String
            $command: String
            $patch: String
            $cwd: String
            $datasets: [String]
        ) {
            launchRun(input: {id: $runId, entityName: $entity, patch: $patch, modelName: $model,
                image: $image, command: $command, datasets: $datasets, cwd: $cwd}) {
                podName
                status
                runId
            }
        }
        ''')
        patch = BytesIO()
        if self.git.dirty:
            self.git.repo.git.execute(['git', 'diff'], output_stream=patch)
            patch.seek(0)
        cwd = "."
        if self.git.enabled:
            cwd = cwd + os.getcwd().replace(self.git.repo.working_dir, "")
        return self.gql(query, variable_values={
            'entity': entity or self.settings('entity'),
            'model': project or self.settings('project'),
            'command': command,
            'runId': run_id,
            'patch': patch.read().decode("utf8"),
            'cwd': cwd
        })

    @normalize_exceptions
    def run_config(self, project, run=None, entity=None):
        """Get the relevant configs for a run

        Args:
            project (str): The project to download, (can include bucket)
            run (str, optional): The run to download
            entity (str, optional): The entity to scope this project to.
        """
        query = gql('''
        query Model($name: String!, $entity: String!, $run: String!) {
            model(name: $name, entityName: $entity) {
                bucket(name: $run) {
                    config
                    commit
                    patch
                    files(names: ["wandb-metadata.json"]) {
                        edges {
                            node {
                                url
                            }
                        }
                    }
                }
            }
        }
        ''')

        response = self.gql(query, variable_values={
            'name': project, 'run': run, 'entity': entity
        })
        if response['model'] == None:
            raise ValueError("Run {}/{}/{} not found".format(entity, project, run) )
        run = response['model']['bucket']
        commit = run['commit']
        patch = run['patch']
        config = json.loads(run['config'] or '{}')
        if len(run['files']['edges']) > 0:
            url = run['files']['edges'][0]['node']['url']
            res = requests.get(url)
            res.raise_for_status()
            metadata = res.json()
        else:
            metadata = {}
        return (commit, config, patch, metadata)

    @normalize_exceptions
    def run_resume_status(self, entity, project_name, name):
        """Check if a run exists and get resume information.

        Args:
            entity (str, optional): The entity to scope this project to.
            project_name (str): The project to download, (can include bucket)
            run (str, optional): The run to download
        """
        query = gql('''
        query Model($project: String!, $entity: String, $name: String!) {
            model(name: $project, entityName: $entity) {
                id
                name
                entity {
                    id
                    name
                }

                bucket(name: $name, missingOk: true) {
                    id
                    name
                    summaryMetrics
                    displayName
                    logLineCount
                    historyLineCount
                    eventsLineCount
                    historyTail
                    eventsTail
                    config
                }
            }
        }
        ''')

        response = self.gql(query, variable_values={
            'entity': entity, 'project': project_name, 'name': name,
        })

        if 'model' not in response or 'bucket' not in (response['model'] or {}):
            return None

        project = response['model']
        self.set_setting('project', project_name)
        if 'entity' in project:
            self.set_setting('entity', project['entity']['name'])

        return project['bucket']

    @normalize_exceptions
    def check_stop_requested(self, project_name, entity_name, run_id):
        query = gql('''
        query Model($projectName: String, $entityName: String, $runId: String!) {
            project(name:$projectName, entityName:$entityName) {
                run(name:$runId) {
                    stopped
                }
            }
        }
        ''')

        response = self.gql(query, variable_values={
            'projectName': project_name, 'entityName': entity_name, 'runId': run_id,
        })

        project = response.get('project', None)
        if not project:
            return False
        run = project.get('run', None)
        if not run:
            return False
        
        return run['stopped']


    def format_project(self, project):
        return re.sub(r'\W+', '-', project.lower()).strip("-_")

    @normalize_exceptions
    def upsert_project(self, project, id=None, description=None, entity=None):
        """Create a new project

        Args:
            project (str): The project to create
            description (str, optional): A description of this project
            entity (str, optional): The entity to scope this project to.
        """
        mutation = gql('''
        mutation UpsertModel($name: String!, $id: String, $entity: String!, $description: String, $repo: String)  {
            upsertModel(input: { id: $id, name: $name, entityName: $entity, description: $description, repo: $repo }) {
                model {
                    name
                    description
                }
            }
        }
        ''')
        response = self.gql(mutation, variable_values={
            'name': self.format_project(project), 'entity': entity or self.settings('entity'),
            'description': description, 'repo': self.git.remote_url, 'id': id})
        return response['upsertModel']['model']

    @normalize_exceptions
    def upsert_run(self, id=None, name=None, project=None, host=None,
                   group=None, tags=None,
                   config=None, description=None, entity=None, state=None,
                   display_name=None, notes=None,
                   repo=None, job_type=None, program_path=None, commit=None,
                   sweep_name=None, summary_metrics=None, num_retries=None):
        """Update a run

        Args:
            id (str, optional): The existing run to update
            name (str, optional): The name of the run to create
            group (str, optional): Name of the group this run is a part of
            project (str, optional): The name of the project
            config (dict, optional): The latest config params
            description (str, optional): A description of this project
            entity (str, optional): The entity to scope this project to.
            repo (str, optional): Url of the program's repository.
            state (str, optional): State of the program.
            job_type (str, optional): Type of job, e.g 'train'.
            program_path (str, optional): Path to the program.
            commit (str, optional): The Git SHA to associate the run with
            summary_metrics (str, optional): The JSON summary metrics
        """
        mutation = gql('''
        mutation UpsertBucket(
            $id: String, $name: String,
            $project: String,
            $entity: String!,
            $groupName: String,
            $description: String,
            $displayName: String,
            $notes: String,
            $commit: String,
            $config: JSONString,
            $host: String,
            $debug: Boolean,
            $program: String,
            $repo: String,
            $jobType: String,
            $state: String,
            $sweep: String,
            $tags: [String!],
            $summaryMetrics: JSONString,
        ) {
            upsertBucket(input: {
                id: $id,
                name: $name,
                groupName: $groupName,
                modelName: $project,
                entityName: $entity,
                description: $description,
                displayName: $displayName,
                notes: $notes,
                config: $config,
                commit: $commit,
                host: $host,
                debug: $debug,
                jobProgram: $program,
                jobRepo: $repo,
                jobType: $jobType,
                state: $state,
                sweep: $sweep,
                tags: $tags,
                summaryMetrics: $summaryMetrics,
            }) {
                bucket {
                    id
                    name
                    displayName
                    description
                    config
                    project {
                        id
                        name
                        entity {
                            id
                            name
                        }
                    }
                }
            }
        }
        ''')
        if config is not None:
            config = json.dumps(config)
        if not description or description.isspace():
            description = None

        kwargs = {}
        if num_retries is not None:
            kwargs['num_retries'] = num_retries

        variable_values = {
            'id': id, 'entity': entity or self.settings('entity'), 'name': name, 'project': project,
            'groupName': group, 'tags': tags,
            'description': description, 'config': config, 'commit': commit,
            'displayName': display_name, 'notes': notes,
            'host': None if self.settings().get('anonymous') == 'true' else host,
            'debug': env.is_debug(env=self._environ), 'repo': repo, 'program': program_path, 'jobType': job_type,
            'state': state, 'sweep': sweep_name, 'summaryMetrics': summary_metrics
        }

        response = self.gql(
            mutation, variable_values=variable_values, **kwargs)

        run = response['upsertBucket']['bucket']
        project = run.get('project')
        if project:
            self.set_setting('project', project['name'])
            entity = project.get('entity')
            if entity:
                self.set_setting('entity', entity['name'])

        return response['upsertBucket']['bucket']

    @normalize_exceptions
    def upload_urls(self, project, files, run=None, entity=None, description=None):
        """Generate temporary resumeable upload urls

        Args:
            project (str): The project to download
            files (list or dict): The filenames to upload
            run (str, optional): The run to upload to
            entity (str, optional): The entity to scope this project to.  Defaults to wandb models

        Returns:
            (bucket_id, file_info)
            bucket_id: id of bucket we uploaded to
            file_info: A dict of filenames and urls, also indicates if this revision already has uploaded files.
                {
                    'weights.h5': { "url": "https://weights.url" },
                    'model.json': { "url": "https://model.json", "updatedAt": '2013-04-26T22:22:23.832Z', 'md5': 'mZFLkyvTelC5g8XnyQrpOw==' },
                }
        """
        query = gql('''
        query Model($name: String!, $files: [String]!, $entity: String!, $run: String!, $description: String) {
            model(name: $name, entityName: $entity) {
                bucket(name: $run, desc: $description) {
                    id
                    files(names: $files) {
                        edges {
                            node {
                                name
                                url(upload: true)
                                updatedAt
                            }
                        }
                    }
                }
            }
        }
        ''')
        run_id = run or self.settings('run')
        entity = entity or self.settings('entity')
        query_result = self.gql(query, variable_values={
            'name': project, 'run': run_id,
            'entity': entity,
            'description': description,
            'files': [file for file in files]
        })

        run = query_result['model']['bucket']
        if run:
            result = {file['name']: file for file in self._flatten_edges(run['files'])}
            return run['id'], result
        else:
            raise CommError("Run does not exist {}/{}/{}.".format(entity, project, run_id))

    @normalize_exceptions
    def download_urls(self, project, run=None, entity=None):
        """Generate download urls

        Args:
            project (str): The project to download
            run (str, optional): The run to upload to
            entity (str, optional): The entity to scope this project to.  Defaults to wandb models

        Returns:
            A dict of extensions and urls

                {
                    'weights.h5': { "url": "https://weights.url", "updatedAt": '2013-04-26T22:22:23.832Z', 'md5': 'mZFLkyvTelC5g8XnyQrpOw==' },
                    'model.json': { "url": "https://model.url", "updatedAt": '2013-04-26T22:22:23.832Z', 'md5': 'mZFLkyvTelC5g8XnyQrpOw==' }
                }
        """
        query = gql('''
        query Model($name: String!, $entity: String!, $run: String!)  {
            model(name: $name, entityName: $entity) {
                bucket(name: $run) {
                    files {
                        edges {
                            node {
                                name
                                url
                                md5
                                updatedAt
                            }
                        }
                    }
                }
            }
        }
        ''')
        query_result = self.gql(query, variable_values={
            'name': project, 'run': run or self.settings('run'),
            'entity': entity or self.settings('entity')})
        files = self._flatten_edges(query_result['model']['bucket']['files'])
        return {file['name']: file for file in files if file}

    @normalize_exceptions
    def download_url(self, project, file_name, run=None, entity=None):
        """Generate download urls

        Args:
            project (str): The project to download
            file_name (str): The name of the file to download
            run (str, optional): The run to upload to
            entity (str, optional): The entity to scope this project to.  Defaults to wandb models

        Returns:
            A dict of extensions and urls

                { "url": "https://weights.url", "updatedAt": '2013-04-26T22:22:23.832Z', 'md5': 'mZFLkyvTelC5g8XnyQrpOw==' }

        """
        query = gql('''
        query Model($name: String!, $fileName: String!, $entity: String!, $run: String!)  {
            model(name: $name, entityName: $entity) {
                bucket(name: $run) {
                    files(names: [$fileName]) {
                        edges {
                            node {
                                name
                                url
                                md5
                                updatedAt
                            }
                        }
                    }
                }
            }
        }
        ''')
        query_result = self.gql(query, variable_values={
            'name': project, 'run': run or self.settings('run'), 'fileName': file_name,
            'entity': entity or self.settings('entity')})
        if query_result['model']:
            files = self._flatten_edges(query_result['model']['bucket']['files'])
            return files[0] if len(files) > 0 and files[0].get('updatedAt') else None
        else:
            return None

    @normalize_exceptions
    def download_file(self, url):
        """Initiate a streaming download

        Args:
            url (str): The url to download

        Returns:
            A tuple of the content length and the streaming response
        """
        response = requests.get(url, stream=True)
        response.raise_for_status()
        return (int(response.headers.get('content-length', 0)), response)

    @normalize_exceptions
    def download_write_file(self, metadata, out_dir=None):
        """Download a file from a run and write it to wandb/

        Args:
            metadata (obj): The metadata object for the file to download. Comes from Api.download_urls().

        Returns:
            A tuple of the file's local path and the streaming response. The streaming response is None if the file already existed and was up to date.
        """
        fileName = metadata['name']
        path = os.path.join(out_dir or wandb_dir(), fileName)
        if self.file_current(fileName, metadata['md5']):
            return path, None

        size, response = self.download_file(metadata['url'])

        with open(path, "wb") as file:
            for data in response.iter_content(chunk_size=1024):
                file.write(data)

        return path, response

    def upload_file(self, url, file, callback=None, extra_headers={}):
        """Uploads a file to W&B with failure resumption

        Args:
            url (str): The url to download
            file (str): The path to the file you want to upload
            callback (:obj:`func`, optional): A callback which is passed the number of
            bytes uploaded since the last time it was called, used to report progress

        Returns:
            The requests library response object
        """
        extra_headers = extra_headers.copy()
        response = None
        progress = Progress(file, callback=callback)
        if progress.len == 0:
            raise CommError("%s is an empty file" % file.name)
        try:
            response = requests.put(
                url, data=progress, headers=extra_headers)
            response.raise_for_status()
        except requests.exceptions.RequestException as e:
            status_code = e.response.status_code if e.response != None else 0
            # Retry errors from cloud storage or local network issues
            if status_code in (308, 409, 429, 500, 502, 503, 504) or isinstance(e, (requests.exceptions.Timeout, requests.exceptions.ConnectionError)):
                util.sentry_reraise(retry.TransientException(exc=e))
            else:
                util.sentry_reraise(e)

        return response

    upload_file_retry = normalize_exceptions(retry.retriable(num_retries=5)(upload_file))

    @normalize_exceptions
    def register_agent(self, host, sweep_id=None, project_name=None, entity=None):
        """Register a new agent

        Args:
            host (str): hostname
            persistent (bool): long running or oneoff
            sweep (str): sweep id
            project_name: (str): model that contains sweep
        """
        mutation = gql('''
        mutation CreateAgent(
            $host: String!
            $projectName: String!,
            $entityName: String!,
            $sweep: String!
        ) {
            createAgent(input: {
                host: $host,
                projectName: $projectName,
                entityName: $entityName,
                sweep: $sweep,
            }) {
                agent {
                    id
                }
            }
        }
        ''')
        if entity is None:
            entity = self.settings("entity")
        if project_name is None:
            project_name = self.settings('project')

        # don't retry on validation or not found errors
        def no_retry_4xx(e):
            if not isinstance(e, requests.HTTPError):
                return True
            if not(e.response.status_code >= 400 and e.response.status_code < 500):
                return True
            body = json.loads(e.response.content)
            raise UsageError(body['errors'][0]['message'])

        response = self.gql(mutation, variable_values={
            'host': host,
            'entityName': entity,
            'projectName': project_name,
            'sweep': sweep_id}, check_retry_fn=no_retry_4xx)
        return response['createAgent']['agent']

    def agent_heartbeat(self, agent_id, metrics, run_states):
        """Notify server about agent state, receive commands.

        Args:
            agent_id (str): agent_id
            metrics (dict): system metrics
            run_states (dict): run_id: state mapping
        Returns:
            List of commands to execute.
        """
        mutation = gql('''
        mutation Heartbeat(
            $id: ID!,
            $metrics: JSONString,
            $runState: JSONString
        ) {
            agentHeartbeat(input: {
                id: $id,
                metrics: $metrics,
                runState: $runState
            }) {
                agent {
                    id
                }
                commands
            }
        }
        ''')
        try:
            response = self.gql(mutation, variable_values={
                'id': agent_id,
                'metrics': json.dumps(metrics),
                'runState': json.dumps(run_states)})
        except Exception as e:
            # GQL raises exceptions with stringified python dictionaries :/
            message = ast.literal_eval(e.args[0])["message"]
            logger.error('Error communicating with W&B: %s', message)
            return []
        else:
            return json.loads(response['agentHeartbeat']['commands'])

    @normalize_exceptions
    def upsert_sweep(self, config, controller=None, scheduler=None, obj_id=None, project=None, entity=None):
        """Upsert a sweep object.

        Args:
            config (str): sweep config (will be converted to yaml)
        """
        project_query = '''
                    project {
                        id
                        name
                        entity {
                            id
                            name
                        }
                    }
        '''
        mutation_str = '''
        mutation UpsertSweep(
            $id: ID,
            $config: String,
            $description: String,
            $entityName: String!,
            $projectName: String!,
            $controller: JSONString,
            $scheduler: JSONString
        ) {
            upsertSweep(input: {
                id: $id,
                config: $config,
                description: $description,
                entityName: $entityName,
                projectName: $projectName,
                controller: $controller,
                scheduler: $scheduler
            }) {
                sweep {
                    name
                    _PROJECT_QUERY_
                }
            }
        }
        '''
        # FIXME(jhr): we need protocol versioning to know schema is not supported
        # for now we will just try both new and old query
        mutation_new = gql(mutation_str.replace("_PROJECT_QUERY_", project_query))
        mutation_old = gql(mutation_str.replace("_PROJECT_QUERY_", ""))

        # don't retry on validation errors
        # TODO(jhr): generalize error handling routines
        def no_retry_4xx(e):
            if not isinstance(e, requests.HTTPError):
                return True
            if not(e.response.status_code >= 400 and e.response.status_code < 500):
                return True
            body = json.loads(e.response.content)
            raise UsageError(body['errors'][0]['message'])

        for mutation in mutation_new, mutation_old:
            try:
                response = self.gql(mutation, variable_values={
                    'id': obj_id,
                    'config': yaml.dump(config),
                    'description': config.get("description"),
                    'entityName': entity or self.settings("entity"),
                    'projectName': project or self.settings("project"),
                    'controller': controller,
                    'scheduler': scheduler},
                    check_retry_fn=no_retry_4xx)
            except UsageError as e:
                raise(e)
            except Exception as e:
                # graphql schema exception is generic
                err = e
                continue
            err = None
            break
        if err:
            raise(err)

        sweep = response['upsertSweep']['sweep']
        project = sweep.get('project')
        if project:
            self.set_setting('project', project['name'])
            entity = project.get('entity')
            if entity:
                self.set_setting('entity', entity['name'])

        return response['upsertSweep']['sweep']['name']

    @normalize_exceptions
    def create_anonymous_api_key(self):
        """Creates a new API key belonging to a new anonymous user."""
        mutation = gql('''
        mutation CreateAnonymousApiKey {
            createAnonymousEntity(input: {}) {
                apiKey {
                    name
                }
            }
        }
        ''')

        response = self.gql(mutation, variable_values={})
        return response['createAnonymousEntity']['apiKey']['name']

    def file_current(self, fname, md5):
        """Checksum a file and compare the md5 with the known md5
        """
        return os.path.isfile(fname) and util.md5_file(fname) == md5

    @normalize_exceptions
    def pull(self, project, run=None, entity=None):
        """Download files from W&B

        Args:
            project (str): The project to download
            run (str, optional): The run to upload to
            entity (str, optional): The entity to scope this project to.  Defaults to wandb models

        Returns:
            The requests library response object
        """
        project, run = self.parse_slug(project, run=run)
        urls = self.download_urls(project, run, entity)
        responses = []
        for fileName in urls:
            _, response = self.download_write_file(urls[fileName])
            if response:
                responses.append(response)

        return responses

    def get_project(self):
        return self.settings('project')

    @normalize_exceptions
    def push(self, files, run=None, entity=None, project=None, description=None, force=True, progress=False):
        """Uploads multiple files to W&B

        Args:
            files (list or dict): The filenames to upload
            run (str, optional): The run to upload to
            entity (str, optional): The entity to scope this project to.  Defaults to wandb models
            project (str, optional): The name of the project to upload to. Defaults to the one in settings.
            description (str, optional): The description of the changes
            force (bool, optional): Whether to prevent push if git has uncommitted changes
            progress (callable, or stream): If callable, will be called with (chunk_bytes,
                total_bytes) as argument else if True, renders a progress bar to stream.

        Returns:
            The requests library response object
        """
        if project is None:
            project = self.get_project()
        if project is None:
            raise CommError("No project configured.")
        if run is None:
            run = self.current_run_id

        # TODO(adrian): we use a retriable version of self.upload_file() so
        # will never retry self.upload_urls() here. Instead, maybe we should
        # make push itself retriable.
        run_id, result = self.upload_urls(
            project, files, run, entity, description)
        responses = []
        for file_name, file_info in result.items():
            file_url = file_info['url']

            # If the upload URL is relative, fill it in with the base URL,
            # since its a proxied file store like the on-prem VM.
            if file_url.startswith('/'):
                file_url = '{}{}'.format(self.api_url, file_url)

            try:
                # To handle Windows paths
                # TODO: this doesn't handle absolute paths...
                normal_name = os.path.join(*file_name.split("/"))
                open_file = files[file_name] if isinstance(
                    files, dict) else open(normal_name, "rb")
            except IOError:
                print("%s does not exist" % file_name)
                continue
            if progress:
                if hasattr(progress, '__call__'):
                    responses.append(self.upload_file_retry(
                        file_url, open_file, progress))
                else:
                    length = os.fstat(open_file.fileno()).st_size
                    with click.progressbar(file=progress, length=length, label='Uploading file: %s' % (file_name),
                                           fill_char=click.style('&', fg='green')) as bar:
                        responses.append(self.upload_file_retry(
                            file_url, open_file, lambda bites, _: bar.update(bites)))
            else:
                responses.append(self.upload_file_retry(file_info['url'], open_file))
            open_file.close()
        return responses

    def get_file_stream_api(self):
        """This creates a new file pusher thread.  Call start to initiate the thread that talks to W&B"""
        if not self._file_stream_api:
            if self._current_run_id is None:
                raise UsageError(
                    'Must have a current run to use file stream API.')
            self._file_stream_api = FileStreamApi(self, self._current_run_id)
        return self._file_stream_api

    def _status_request(self, url, length):
        """Ask google how much we've uploaded"""
        return requests.put(
            url=url,
            headers={'Content-Length': '0',
                     'Content-Range': 'bytes */%i' % length}
        )

    def _flatten_edges(self, response):
        """Return an array from the nested graphql relay structure"""
        return [node['node'] for node in response['edges']]
class FileStreamApi(object):
    """Pushes chunks of files to our streaming endpoint.

    This class is used as a singleton. It has a thread that serializes access to
    the streaming endpoint and performs rate-limiting and batching.

    TODO: Differentiate between binary/text encoding.
    """

    Finish = collections.namedtuple("Finish", ("exitcode"))
    Preempting = collections.namedtuple("Preempting", ())

    HTTP_TIMEOUT = env.get_http_timeout(10)
    MAX_ITEMS_PER_PUSH = 10000

    def __init__(self, api, run_id, start_time, settings=None):
        if settings is None:
            settings = dict()
        # NOTE: exc_info is set in thread_except_body context and readable by calling threads
        self._exc_info = None
        self._settings = settings
        self._api = api
        self._run_id = run_id
        self._start_time = start_time
        self._client = requests.Session()
        self._client.auth = ("api", api.api_key)
        self._client.timeout = self.HTTP_TIMEOUT
        self._client.headers.update({
            "User-Agent": api.user_agent,
            "X-WANDB-USERNAME": env.get_username(),
            "X-WANDB-USER-EMAIL": env.get_user_email(),
        })
        self._file_policies = {}
        self._queue = queue.Queue()
        self._thread = threading.Thread(target=self._thread_except_body)
        # It seems we need to make this a daemon thread to get sync.py's atexit handler to run, which
        # cleans this thread up.
        self._thread.name = "FileStreamThread"
        self._thread.daemon = True
        self._init_endpoint()

    def _init_endpoint(self):
        settings = self._api.settings()
        settings.update(self._settings)
        self._endpoint = "{base}/files/{entity}/{project}/{run}/file_stream".format(
            base=settings["base_url"],
            entity=settings["entity"],
            project=settings["project"],
            run=self._run_id,
        )

    def start(self):
        self._init_endpoint()
        self._thread.start()

    def set_default_file_policy(self, filename, file_policy):
        """Set an upload policy for a file unless one has already been set.
        """
        if filename not in self._file_policies:
            self._file_policies[filename] = file_policy

    def set_file_policy(self, filename, file_policy):
        self._file_policies[filename] = file_policy

    @property
    def heartbeat_seconds(self):
        # Defaults to 30
        return self._api.dynamic_settings["heartbeat_seconds"]

    def rate_limit_seconds(self):
        run_time = time.time() - self._start_time
        if run_time < 60:
            return max(1, self.heartbeat_seconds / 15)
        elif run_time < 300:
            return max(2.5, self.heartbeat_seconds / 3)
        else:
            return max(5, self.heartbeat_seconds)

    def _read_queue(self):
        # called from the push thread (_thread_body), this does an initial read
        # that'll block for up to rate_limit_seconds. Then it tries to read
        # as much out of the queue as it can. We do this because the http post
        # to the server happens within _thread_body, and can take longer than
        # our rate limit. So next time we get a chance to read the queue we want
        # read all the stuff that queue'd up since last time.
        #
        # If we have more than MAX_ITEMS_PER_PUSH in the queue then the push thread
        # will get behind and data will buffer up in the queue.
        return util.read_many_from_queue(self._queue, self.MAX_ITEMS_PER_PUSH,
                                         self.rate_limit_seconds())

    def _thread_body(self):
        posted_data_time = time.time()
        posted_anything_time = time.time()
        ready_chunks = []
        finished = None
        while finished is None:
            items = self._read_queue()
            for item in items:
                if isinstance(item, self.Finish):
                    finished = item
                elif isinstance(item, self.Preempting):
                    request_with_retry(
                        self._client.post,
                        self._endpoint,
                        json={
                            "complete": False,
                            "preempting": True
                        },
                    )
                else:
                    # item is Chunk
                    ready_chunks.append(item)

            cur_time = time.time()

            if ready_chunks and (finished or cur_time - posted_data_time >
                                 self.rate_limit_seconds()):
                posted_data_time = cur_time
                posted_anything_time = cur_time
                self._send(ready_chunks)
                ready_chunks = []

            if cur_time - posted_anything_time > self.heartbeat_seconds:
                posted_anything_time = cur_time
                self._handle_response(
                    request_with_retry(
                        self._client.post,
                        self._endpoint,
                        json={
                            "complete": False,
                            "failed": False
                        },
                    ))
        # post the final close message. (item is self.Finish instance now)
        request_with_retry(
            self._client.post,
            self._endpoint,
            json={
                "complete": True,
                "exitcode": int(finished.exitcode)
            },
        )

    def _thread_except_body(self):
        # TODO: Consolidate with internal_util.ExceptionThread
        try:
            self._thread_body()
        except Exception as e:
            exc_info = sys.exc_info()
            self._exc_info = exc_info
            logger.exception("generic exception in filestream thread")
            util.sentry_exc(exc_info, delay=True)
            raise e

    def _handle_response(self, response):
        """Logs dropped chunks and updates dynamic settings"""
        if isinstance(response, Exception):
            wandb.termerror(
                "Droppped streaming file chunk (see wandb/debug.log)")
            logging.error("dropped chunk %s" % response)
            raise response
        else:
            parsed: dict = None
            try:
                parsed = response.json()
            except Exception:
                pass
            if isinstance(parsed, dict):
                limits = parsed.get("limits")
                if isinstance(limits, dict):
                    self._api.dynamic_settings.update(limits)

    def _send(self, chunks):
        # create files dict. dict of <filename: chunks> pairs where chunks is a list of
        # [chunk_id, chunk_data] tuples (as lists since this will be json).
        files = {}
        # Groupby needs group keys to be consecutive, so sort first.
        chunks.sort(key=lambda c: c.filename)
        for filename, file_chunks in itertools.groupby(chunks,
                                                       lambda c: c.filename):
            file_chunks = list(file_chunks)  # groupby returns iterator
            # Specific file policies are set by internal/sender.py
            self.set_default_file_policy(filename, DefaultFilePolicy())
            files[filename] = self._file_policies[filename].process_chunks(
                file_chunks)
            if not files[filename]:
                del files[filename]

        for fs in file_stream_utils.split_files(files, max_mb=10):
            self._handle_response(
                request_with_retry(
                    self._client.post,
                    self._endpoint,
                    json={"files": fs},
                    retry_callback=self._api.retry_callback,
                ))

    def stream_file(self, path):
        name = path.split("/")[-1]
        with open(path) as f:
            self._send([Chunk(name, line) for line in f])

    def enqueue_preempting(self):
        self._queue.put(self.Preempting())

    def push(self, filename, data):
        """Push a chunk of a file to the streaming endpoint.

        Arguments:
            filename: Name of file that this is a chunk of.
            chunk_id: TODO: change to 'offset'
            chunk: File data.
        """
        self._queue.put(Chunk(filename, data))

    def finish(self, exitcode):
        """Cleans up.

        Anything pushed after finish will be dropped.

        Arguments:
            exitcode: The exitcode of the watched process.
        """
        self._queue.put(self.Finish(exitcode))
        # TODO(jhr): join on a thread which exited with an exception is a noop, clean up this path
        self._thread.join()
        if self._exc_info:
            logger.error("FileStream exception", exc_info=self._exc_info)
            # reraising the original exception, will get recaught in internal.py for the sender thread
            six.reraise(*self._exc_info)
示例#3
0
class Api(object):
    """W&B Internal Api wrapper

    Note:
        Settings are automatically overridden by looking for
        a `wandb/settings` file in the current working directory or it's parent
        directory.  If none can be found, we look in the current users home
        directory.

    Arguments:
        default_settings(`dict`, optional): If you aren't using a settings
        file or you wish to override the section to use in the settings file
        Override the settings here.
    """

    HTTP_TIMEOUT = env.get_http_timeout(10)

    def __init__(
        self,
        default_settings=None,
        load_settings=True,
        retry_timedelta=None,
        environ=os.environ,
    ):
        if retry_timedelta is None:
            retry_timedelta = datetime.timedelta(days=1)
        self._environ = environ
        self.default_settings = {
            "section": "default",
            "git_remote": "origin",
            "ignore_globs": [],
            "base_url": "https://api.wandb.ai",
        }
        self.retry_timedelta = retry_timedelta
        self.default_settings.update(default_settings or {})
        self.retry_uploads = 10
        self._settings = Settings(
            load_settings=load_settings,
            root_dir=self.default_settings.get("root_dir"))
        # self.git = GitRepo(remote=self.settings("git_remote"))
        self.git = None
        # Mutable settings set by the _file_stream_api
        self.dynamic_settings = {
            "system_sample_seconds": 2,
            "system_samples": 15,
            "heartbeat_seconds": 30,
        }
        self.client = Client(transport=RequestsHTTPTransport(
            headers={
                "User-Agent": self.user_agent,
                "X-WANDB-USERNAME": env.get_username(env=self._environ),
                "X-WANDB-USER-EMAIL": env.get_user_email(env=self._environ),
            },
            use_json=True,
            # this timeout won't apply when the DNS lookup fails. in that case, it will be 60s
            # https://bugs.python.org/issue22889
            timeout=self.HTTP_TIMEOUT,
            auth=("api", self.api_key or ""),
            url="%s/graphql" % self.settings("base_url"),
        ))
        self.gql = retry.Retry(
            self.execute,
            retry_timedelta=retry_timedelta,
            check_retry_fn=util.no_retry_auth,
            retryable_exceptions=(RetryError, requests.RequestException),
        )
        self._current_run_id = None
        self._file_stream_api = None

    def reauth(self):
        """Ensures the current api key is set in the transport"""
        self.client.transport.auth = ("api", self.api_key or "")

    def relocate(self):
        """Ensures the current api points to the right server"""
        self.client.transport.url = "%s/graphql" % self.settings("base_url")

    def execute(self, *args, **kwargs):
        """Wrapper around execute that logs in cases of failure."""
        try:
            return self.client.execute(*args, **kwargs)
        except requests.exceptions.HTTPError as err:
            res = err.response
            logger.error("%s response executing GraphQL." % res.status_code)
            logger.error(res.text)
            self.display_gorilla_error_if_found(res)
            six.reraise(*sys.exc_info())

    def display_gorilla_error_if_found(self, res):
        try:
            data = res.json()
        except ValueError:
            return

        if "errors" in data and isinstance(data["errors"], list):
            for err in data["errors"]:
                if not err.get("message"):
                    continue
                wandb.termerror("Error while calling W&B API: {} ({})".format(
                    err["message"], res))

    def disabled(self):
        return self._settings.get(Settings.DEFAULT_SECTION,
                                  "disabled",
                                  fallback=False)

    def sync_spell(self, run, env=None):
        """Syncs this run with spell"""
        try:
            env = env or os.environ
            run.config["_wandb"]["spell_url"] = env.get("SPELL_RUN_URL")
            run.config.persist()
            try:
                url = run.get_url()
            except CommError as e:
                wandb.termerror("Unable to register run with spell.run: %s" %
                                str(e))
                return False
            return requests.put(
                env.get("SPELL_API_URL", "https://api.spell.run") +
                "/wandb_url",
                json={
                    "access_token": env.get("WANDB_ACCESS_TOKEN"),
                    "url": url
                },
                timeout=2,
            )
        except requests.RequestException:
            return False

    def save_patches(self, out_dir):
        """Save the current state of this repository to one or more patches.

        Makes one patch against HEAD and another one against the most recent
        commit that occurs in an upstream branch. This way we can be robust
        to history editing as long as the user never does "push -f" to break
        history on an upstream branch.

        Writes the first patch to <out_dir>/<DIFF_FNAME> and the second to
        <out_dir>/upstream_diff_<commit_id>.patch.

        Arguments:
            out_dir (str): Directory to write the patch files.
        """
        if not self.git.enabled:
            return False

        try:
            root = self.git.root
            if self.git.dirty:
                patch_path = os.path.join(out_dir,
                                          wandb_lib.filenames.DIFF_FNAME)
                if self.git.has_submodule_diff:
                    with open(patch_path, "wb") as patch:
                        # we diff against HEAD to ensure we get changes in the index
                        subprocess.check_call(
                            ["git", "diff", "--submodule=diff", "HEAD"],
                            stdout=patch,
                            cwd=root,
                            timeout=5,
                        )
                else:
                    with open(patch_path, "wb") as patch:
                        subprocess.check_call(["git", "diff", "HEAD"],
                                              stdout=patch,
                                              cwd=root,
                                              timeout=5)

            upstream_commit = self.git.get_upstream_fork_point()
            if upstream_commit and upstream_commit != self.git.repo.head.commit:
                sha = upstream_commit.hexsha
                upstream_patch_path = os.path.join(
                    out_dir, "upstream_diff_{}.patch".format(sha))
                if self.git.has_submodule_diff:
                    with open(upstream_patch_path, "wb") as upstream_patch:
                        subprocess.check_call(
                            ["git", "diff", "--submodule=diff", sha],
                            stdout=upstream_patch,
                            cwd=root,
                            timeout=5,
                        )
                else:
                    with open(upstream_patch_path, "wb") as upstream_patch:
                        subprocess.check_call(
                            ["git", "diff", sha],
                            stdout=upstream_patch,
                            cwd=root,
                            timeout=5,
                        )
        # TODO: A customer saw `ValueError: Reference at 'refs/remotes/origin/foo' does not exist`
        # so we now catch ValueError.  Catching this error feels too generic.
        except (
                ValueError,
                subprocess.CalledProcessError,
                subprocess.TimeoutExpired,
        ) as e:
            logger.error("Error generating diff: %s" % e)

    def set_current_run_id(self, run_id):
        self._current_run_id = run_id

    @property
    def current_run_id(self):
        return self._current_run_id

    @property
    def user_agent(self):
        return "W&B Internal Client %s" % __version__

    @property
    def api_key(self):
        auth = requests.utils.get_netrc_auth(self.api_url)
        key = None
        if auth:
            key = auth[-1]
        # Environment should take precedence
        if self._environ.get(env.API_KEY):
            key = self._environ.get(env.API_KEY)
        return key

    @property
    def api_url(self):
        return self.settings("base_url")

    @property
    def app_url(self):
        return wandb.util.app_url(self.api_url)

    def settings(self, key=None, section=None):
        """The settings overridden from the wandb/settings file.

        Arguments:
            key (str, optional): If provided only this setting is returned
            section (str, optional): If provided this section of the setting file is
            used, defaults to "default"

        Returns:
            A dict with the current settings

                {
                    "entity": "models",
                    "base_url": "https://api.wandb.ai",
                    "project": None
                }
        """
        result = self.default_settings.copy()
        result.update(self._settings.items(section=section))
        result.update({
            "entity":
            env.get_entity(
                self._settings.get(
                    Settings.DEFAULT_SECTION,
                    "entity",
                    fallback=result.get("entity"),
                ),
                env=self._environ,
            ),
            "project":
            env.get_project(
                self._settings.get(
                    Settings.DEFAULT_SECTION,
                    "project",
                    fallback=result.get("project"),
                ),
                env=self._environ,
            ),
            "base_url":
            env.get_base_url(
                self._settings.get(
                    Settings.DEFAULT_SECTION,
                    "base_url",
                    fallback=result.get("base_url"),
                ),
                env=self._environ,
            ),
            "ignore_globs":
            env.get_ignore(
                self._settings.get(
                    Settings.DEFAULT_SECTION,
                    "ignore_globs",
                    fallback=result.get("ignore_globs"),
                ),
                env=self._environ,
            ),
        })

        return result if key is None else result[key]

    def clear_setting(self, key, globally=False, persist=False):
        self._settings.clear(Settings.DEFAULT_SECTION,
                             key,
                             globally=globally,
                             persist=persist)

    def set_setting(self, key, value, globally=False, persist=False):
        self._settings.set(Settings.DEFAULT_SECTION,
                           key,
                           value,
                           globally=globally,
                           persist=persist)
        if key == "entity":
            env.set_entity(value, env=self._environ)
        elif key == "project":
            env.set_project(value, env=self._environ)
        elif key == "base_url":
            self.relocate()

    def parse_slug(self, slug, project=None, run=None):
        if slug and "/" in slug:
            parts = slug.split("/")
            project = parts[0]
            run = parts[1]
        else:
            project = project or self.settings().get("project")
            if project is None:
                raise CommError("No default project configured.")
            run = run or slug or env.get_run(env=self._environ)
            if run is None:
                run = "latest"
        return (project, run)

    @normalize_exceptions
    def viewer(self):
        query = gql("""
        query Viewer{
            viewer {
                id
                entity
                teams {
                    edges {
                        node {
                            name
                        }
                    }
                }
            }
        }
        """)
        res = self.gql(query)
        return res.get("viewer") or {}

    @normalize_exceptions
    def list_projects(self, entity=None):
        """Lists projects in W&B scoped by entity.

        Arguments:
            entity (str, optional): The entity to scope this project to.

        Returns:
                [{"id","name","description"}]
        """
        query = gql("""
        query Models($entity: String!) {
            models(first: 10, entityName: $entity) {
                edges {
                    node {
                        id
                        name
                        description
                    }
                }
            }
        }
        """)
        return self._flatten_edges(
            self.gql(
                query,
                variable_values={"entity": entity
                                 or self.settings("entity")})["models"])

    @normalize_exceptions
    def project(self, project, entity=None):
        """Retrive project

        Arguments:
            project (str): The project to get details for
            entity (str, optional): The entity to scope this project to.

        Returns:
                [{"id","name","repo","dockerImage","description"}]
        """
        query = gql("""
        query Models($entity: String, $project: String!) {
            model(name: $project, entityName: $entity) {
                id
                name
                repo
                dockerImage
                description
            }
        }
        """)
        return self.gql(query,
                        variable_values={
                            "entity": entity,
                            "project": project
                        })["model"]

    @normalize_exceptions
    def sweep(self, sweep, specs, project=None, entity=None):
        """Retrieve sweep.

        Arguments:
            sweep (str): The sweep to get details for
            specs (str): history specs
            project (str, optional): The project to scope this sweep to.
            entity (str, optional): The entity to scope this sweep to.

        Returns:
                [{"id","name","repo","dockerImage","description"}]
        """
        query = gql("""
        query Models($entity: String, $project: String!, $sweep: String!, $specs: [JSONString!]!) {
            model(name: $project, entityName: $entity) {
                sweep(sweepName: $sweep) {
                    id
                    name
                    method
                    state
                    description
                    config
                    createdAt
                    heartbeatAt
                    updatedAt
                    earlyStopJobRunning
                    bestLoss
                    controller
                    scheduler
                    runs {
                        edges {
                            node {
                                name
                                state
                                config
                                exitcode
                                heartbeatAt
                                shouldStop
                                failed
                                stopped
                                running
                                summaryMetrics
                                sampledHistory(specs: $specs)
                            }
                        }
                    }
                }
            }
        }
        """)
        entity = entity or self.settings("entity")
        project = project or self.settings("project")
        response = self.gql(
            query,
            variable_values={
                "entity": entity,
                "project": project,
                "sweep": sweep,
                "specs": specs,
            },
        )
        if response["model"] is None or response["model"]["sweep"] is None:
            raise ValueError("Sweep {}/{}/{} not found".format(
                entity, project, sweep))
        data = response["model"]["sweep"]
        if data:
            data["runs"] = self._flatten_edges(data["runs"])
        return data

    @normalize_exceptions
    def list_runs(self, project, entity=None):
        """Lists runs in W&B scoped by project.

        Arguments:
            project (str): The project to scope the runs to
            entity (str, optional): The entity to scope this project to.  Defaults to public models

        Returns:
                [{"id",name","description"}]
        """
        query = gql("""
        query Buckets($model: String!, $entity: String!) {
            model(name: $model, entityName: $entity) {
                buckets(first: 10) {
                    edges {
                        node {
                            id
                            name
                            displayName
                            description
                        }
                    }
                }
            }
        }
        """)
        return self._flatten_edges(
            self.gql(
                query,
                variable_values={
                    "entity": entity or self.settings("entity"),
                    "model": project or self.settings("project"),
                },
            )["model"]["buckets"])

    @normalize_exceptions
    def launch_run(self, command, project=None, entity=None, run_id=None):
        """Launch a run in the cloud.

        Arguments:
            command (str): The command to run
            program (str): The file to run
            project (str): The project to scope the runs to
            entity (str, optional): The entity to scope this project to.  Defaults to public models
            run_id (str, optional): The run_id to scope to

        Returns:
                [{"podName","status"}]
        """
        query = gql("""
        mutation launchRun(
            $entity: String
            $model: String
            $runId: String
            $image: String
            $command: String
            $patch: String
            $cwd: String
            $datasets: [String]
        ) {
            launchRun(input: {id: $runId, entityName: $entity, patch: $patch, modelName: $model,
                image: $image, command: $command, datasets: $datasets, cwd: $cwd}) {
                podName
                status
                runId
            }
        }
        """)
        run_id = run_id or self.current_run_id
        assert run_id, "run_id must be specified"
        patch = BytesIO()
        if self.git.dirty:
            self.git.repo.git.execute(["git", "diff"], output_stream=patch)
            patch.seek(0)
        cwd = "."
        if self.git.enabled:
            cwd = cwd + os.getcwd().replace(self.git.repo.working_dir, "")
        return self.gql(
            query,
            variable_values={
                "entity": entity or self.settings("entity"),
                "model": project or self.settings("project"),
                "command": command,
                "runId": run_id,
                "patch": patch.read().decode("utf8"),
                "cwd": cwd,
            },
        )

    @normalize_exceptions
    def run_config(self, project, run=None, entity=None):
        """Get the relevant configs for a run

        Arguments:
            project (str): The project to download, (can include bucket)
            run (str): The run to download
            entity (str, optional): The entity to scope this project to.
        """
        query = gql("""
        query Model($name: String!, $entity: String!, $run: String!) {
            model(name: $name, entityName: $entity) {
                bucket(name: $run) {
                    config
                    commit
                    patch
                    files(names: ["wandb-metadata.json"]) {
                        edges {
                            node {
                                url
                            }
                        }
                    }
                }
            }
        }
        """)
        run = run or self.current_run_id
        assert run, "run must be specified"
        response = self.gql(query,
                            variable_values={
                                "name": project,
                                "run": run,
                                "entity": entity
                            })
        if response["model"] is None:
            raise ValueError("Run {}/{}/{} not found".format(
                entity, project, run))
        run = response["model"]["bucket"]
        commit = run["commit"]
        patch = run["patch"]
        config = json.loads(run["config"] or "{}")
        if len(run["files"]["edges"]) > 0:
            url = run["files"]["edges"][0]["node"]["url"]
            res = requests.get(url)
            res.raise_for_status()
            metadata = res.json()
        else:
            metadata = {}
        return (commit, config, patch, metadata)

    @normalize_exceptions
    def run_resume_status(self, entity, project_name, name):
        """Check if a run exists and get resume information.

        Arguments:
            entity (str, optional): The entity to scope this project to.
            project_name (str): The project to download, (can include bucket)
            name (str): The run to download
        """
        query = gql("""
        query Model($project: String!, $entity: String, $name: String!) {
            model(name: $project, entityName: $entity) {
                id
                name
                entity {
                    id
                    name
                }

                bucket(name: $name, missingOk: true) {
                    id
                    name
                    summaryMetrics
                    displayName
                    logLineCount
                    historyLineCount
                    eventsLineCount
                    historyTail
                    eventsTail
                    config
                }
            }
        }
        """)

        response = self.gql(
            query,
            variable_values={
                "entity": entity,
                "project": project_name,
                "name": name,
            },
        )

        if "model" not in response or "bucket" not in (response["model"]
                                                       or {}):
            return None

        project = response["model"]
        self.set_setting("project", project_name)
        if "entity" in project:
            self.set_setting("entity", project["entity"]["name"])

        return project["bucket"]

    @normalize_exceptions
    def check_stop_requested(self, project_name, entity_name, run_id):
        query = gql("""
        query Model($projectName: String, $entityName: String, $runId: String!) {
            project(name:$projectName, entityName:$entityName) {
                run(name:$runId) {
                    stopped
                }
            }
        }
        """)
        run_id = run_id or self.current_run_id
        assert run_id, "run_id must be specified"
        response = self.gql(
            query,
            variable_values={
                "projectName": project_name,
                "entityName": entity_name,
                "runId": run_id,
            },
        )

        project = response.get("project", None)
        if not project:
            return False
        run = project.get("run", None)
        if not run:
            return False

        return run["stopped"]

    def format_project(self, project):
        return re.sub(r"\W+", "-", project.lower()).strip("-_")

    @normalize_exceptions
    def upsert_project(self, project, id=None, description=None, entity=None):
        """Create a new project

        Arguments:
            project (str): The project to create
            description (str, optional): A description of this project
            entity (str, optional): The entity to scope this project to.
        """
        mutation = gql("""
        mutation UpsertModel($name: String!, $id: String, $entity: String!, $description: String, $repo: String)  {
            upsertModel(input: { id: $id, name: $name, entityName: $entity, description: $description, repo: $repo }) {
                model {
                    name
                    description
                }
            }
        }
        """)
        response = self.gql(
            mutation,
            variable_values={
                "name": self.format_project(project),
                "entity": entity or self.settings("entity"),
                "description": description,
                "repo": self.git.remote_url,
                "id": id,
            },
        )
        return response["upsertModel"]["model"]

    @normalize_exceptions
    def pop_from_run_queue(self, entity=None, project=None):
        mutation = gql("""
        mutation popFromRunQueue($entity: String!, $project: String!)  {
            popFromRunQueue(input: { entityName: $entity, projectName: $project }) {
                runQueueItemId
                runSpec
            }
        }
        """)
        response = self.gql(mutation,
                            variable_values={
                                "entity": entity,
                                "project": project
                            })
        return response["popFromRunQueue"]

    @normalize_exceptions
    def upsert_run(
        self,
        id=None,
        name=None,
        project=None,
        host=None,
        group=None,
        tags=None,
        config=None,
        description=None,
        entity=None,
        state=None,
        display_name=None,
        notes=None,
        repo=None,
        job_type=None,
        program_path=None,
        commit=None,
        sweep_name=None,
        summary_metrics=None,
        num_retries=None,
    ):
        """Update a run

        Arguments:
            id (str, optional): The existing run to update
            name (str, optional): The name of the run to create
            group (str, optional): Name of the group this run is a part of
            project (str, optional): The name of the project
            config (dict, optional): The latest config params
            description (str, optional): A description of this project
            entity (str, optional): The entity to scope this project to.
            repo (str, optional): Url of the program's repository.
            state (str, optional): State of the program.
            job_type (str, optional): Type of job, e.g 'train'.
            program_path (str, optional): Path to the program.
            commit (str, optional): The Git SHA to associate the run with
            summary_metrics (str, optional): The JSON summary metrics
        """
        mutation = gql("""
        mutation UpsertBucket(
            $id: String, $name: String,
            $project: String,
            $entity: String!,
            $groupName: String,
            $description: String,
            $displayName: String,
            $notes: String,
            $commit: String,
            $config: JSONString,
            $host: String,
            $debug: Boolean,
            $program: String,
            $repo: String,
            $jobType: String,
            $state: String,
            $sweep: String,
            $tags: [String!],
            $summaryMetrics: JSONString,
        ) {
            upsertBucket(input: {
                id: $id,
                name: $name,
                groupName: $groupName,
                modelName: $project,
                entityName: $entity,
                description: $description,
                displayName: $displayName,
                notes: $notes,
                config: $config,
                commit: $commit,
                host: $host,
                debug: $debug,
                jobProgram: $program,
                jobRepo: $repo,
                jobType: $jobType,
                state: $state,
                sweep: $sweep,
                tags: $tags,
                summaryMetrics: $summaryMetrics,
            }) {
                bucket {
                    id
                    name
                    displayName
                    description
                    config
                    project {
                        id
                        name
                        entity {
                            id
                            name
                        }
                    }
                }
            }
        }
        """)
        if config is not None:
            config = json.dumps(config)
        if not description or description.isspace():
            description = None

        kwargs = {}
        if num_retries is not None:
            kwargs["num_retries"] = num_retries

        variable_values = {
            "id": id,
            "entity": entity or self.settings("entity"),
            "name": name,
            "project": project,
            "groupName": group,
            "tags": tags,
            "description": description,
            "config": config,
            "commit": commit,
            "displayName": display_name,
            "notes": notes,
            "host":
            None if self.settings().get("anonymous") == "true" else host,
            "debug": env.is_debug(env=self._environ),
            "repo": repo,
            "program": program_path,
            "jobType": job_type,
            "state": state,
            "sweep": sweep_name,
            "summaryMetrics": summary_metrics,
        }

        response = self.gql(mutation,
                            variable_values=variable_values,
                            **kwargs)

        run = response["upsertBucket"]["bucket"]
        project = run.get("project")
        if project:
            self.set_setting("project", project["name"])
            entity = project.get("entity")
            if entity:
                self.set_setting("entity", entity["name"])

        return response["upsertBucket"]["bucket"]

    @normalize_exceptions
    def upload_urls(self,
                    project,
                    files,
                    run=None,
                    entity=None,
                    description=None):
        """Generate temporary resumeable upload urls

        Arguments:
            project (str): The project to download
            files (list or dict): The filenames to upload
            run (str): The run to upload to
            entity (str, optional): The entity to scope this project to.  Defaults to wandb models

        Returns:
            (bucket_id, file_info)
            bucket_id: id of bucket we uploaded to
            file_info: A dict of filenames and urls, also indicates if this revision already has uploaded files.
                {
                    'weights.h5': { "url": "https://weights.url" },
                    'model.json': { "url": "https://model.json", "updatedAt": '2013-04-26T22:22:23.832Z', 'md5': 'mZFLkyvTelC5g8XnyQrpOw==' },
                }
        """
        query = gql("""
        query Model($name: String!, $files: [String]!, $entity: String!, $run: String!, $description: String) {
            model(name: $name, entityName: $entity) {
                bucket(name: $run, desc: $description) {
                    id
                    files(names: $files) {
                        uploadHeaders
                        edges {
                            node {
                                name
                                url(upload: true)
                                updatedAt
                            }
                        }
                    }
                }
            }
        }
        """)
        run_id = run or self.current_run_id
        assert run, "run must be specified"
        entity = entity or self.settings("entity")
        query_result = self.gql(
            query,
            variable_values={
                "name": project,
                "run": run_id,
                "entity": entity,
                "description": description,
                "files": [file for file in files],
            },
        )

        run = query_result["model"]["bucket"]
        if run:
            result = {
                file["name"]: file
                for file in self._flatten_edges(run["files"])
            }
            return run["id"], run["files"]["uploadHeaders"], result
        else:
            raise CommError("Run does not exist {}/{}/{}.".format(
                entity, project, run_id))

    @normalize_exceptions
    def download_urls(self, project, run=None, entity=None):
        """Generate download urls

        Arguments:
            project (str): The project to download
            run (str): The run to upload to
            entity (str, optional): The entity to scope this project to.  Defaults to wandb models

        Returns:
            A dict of extensions and urls

                {
                    'weights.h5': { "url": "https://weights.url", "updatedAt": '2013-04-26T22:22:23.832Z', 'md5': 'mZFLkyvTelC5g8XnyQrpOw==' },
                    'model.json': { "url": "https://model.url", "updatedAt": '2013-04-26T22:22:23.832Z', 'md5': 'mZFLkyvTelC5g8XnyQrpOw==' }
                }
        """
        query = gql("""
        query Model($name: String!, $entity: String!, $run: String!)  {
            model(name: $name, entityName: $entity) {
                bucket(name: $run) {
                    files {
                        edges {
                            node {
                                name
                                url
                                md5
                                updatedAt
                            }
                        }
                    }
                }
            }
        }
        """)
        run = run or self.current_run_id
        assert run, "run must be specified"
        query_result = self.gql(
            query,
            variable_values={
                "name": project,
                "run": run,
                "entity": entity or self.settings("entity"),
            },
        )
        files = self._flatten_edges(query_result["model"]["bucket"]["files"])
        return {file["name"]: file for file in files if file}

    @normalize_exceptions
    def download_url(self, project, file_name, run=None, entity=None):
        """Generate download urls

        Arguments:
            project (str): The project to download
            file_name (str): The name of the file to download
            run (str): The run to upload to
            entity (str, optional): The entity to scope this project to.  Defaults to wandb models

        Returns:
            A dict of extensions and urls

                { "url": "https://weights.url", "updatedAt": '2013-04-26T22:22:23.832Z', 'md5': 'mZFLkyvTelC5g8XnyQrpOw==' }

        """
        query = gql("""
        query Model($name: String!, $fileName: String!, $entity: String!, $run: String!)  {
            model(name: $name, entityName: $entity) {
                bucket(name: $run) {
                    files(names: [$fileName]) {
                        edges {
                            node {
                                name
                                url
                                md5
                                updatedAt
                            }
                        }
                    }
                }
            }
        }
        """)
        run = run or self.current_run_id
        assert run, "run must be specified"
        query_result = self.gql(
            query,
            variable_values={
                "name": project,
                "run": run,
                "fileName": file_name,
                "entity": entity or self.settings("entity"),
            },
        )
        if query_result["model"]:
            files = self._flatten_edges(
                query_result["model"]["bucket"]["files"])
            return files[0] if len(files) > 0 and files[0].get(
                "updatedAt") else None
        else:
            return None

    @normalize_exceptions
    def download_file(self, url):
        """Initiate a streaming download

        Arguments:
            url (str): The url to download

        Returns:
            A tuple of the content length and the streaming response
        """
        response = requests.get(url, stream=True)
        response.raise_for_status()
        return (int(response.headers.get("content-length", 0)), response)

    @normalize_exceptions
    def download_write_file(self, metadata, out_dir=None):
        """Download a file from a run and write it to wandb/

        Arguments:
            metadata (obj): The metadata object for the file to download. Comes from Api.download_urls().

        Returns:
            A tuple of the file's local path and the streaming response. The streaming response is None if the file already existed and was up to date.
        """
        file_name = metadata["name"]
        path = os.path.join(out_dir or self.settings("wandb_dir"), file_name)
        if self.file_current(file_name, metadata["md5"]):
            return path, None

        size, response = self.download_file(metadata["url"])

        with util.fsync_open(path, "wb") as file:
            for data in response.iter_content(chunk_size=1024):
                file.write(data)

        return path, response

    @normalize_exceptions
    def register_agent(self,
                       host,
                       sweep_id=None,
                       project_name=None,
                       entity=None):
        """Register a new agent

        Arguments:
            host (str): hostname
            persistent (bool): long running or oneoff
            sweep (str): sweep id
            project_name: (str): model that contains sweep
        """
        mutation = gql("""
        mutation CreateAgent(
            $host: String!
            $projectName: String!,
            $entityName: String!,
            $sweep: String!
        ) {
            createAgent(input: {
                host: $host,
                projectName: $projectName,
                entityName: $entityName,
                sweep: $sweep,
            }) {
                agent {
                    id
                }
            }
        }
        """)
        if entity is None:
            entity = self.settings("entity")
        if project_name is None:
            project_name = self.settings("project")

        # don't retry on validation or not found errors
        def no_retry_4xx(e):
            if not isinstance(e, requests.HTTPError):
                return True
            if not (e.response.status_code >= 400
                    and e.response.status_code < 500):
                return True
            body = json.loads(e.response.content)
            raise UsageError(body["errors"][0]["message"])

        response = self.gql(
            mutation,
            variable_values={
                "host": host,
                "entityName": entity,
                "projectName": project_name,
                "sweep": sweep_id,
            },
            check_retry_fn=no_retry_4xx,
        )
        return response["createAgent"]["agent"]

    def agent_heartbeat(self, agent_id, metrics, run_states):
        """Notify server about agent state, receive commands.

        Arguments:
            agent_id (str): agent_id
            metrics (dict): system metrics
            run_states (dict): run_id: state mapping
        Returns:
            List of commands to execute.
        """
        mutation = gql("""
        mutation Heartbeat(
            $id: ID!,
            $metrics: JSONString,
            $runState: JSONString
        ) {
            agentHeartbeat(input: {
                id: $id,
                metrics: $metrics,
                runState: $runState
            }) {
                agent {
                    id
                }
                commands
            }
        }
        """)
        try:
            response = self.gql(
                mutation,
                variable_values={
                    "id": agent_id,
                    "metrics": json.dumps(metrics),
                    "runState": json.dumps(run_states),
                },
            )
        except Exception as e:
            # GQL raises exceptions with stringified python dictionaries :/
            message = ast.literal_eval(e.args[0])["message"]
            logger.error("Error communicating with W&B: %s", message)
            return []
        else:
            return json.loads(response["agentHeartbeat"]["commands"])

    @normalize_exceptions
    def upsert_sweep(
        self,
        config,
        controller=None,
        scheduler=None,
        obj_id=None,
        project=None,
        entity=None,
    ):
        """Upsert a sweep object.

        Arguments:
            config (str): sweep config (will be converted to yaml)
        """
        project_query = """
                    project {
                        id
                        name
                        entity {
                            id
                            name
                        }
                    }
        """
        mutation_str = """
        mutation UpsertSweep(
            $id: ID,
            $config: String,
            $description: String,
            $entityName: String!,
            $projectName: String!,
            $controller: JSONString,
            $scheduler: JSONString
        ) {
            upsertSweep(input: {
                id: $id,
                config: $config,
                description: $description,
                entityName: $entityName,
                projectName: $projectName,
                controller: $controller,
                scheduler: $scheduler
            }) {
                sweep {
                    name
                    _PROJECT_QUERY_
                }
            }
        }
        """
        # TODO(jhr): we need protocol versioning to know schema is not supported
        # for now we will just try both new and old query
        mutation_new = gql(
            mutation_str.replace("_PROJECT_QUERY_", project_query))
        mutation_old = gql(mutation_str.replace("_PROJECT_QUERY_", ""))

        # don't retry on validation errors
        # TODO(jhr): generalize error handling routines
        def no_retry_4xx(e):
            if not isinstance(e, requests.HTTPError):
                return True
            if not (e.response.status_code >= 400
                    and e.response.status_code < 500):
                return True
            body = json.loads(e.response.content)
            raise UsageError(body["errors"][0]["message"])

        for mutation in mutation_new, mutation_old:
            try:
                response = self.gql(
                    mutation,
                    variable_values={
                        "id": obj_id,
                        "config": yaml.dump(config),
                        "description": config.get("description"),
                        "entityName": entity or self.settings("entity"),
                        "projectName": project or self.settings("project"),
                        "controller": controller,
                        "scheduler": scheduler,
                    },
                    check_retry_fn=no_retry_4xx,
                )
            except UsageError as e:
                raise (e)
            except Exception as e:
                # graphql schema exception is generic
                err = e
                continue
            err = None
            break
        if err:
            raise (err)

        sweep = response["upsertSweep"]["sweep"]
        project = sweep.get("project")
        if project:
            self.set_setting("project", project["name"])
            entity = project.get("entity")
            if entity:
                self.set_setting("entity", entity["name"])

        return response["upsertSweep"]["sweep"]["name"]

    @normalize_exceptions
    def create_anonymous_api_key(self):
        """Creates a new API key belonging to a new anonymous user."""
        mutation = gql("""
        mutation CreateAnonymousApiKey {
            createAnonymousEntity(input: {}) {
                apiKey {
                    name
                }
            }
        }
        """)

        response = self.gql(mutation, variable_values={})
        return response["createAnonymousEntity"]["apiKey"]["name"]

    def file_current(self, fname, md5):
        """Checksum a file and compare the md5 with the known md5"""
        return os.path.isfile(fname) and util.md5_file(fname) == md5

    @normalize_exceptions
    def pull(self, project, run=None, entity=None):
        """Download files from W&B

        Arguments:
            project (str): The project to download
            run (str): The run to upload to
            entity (str, optional): The entity to scope this project to.  Defaults to wandb models

        Returns:
            The requests library response object
        """
        project, run = self.parse_slug(project, run=run)
        assert run, "run must be specified"
        urls = self.download_urls(project, run, entity)
        responses = []
        for file_name in urls:
            _, response = self.download_write_file(urls[file_name])
            if response:
                responses.append(response)

        return responses

    def get_project(self):
        return self.settings("project")

    def _status_request(self, url, length):
        """Ask google how much we've uploaded"""
        return requests.put(
            url=url,
            headers={
                "Content-Length": "0",
                "Content-Range": "bytes */%i" % length
            },
        )

    def _flatten_edges(self, response):
        """Return an array from the nested graphql relay structure"""
        return [node["node"] for node in response["edges"]]
示例#4
0
class Api(object):
    """
    Used for querying the wandb server.

    Examples:
        Most common way to initialize
        ```
            wandb.Api()
        ```

    Args:
        overrides (dict): You can set `base_url` if you are using a wandb server
            other than https://api.wandb.ai.  
            You can also set defaults for `entity`, `project`, and `run`.
    """

    _HTTP_TIMEOUT = env.get_http_timeout(9)

    def __init__(self, overrides={}):
        self.settings = {
            'entity': None,
            'project': None,
            'run': "latest",
            'base_url': env.get_base_url("https://api.wandb.ai")
        }
        self.settings.update(overrides)
        if 'username' in overrides and 'entity' not in overrides:
            wandb.termwarn(
                'Passing "username" to Api is deprecated. please use "entity" instead.'
            )
            self.settings['entity'] = overrides['username']
        self._projects = {}
        self._runs = {}
        self._sweeps = {}
        self._base_client = Client(transport=RequestsHTTPTransport(
            headers={
                'User-Agent': self.user_agent,
                'Use-Admin-Privileges': "true"
            },
            use_json=True,
            # this timeout won't apply when the DNS lookup fails. in that case, it will be 60s
            # https://bugs.python.org/issue22889
            timeout=self._HTTP_TIMEOUT,
            auth=("api", self.api_key),
            url='%s/graphql' % self.settings['base_url']))
        self._client = RetryingClient(self._base_client)

    def create_run(self, **kwargs):
        return Run.create(self, **kwargs)

    @property
    def client(self):
        return self._client

    @property
    def user_agent(self):
        return 'W&B Public Client %s' % __version__

    @property
    def api_key(self):
        auth = requests.utils.get_netrc_auth(self.settings['base_url'])
        key = None
        if auth:
            key = auth[-1]
        # Environment should take precedence
        if os.getenv("WANDB_API_KEY"):
            key = os.environ["WANDB_API_KEY"]
        return key

    def flush(self):
        """
        The api object keeps a local cache of runs, so if the state of the run may 
            change while executing your script you must clear the local cache with `api.flush()`
            to get the latest values associated with the run."""
        self._runs = {}

    def _parse_path(self, path):
        """Parses paths in the following formats:

        url: entity/project/runs/run_id
        path: entity/project/run_id
        docker: entity/project:run_id

        entity is optional and will fallback to the current logged in user.
        """
        run = self.settings['run']
        project = self.settings['project']
        entity = self.settings['entity']
        parts = path.replace("/runs/", "/").strip("/ ").split("/")
        if ":" in parts[-1]:
            run = parts[-1].split(":")[-1]
            parts[-1] = parts[-1].split(":")[0]
        elif parts[-1]:
            run = parts[-1]
        if len(parts) > 1:
            project = parts[1]
            if entity and run == project:
                project = parts[0]
            else:
                entity = parts[0]
        else:
            project = parts[0]
        return (entity, project, run)

    def projects(self, entity=None, per_page=None):
        """Get projects for a given entity.
        Args:
            entity (str): Name of the entity requested.  If None will fallback to 
                default entity passed to :obj:`Api`.  If no default entity, will raise a `ValueError`.
            per_page (int): Sets the page size for query pagination.  None will use the default size.
                Usually there is no reason to change this.

        Returns:
            A :obj:`Projects` object which is an iterable collection of :obj:`Project` objects.
        
        """
        if entity is None:
            entity = self.settings['entity']
            if entity is None:
                raise ValueError(
                    'entity must be passed as a parameter, or set in settings')
        if entity not in self._projects:
            self._projects[entity] = Projects(self.client,
                                              entity,
                                              per_page=per_page)
        return self._projects[entity]

    def runs(self, path="", filters={}, order="-created_at", per_page=None):
        """Return a set of runs from a project that match the filters provided.
        You can filter by `config.*`, `summary.*`, `state`, `entity`, `createdAt`, etc.

        Examples:
            Find runs in my_project config.experiment_name has been set to "foo"
            ```
            api.runs(path="my_entity/my_project", {"config.experiment_name": "foo"})
            ```

            Find runs in my_project config.experiment_name has been set to "foo" or "bar"
            ```
            api.runs(path="my_entity/my_project", 
                {"$or": [{"config.experiment_name": "foo"}, {"config.experiment_name": "bar"}]})
            ```

            Find runs in my_project sorted by ascending loss
            ```
            api.runs(path="my_entity/my_project", {"order": "+summary.loss"})
            ```             


        Args:
            path (str): path to project, should be in the form: "entity/project"
            filters (dict): queries for specific runs using the MongoDB query language.
                You can filter by run properties such as config.key, summary.key, state, entity, createdAt, etc.
                For example: {"config.experiment_name": "foo"} would find runs with a config entry
                    of experiment name set to "foo"
                You can compose operations to make more complicated queries,
                    see Reference for the language is at  https://docs.mongodb.com/manual/reference/operator/query
            order (str): Order can be `created_at`, `heartbeat_at`, `config.*.value`, or `summary.*`.  
                If you prepend order with a + order is ascending.
                If you prepend order with a - order is descending (default).
                The default order is run.created_at from newest to oldest.

        Returns:
            A :obj:`Runs` object, which is an iterable collection of :obj:`Run` objects.
        """
        entity, project, run = self._parse_path(path)
        if not self._runs.get(path):
            self._runs[path + str(filters) + str(order)] = Runs(
                self.client,
                entity,
                project,
                filters=filters,
                order=order,
                per_page=per_page)
        return self._runs[path + str(filters) + str(order)]

    @normalize_exceptions
    def run(self, path=""):
        """Returns a single run by parsing path in the form entity/project/run_id.
        
        Args:
            path (str): path to run in the form entity/project/run_id.  
                If api.entity is set, this can be in the form project/run_id 
                and if api.project is set this can just be the run_id.
        
        Returns:
            A :obj:`Run` object.
        """
        entity, project, run = self._parse_path(path)
        if not self._runs.get(path):
            self._runs[path] = Run(self.client, entity, project, run)
        return self._runs[path]

    @normalize_exceptions
    def sweep(self, path=""):
        """
        Returns a sweep by parsing path in the form entity/project/sweep_id.
        
        Args:
            path (str, optional): path to sweep in the form entity/project/sweep_id.  If api.entity
                is set, this can be in the form project/sweep_id and if api.project is set
                this can just be the sweep_id.
        
        Returns:
            A :obj:`Sweep` object.
        """
        entity, project, sweep_id = self._parse_path(path)
        if not self._sweeps.get(sweep_id):
            self._sweeps[path] = Sweep(self.client, entity, project, sweep_id)
        return self._sweeps[path]
示例#5
0
class FileStreamApi(object):
    """Pushes chunks of files to our streaming endpoint.

    This class is used as a singleton. It has a thread that serializes access to
    the streaming endpoint and performs rate-limiting and batching.

    TODO: Differentiate between binary/text encoding.
    """
    Finish = collections.namedtuple('Finish', ('exitcode'))

    HTTP_TIMEOUT = env.get_http_timeout(10)
    MAX_ITEMS_PER_PUSH = 10000

    def __init__(self, api, run_id):
        self._api = api
        self._run_id = run_id
        self._client = requests.Session()
        self._client.auth = ('api', api.api_key)
        self._client.timeout = self.HTTP_TIMEOUT
        self._client.headers.update({
            'User-Agent': api.user_agent,
            'X-WANDB-USERNAME': env.get_username()
        })
        self._file_policies = {}
        self._queue = queue.Queue()
        self._thread = threading.Thread(target=self._thread_body)
        # It seems we need to make this a daemon thread to get sync.py's atexit handler to run, which
        # cleans this thread up.
        self._thread.daemon = True
        self._init_endpoint()

    def _init_endpoint(self):
        settings = self._api.settings()
        self._endpoint = "{base}/files/{entity}/{project}/{run}/file_stream".format(
            base=settings['base_url'],
            entity=settings['entity'],
            project=settings['project'],
            run=self._run_id)

    def start(self):
        self._init_endpoint()
        self._thread.start()

    def set_default_file_policy(self, filename, file_policy):
        """Set an upload policy for a file unless one has already been set.
        """
        if filename not in self._file_policies:
            self._file_policies[filename] = file_policy

    def set_file_policy(self, filename, file_policy):
        self._file_policies[filename] = file_policy

    @property
    def heartbeat_seconds(self):
        # Defaults to 30
        return self._api.dynamic_settings["heartbeat_seconds"]

    def rate_limit_seconds(self):
        run_time = time.time() - wandb.START_TIME
        if run_time < 60:
            return max(1, self.heartbeat_seconds / 15)
        elif run_time < 300:
            return max(2.5, self.heartbeat_seconds / 3)
        else:
            return max(5, self.heartbeat_seconds)

    def _read_queue(self):
        # called from the push thread (_thread_body), this does an initial read
        # that'll block for up to rate_limit_seconds. Then it tries to read
        # as much out of the queue as it can. We do this because the http post
        # to the server happens within _thread_body, and can take longer than
        # our rate limit. So next time we get a chance to read the queue we want
        # read all the stuff that queue'd up since last time.
        #
        # If we have more than MAX_ITEMS_PER_PUSH in the queue then the push thread
        # will get behind and data will buffer up in the queue.
        return util.read_many_from_queue(self._queue, self.MAX_ITEMS_PER_PUSH,
                                         self.rate_limit_seconds())

    def _thread_body(self):
        posted_data_time = time.time()
        posted_anything_time = time.time()
        ready_chunks = []
        finished = None
        while finished is None:
            items = self._read_queue()
            for item in items:
                if isinstance(item, self.Finish):
                    finished = item
                else:
                    # item is Chunk
                    ready_chunks.append(item)

            cur_time = time.time()

            if ready_chunks and (finished or cur_time - posted_data_time >
                                 self.rate_limit_seconds()):
                posted_data_time = cur_time
                posted_anything_time = cur_time
                self._send(ready_chunks)
                ready_chunks = []

            if cur_time - posted_anything_time > self.heartbeat_seconds:
                posted_anything_time = cur_time
                self._handle_response(
                    util.request_with_retry(self._client.post,
                                            self._endpoint,
                                            json={
                                                'complete': False,
                                                'failed': False
                                            }))
        # post the final close message. (item is self.Finish instance now)
        util.request_with_retry(self._client.post,
                                self._endpoint,
                                json={
                                    'complete': True,
                                    'exitcode': int(finished.exitcode)
                                })

    def _handle_response(self, response):
        """Logs dropped chunks and updates dynamic settings"""
        if isinstance(response, Exception):
            logging.error("dropped chunk %s" % response)
        elif response.json().get("limits"):
            parsed = response.json()
            self._api.dynamic_settings.update(parsed["limits"])

    def _send(self, chunks):
        # create files dict. dict of <filename: chunks> pairs where chunks is a list of
        # [chunk_id, chunk_data] tuples (as lists since this will be json).
        files = {}
        # Groupby needs group keys to be consecutive, so sort first.
        chunks.sort(key=lambda c: c.filename)
        #print('fsapi', chunks)
        for filename, file_chunks in itertools.groupby(chunks,
                                                       lambda c: c.filename):
            file_chunks = list(file_chunks)  # groupby returns iterator
            self.set_default_file_policy(filename, DefaultFilePolicy())
            files[filename] = self._file_policies[filename].process_chunks(
                file_chunks)
        self._handle_response(
            util.request_with_retry(self._client.post,
                                    self._endpoint,
                                    json={'files': files}))

    def stream_file(self, path):
        name = path.split("/")[-1]
        self._send([Chunk(name, line) for line in open(path).readlines()])

    def push(self, filename, data):
        """Push a chunk of a file to the streaming endpoint.

        Args:
            filename: Name of file that this is a chunk of.
            chunk_id: TODO: change to 'offset'
            chunk: File data.
        """
        self._queue.put(Chunk(filename, data))

    def finish(self, exitcode):
        """Cleans up.

        Anything pushed after finish will be dropped.

        Args:
            exitcode: The exitcode of the watched process.
        """
        self._queue.put(self.Finish(exitcode))
        self._thread.join()
示例#6
0
class Api(object):
    """W&B Public API

    Args:
        setting_overrides(:obj:`dict`, optional): You can set defaults such as
        entity, project, and run here as well as which api server to use.
    """

    HTTP_TIMEOUT = env.get_http_timeout(9)

    def __init__(self, overrides={}):
        self.settings = {
            'entity': None,
            'project': None,
            'run': "latest",
            'base_url': env.get_base_url("https://api.wandb.ai")
        }
        self.settings.update(overrides)
        if 'username' in overrides and 'entity' not in overrides:
            wandb.termwarn('Passing "username" to Api is deprecated. please use "entity" instead.')
            self.settings['entity'] = overrides['username']
        self._projects = {}
        self._runs = {}
        self._sweeps = {}
        self._base_client = Client(
            transport=RequestsHTTPTransport(
                headers={'User-Agent': self.user_agent, 'Use-Admin-Privileges': "true"},
                use_json=True,
                # this timeout won't apply when the DNS lookup fails. in that case, it will be 60s
                # https://bugs.python.org/issue22889
                timeout=self.HTTP_TIMEOUT,
                auth=("api", self.api_key),
                url='%s/graphql' % self.settings['base_url']
            )
        )
        self._client = RetryingClient(self._base_client)

    def create_run(self, **kwargs):
        return Run.create(self, **kwargs)

    @property
    def client(self):
        return self._client

    @property
    def user_agent(self):
        return 'W&B Public Client %s' % __version__

    @property
    def api_key(self):
        auth = requests.utils.get_netrc_auth(self.settings['base_url'])
        key = None
        if auth:
            key = auth[-1]
        # Environment should take precedence
        if os.getenv("WANDB_API_KEY"):
            key = os.environ["WANDB_API_KEY"]
        return key

    def flush(self):
        """Clear the local cache"""
        self._runs = {}

    def _parse_path(self, path):
        """Parses paths in the following formats:

        url: entity/project/runs/run_id
        path: entity/project/run_id
        docker: entity/project:run_id

        entity is optional and will fallback to the current logged in user.
        """
        run = self.settings['run']
        project = self.settings['project']
        entity = self.settings['entity']
        parts = path.replace("/runs/", "/").strip("/ ").split("/")
        if ":" in parts[-1]:
            run = parts[-1].split(":")[-1]
            parts[-1] = parts[-1].split(":")[0]
        elif parts[-1]:
            run = parts[-1]
        if len(parts) > 1:
            project = parts[1]
            if entity and run == project:
                project = parts[0]
            else:
                entity = parts[0]
        else:
            project = parts[0]
        return (entity, project, run)

    def projects(self, entity=None, per_page=None):
        """Return a list of projects for a given entity."""
        if entity is None:
            entity = self.settings['entity']
            if entity is None:
                raise ValueError('entity must be passed as a parameter, or set in settings')
        if entity not in self._projects:
            self._projects[entity] = Projects(self.client, entity, per_page=per_page)
        return self._projects[entity]

    def runs(self, path="", filters={}, order="-created_at", per_page=None):
        """Return a set of runs from a project that match the filters provided.
        You can filter by config.*, summary.*, state, entity, createdAt, etc.

        The filters use the same query language as MongoDB:

        https://docs.mongodb.com/manual/reference/operator/query

        Order can be created_at, heartbeat_at, config.*.value, or summary.*.  By default
        the order is descending, if you prepend order with a + order becomes ascending.
        """
        entity, project, run = self._parse_path(path)
        if not self._runs.get(path):
            self._runs[path + str(filters) + str(order)] = Runs(self.client, entity, project,
                                                                filters=filters, order=order, per_page=per_page)
        return self._runs[path + str(filters) + str(order)]

    @normalize_exceptions
    def run(self, path=""):
        """Returns a run by parsing path in the form entity/project/run, if
        defaults were set on the Api, only overrides what's passed.  I.E. you can just pass
        run_id if you set entity and project on the Api"""
        entity, project, run = self._parse_path(path)
        if not self._runs.get(path):
            self._runs[path] = Run(self.client, entity, project, run)
        return self._runs[path]

    @normalize_exceptions
    def sweep(self, path=""):
        entity, project, sweep_id = self._parse_path(path)
        if not self._sweeps.get(sweep_id):
            self._sweeps[path] = Sweep(self.client, entity, project, sweep_id)
        return self._sweeps[path]
示例#7
0
class Api(object):
    """W&B Internal Api wrapper

    Note:
        Settings are automatically overridden by looking for
        a `wandb/settings` file in the current working directory or it's parent
        directory.  If none can be found, we look in the current users home
        directory.

    Args:
        default_settings(:obj:`dict`, optional): If you aren't using a settings
        file or you wish to override the section to use in the settings file
        Override the settings here.
    """

    HTTP_TIMEOUT = env.get_http_timeout(10)

    def __init__(
        self,
        default_settings=None,
        load_settings=True,
        retry_timedelta=None,
        environ=os.environ,
    ):
        if retry_timedelta is None:
            retry_timedelta = datetime.timedelta(days=1)
        self._environ = environ
        self.default_settings = {
            "section": "default",
            "git_remote": "origin",
            "ignore_globs": [],
            "base_url": "https://api.wandb.ai",
        }
        self.retry_timedelta = retry_timedelta
        self.default_settings.update(default_settings or {})
        self.retry_uploads = 10
        self._settings = Settings(
            load_settings=load_settings,
            root_dir=self.default_settings.get("root_dir"))
        # self.git = GitRepo(remote=self.settings("git_remote"))
        self.git = None
        # Mutable settings set by the _file_stream_api
        self.dynamic_settings = {
            "system_sample_seconds": 2,
            "system_samples": 15,
            "heartbeat_seconds": 30,
        }
        self.client = Client(transport=RequestsHTTPTransport(
            headers={
                "User-Agent": self.user_agent,
                "X-WANDB-USERNAME": env.get_username(env=self._environ),
                "X-WANDB-USER-EMAIL": env.get_user_email(env=self._environ),
            },
            use_json=True,
            # this timeout won't apply when the DNS lookup fails. in that case, it will be 60s
            # https://bugs.python.org/issue22889
            timeout=self.HTTP_TIMEOUT,
            auth=("api", self.api_key or ""),
            url="%s/graphql" % self.settings("base_url"),
        ))
        self.gql = retry.Retry(
            self.execute,
            retry_timedelta=retry_timedelta,
            check_retry_fn=util.no_retry_auth,
            retryable_exceptions=(RetryError, requests.RequestException),
        )
        self._current_run_id = None
        self._file_stream_api = None

    def reauth(self):
        """Ensures the current api key is set in the transport"""
        self.client.transport.auth = ("api", self.api_key or "")

    def relocate(self):
        """Ensures the current api points to the right server"""
        self.client.transport.url = "%s/graphql" % self.settings("base_url")

    def execute(self, *args, **kwargs):
        """Wrapper around execute that logs in cases of failure."""
        try:
            return self.client.execute(*args, **kwargs)
        except requests.exceptions.HTTPError as err:
            res = err.response
            logger.error("%s response executing GraphQL." % res.status_code)
            logger.error(res.text)
            self.display_gorilla_error_if_found(res)
            six.reraise(*sys.exc_info())

    def display_gorilla_error_if_found(self, res):
        try:
            data = res.json()
        except ValueError:
            return

        if "errors" in data and isinstance(data["errors"], list):
            for err in data["errors"]:
                if not err.get("message"):
                    continue
                wandb.termerror("Error while calling W&B API: {} ({})".format(
                    err["message"], res))

    def disabled(self):
        return self._settings.get(Settings.DEFAULT_SECTION,
                                  "disabled",
                                  fallback=False)

    def sync_spell(self, run, env=None):
        """Syncs this run with spell"""
        try:
            env = env or os.environ
            run.config["_wandb"]["spell_url"] = env.get("SPELL_RUN_URL")
            run.config.persist()
            try:
                url = run.get_url()
            except CommError as e:
                wandb.termerror("Unable to register run with spell.run: %s" %
                                str(e))
                return False
            return requests.put(
                env.get("SPELL_API_URL", "https://api.spell.run") +
                "/wandb_url",
                json={
                    "access_token": env.get("WANDB_ACCESS_TOKEN"),
                    "url": url
                },
                timeout=2,
            )
        except requests.RequestException:
            return False

    def save_patches(self, out_dir):
        """Save the current state of this repository to one or more patches.

        Makes one patch against HEAD and another one against the most recent
        commit that occurs in an upstream branch. This way we can be robust
        to history editing as long as the user never does "push -f" to break
        history on an upstream branch.

        Writes the first patch to <out_dir>/<DIFF_FNAME> and the second to
        <out_dir>/upstream_diff_<commit_id>.patch.

        Args:
            out_dir (str): Directory to write the patch files.
        """
        if not self.git.enabled:
            return False

        try:
            root = self.git.root
            if self.git.dirty:
                patch_path = os.path.join(out_dir,
                                          wandb_lib.filenames.DIFF_FNAME)
                if self.git.has_submodule_diff:
                    with open(patch_path, "wb") as patch:
                        # we diff against HEAD to ensure we get changes in the index
                        subprocess.check_call(
                            ["git", "diff", "--submodule=diff", "HEAD"],
                            stdout=patch,
                            cwd=root,
                            timeout=5,
                        )
                else:
                    with open(patch_path, "wb") as patch:
                        subprocess.check_call(["git", "diff", "HEAD"],
                                              stdout=patch,
                                              cwd=root,
                                              timeout=5)

            upstream_commit = self.git.get_upstream_fork_point()
            if upstream_commit and upstream_commit != self.git.repo.head.commit:
                sha = upstream_commit.hexsha
                upstream_patch_path = os.path.join(
                    out_dir, "upstream_diff_{}.patch".format(sha))
                if self.git.has_submodule_diff:
                    with open(upstream_patch_path, "wb") as upstream_patch:
                        subprocess.check_call(
                            ["git", "diff", "--submodule=diff", sha],
                            stdout=upstream_patch,
                            cwd=root,
                            timeout=5,
                        )
                else:
                    with open(upstream_patch_path, "wb") as upstream_patch:
                        subprocess.check_call(
                            ["git", "diff", sha],
                            stdout=upstream_patch,
                            cwd=root,
                            timeout=5,
                        )
        # TODO: A customer saw `ValueError: Reference at 'refs/remotes/origin/foo' does not exist`
        # so we now catch ValueError.  Catching this error feels too generic.
        except (
                ValueError,
                subprocess.CalledProcessError,
                subprocess.TimeoutExpired,
        ) as e:
            logger.error("Error generating diff: %s" % e)

    def set_current_run_id(self, run_id):
        self._current_run_id = run_id

    @property
    def current_run_id(self):
        return self._current_run_id

    @property
    def user_agent(self):
        return "W&B Internal Client %s" % __version__

    @property
    def api_key(self):
        auth = requests.utils.get_netrc_auth(self.api_url)
        key = None
        if auth:
            key = auth[-1]
        # Environment should take precedence
        if self._environ.get(env.API_KEY):
            key = self._environ.get(env.API_KEY)
        return key

    @property
    def api_url(self):
        return self.settings("base_url")

    @property
    def app_url(self):
        return wandb.util.app_url(self.api_url)

    def settings(self, key=None, section=None):
        """The settings overridden from the wandb/settings file.

        Args:
            key (str, optional): If provided only this setting is returned
            section (str, optional): If provided this section of the setting file is
            used, defaults to "default"

        Returns:
            A dict with the current settings

                {
                    "entity": "models",
                    "base_url": "https://api.wandb.ai",
                    "project": None
                }
        """
        result = self.default_settings.copy()
        result.update(self._settings.items(section=section))
        result.update({
            "entity":
            env.get_entity(
                self._settings.get(
                    Settings.DEFAULT_SECTION,
                    "entity",
                    fallback=result.get("entity"),
                ),
                env=self._environ,
            ),
            "project":
            env.get_project(
                self._settings.get(
                    Settings.DEFAULT_SECTION,
                    "project",
                    fallback=result.get("project"),
                ),
                env=self._environ,
            ),
            "base_url":
            env.get_base_url(
                self._settings.get(
                    Settings.DEFAULT_SECTION,
                    "base_url",
                    fallback=result.get("base_url"),
                ),
                env=self._environ,
            ),
            "ignore_globs":
            env.get_ignore(
                self._settings.get(
                    Settings.DEFAULT_SECTION,
                    "ignore_globs",
                    fallback=result.get("ignore_globs"),
                ),
                env=self._environ,
            ),
        })

        return result if key is None else result[key]

    def clear_setting(self, key, globally=False, persist=False):
        self._settings.clear(Settings.DEFAULT_SECTION,
                             key,
                             globally=globally,
                             persist=persist)

    def set_setting(self, key, value, globally=False, persist=False):
        self._settings.set(Settings.DEFAULT_SECTION,
                           key,
                           value,
                           globally=globally,
                           persist=persist)
        if key == "entity":
            env.set_entity(value, env=self._environ)
        elif key == "project":
            env.set_project(value, env=self._environ)
        elif key == "base_url":
            self.relocate()

    def parse_slug(self, slug, project=None, run=None):
        if slug and "/" in slug:
            parts = slug.split("/")
            project = parts[0]
            run = parts[1]
        else:
            project = project or self.settings().get("project")
            if project is None:
                raise CommError("No default project configured.")
            run = run or slug or env.get_run(env=self._environ)
            if run is None:
                run = "latest"
        return (project, run)

    @normalize_exceptions
    def viewer(self):
        query = gql("""
        query Viewer{
            viewer {
                id
                entity
                teams {
                    edges {
                        node {
                            name
                        }
                    }
                }
            }
        }
        """)
        res = self.gql(query)
        return res.get("viewer") or {}

    @normalize_exceptions
    def list_projects(self, entity=None):
        """Lists projects in W&B scoped by entity.

        Args:
            entity (str, optional): The entity to scope this project to.

        Returns:
                [{"id","name","description"}]
        """
        query = gql("""
        query Models($entity: String!) {
            models(first: 10, entityName: $entity) {
                edges {
                    node {
                        id
                        name
                        description
                    }
                }
            }
        }
        """)
        return self._flatten_edges(
            self.gql(
                query,
                variable_values={"entity": entity
                                 or self.settings("entity")})["models"])

    @normalize_exceptions
    def project(self, project, entity=None):
        """Retrive project

        Args:
            project (str): The project to get details for
            entity (str, optional): The entity to scope this project to.

        Returns:
                [{"id","name","repo","dockerImage","description"}]
        """
        query = gql("""
        query Models($entity: String, $project: String!) {
            model(name: $project, entityName: $entity) {
                id
                name
                repo
                dockerImage
                description
            }
        }
        """)
        return self.gql(query,
                        variable_values={
                            "entity": entity,
                            "project": project
                        })["model"]

    @normalize_exceptions
    def sweep(self, sweep, specs, project=None, entity=None):
        """Retrieve sweep.

        Args:
            sweep (str): The sweep to get details for
            specs (str): history specs
            project (str, optional): The project to scope this sweep to.
            entity (str, optional): The entity to scope this sweep to.

        Returns:
                [{"id","name","repo","dockerImage","description"}]
        """
        query = gql("""
        query Models($entity: String, $project: String!, $sweep: String!, $specs: [JSONString!]!) {
            model(name: $project, entityName: $entity) {
                sweep(sweepName: $sweep) {
                    id
                    name
                    method
                    state
                    description
                    config
                    createdAt
                    heartbeatAt
                    updatedAt
                    earlyStopJobRunning
                    bestLoss
                    controller
                    scheduler
                    runs {
                        edges {
                            node {
                                name
                                state
                                config
                                exitcode
                                heartbeatAt
                                shouldStop
                                failed
                                stopped
                                running
                                summaryMetrics
                                sampledHistory(specs: $specs)
                            }
                        }
                    }
                }
            }
        }
        """)
        entity = entity or self.settings("entity")
        project = project or self.settings("project")
        response = self.gql(
            query,
            variable_values={
                "entity": entity,
                "project": project,
                "sweep": sweep,
                "specs": specs,
            },
        )
        if response["model"] is None or response["model"]["sweep"] is None:
            raise ValueError("Sweep {}/{}/{} not found".format(
                entity, project, sweep))
        data = response["model"]["sweep"]
        if data:
            data["runs"] = self._flatten_edges(data["runs"])
        return data

    @normalize_exceptions
    def list_runs(self, project, entity=None):
        """Lists runs in W&B scoped by project.

        Args:
            project (str): The project to scope the runs to
            entity (str, optional): The entity to scope this project to.  Defaults to public models

        Returns:
                [{"id",name","description"}]
        """
        query = gql("""
        query Buckets($model: String!, $entity: String!) {
            model(name: $model, entityName: $entity) {
                buckets(first: 10) {
                    edges {
                        node {
                            id
                            name
                            displayName
                            description
                        }
                    }
                }
            }
        }
        """)
        return self._flatten_edges(
            self.gql(
                query,
                variable_values={
                    "entity": entity or self.settings("entity"),
                    "model": project or self.settings("project"),
                },
            )["model"]["buckets"])

    @normalize_exceptions
    def launch_run(self, command, project=None, entity=None, run_id=None):
        """Launch a run in the cloud.

        Args:
            command (str): The command to run
            program (str): The file to run
            project (str): The project to scope the runs to
            entity (str, optional): The entity to scope this project to.  Defaults to public models
            run_id (str, optional): The run_id to scope to

        Returns:
                [{"podName","status"}]
        """
        query = gql("""
        mutation launchRun(
            $entity: String
            $model: String
            $runId: String
            $image: String
            $command: String
            $patch: String
            $cwd: String
            $datasets: [String]
        ) {
            launchRun(input: {id: $runId, entityName: $entity, patch: $patch, modelName: $model,
                image: $image, command: $command, datasets: $datasets, cwd: $cwd}) {
                podName
                status
                runId
            }
        }
        """)
        run_id = run_id or self.current_run_id
        assert run_id, "run_id must be specified"
        patch = BytesIO()
        if self.git.dirty:
            self.git.repo.git.execute(["git", "diff"], output_stream=patch)
            patch.seek(0)
        cwd = "."
        if self.git.enabled:
            cwd += os.getcwd().replace(self.git.repo.working_dir, "")
        return self.gql(
            query,
            variable_values={
                "entity": entity or self.settings("entity"),
                "model": project or self.settings("project"),
                "command": command,
                "runId": run_id,
                "patch": patch.read().decode("utf8"),
                "cwd": cwd,
            },
        )

    @normalize_exceptions
    def run_config(self, project, run=None, entity=None):
        """Get the relevant configs for a run

        Args:
            project (str): The project to download, (can include bucket)
            run (str): The run to download
            entity (str, optional): The entity to scope this project to.
        """
        query = gql("""
        query Model($name: String!, $entity: String!, $run: String!) {
            model(name: $name, entityName: $entity) {
                bucket(name: $run) {
                    config
                    commit
                    patch
                    files(names: ["wandb-metadata.json"]) {
                        edges {
                            node {
                                url
                            }
                        }
                    }
                }
            }
        }
        """)
        run = run or self.current_run_id
        assert run, "run must be specified"
        response = self.gql(query,
                            variable_values={
                                "name": project,
                                "run": run,
                                "entity": entity
                            })
        if response["model"] is None:
            raise ValueError("Run {}/{}/{} not found".format(
                entity, project, run))
        run = response["model"]["bucket"]
        commit = run["commit"]
        patch = run["patch"]
        config = json.loads(run["config"] or "{}")
        if len(run["files"]["edges"]) > 0:
            url = run["files"]["edges"][0]["node"]["url"]
            res = requests.get(url)
            res.raise_for_status()
            metadata = res.json()
        else:
            metadata = {}
        return (commit, config, patch, metadata)

    @normalize_exceptions
    def run_resume_status(self, entity, project_name, name):
        """Check if a run exists and get resume information.

        Args:
            entity (str, optional): The entity to scope this project to.
            project_name (str): The project to download, (can include bucket)
            name (str): The run to download
        """
        query = gql("""
        query Model($project: String!, $entity: String, $name: String!) {
            model(name: $project, entityName: $entity) {
                id
                name
                entity {
                    id
                    name
                }

                bucket(name: $name, missingOk: true) {
                    id
                    name
                    summaryMetrics
                    displayName
                    logLineCount
                    historyLineCount
                    eventsLineCount
                    historyTail
                    eventsTail
                    config
                }
            }
        }
        """)

        response = self.gql(
            query,
            variable_values={
                "entity": entity,
                "project": project_name,
                "name": name,
            },
        )

        if "model" not in response or "bucket" not in (response["model"]
                                                       or {}):
            return None

        project = response["model"]
        self.set_setting("project", project_name)
        if "entity" in project:
            self.set_setting("entity", project["entity"]["name"])

        return project["bucket"]

    @normalize_exceptions
    def check_stop_requested(self, project_name, entity_name, run_id):
        query = gql("""
        query Model($projectName: String, $entityName: String, $runId: String!) {
            project(name:$projectName, entityName:$entityName) {
                run(name:$runId) {
                    stopped
                }
            }
        }
        """)
        run_id = run_id or self.current_run_id
        assert run_id, "run_id must be specified"
        response = self.gql(
            query,
            variable_values={
                "projectName": project_name,
                "entityName": entity_name,
                "runId": run_id,
            },
        )

        project = response.get("project", None)
        if not project:
            return False
        run = project.get("run", None)
        if not run:
            return False

        return run["stopped"]

    def format_project(self, project):
        return re.sub(r"\W+", "-", project.lower()).strip("-_")

    @normalize_exceptions
    def upsert_project(self, project, id=None, description=None, entity=None):
        """Create a new project

        Args:
            project (str): The project to create
            description (str, optional): A description of this project
            entity (str, optional): The entity to scope this project to.
        """
        mutation = gql("""
        mutation UpsertModel($name: String!, $id: String, $entity: String!, $description: String, $repo: String)  {
            upsertModel(input: { id: $id, name: $name, entityName: $entity, description: $description, repo: $repo }) {
                model {
                    name
                    description
                }
            }
        }
        """)
        response = self.gql(
            mutation,
            variable_values={
                "name": self.format_project(project),
                "entity": entity or self.settings("entity"),
                "description": description,
                "repo": self.git.remote_url,
                "id": id,
            },
        )
        return response["upsertModel"]["model"]

    @normalize_exceptions
    def pop_from_run_queue(self, entity=None, project=None):
        mutation = gql("""
        mutation popFromRunQueue($entity: String!, $project: String!)  {
            popFromRunQueue(input: { entityName: $entity, projectName: $project }) {
                runQueueItemId
                runSpec
            }
        }
        """)
        response = self.gql(mutation,
                            variable_values={
                                "entity": entity,
                                "project": project
                            })
        return response["popFromRunQueue"]

    @normalize_exceptions
    def upsert_run(
        self,
        id=None,
        name=None,
        project=None,
        host=None,
        group=None,
        tags=None,
        config=None,
        description=None,
        entity=None,
        state=None,
        display_name=None,
        notes=None,
        repo=None,
        job_type=None,
        program_path=None,
        commit=None,
        sweep_name=None,
        summary_metrics=None,
        num_retries=None,
    ):
        """Update a run

        Args:
            id (str, optional): The existing run to update
            name (str, optional): The name of the run to create
            group (str, optional): Name of the group this run is a part of
            project (str, optional): The name of the project
            config (dict, optional): The latest config params
            description (str, optional): A description of this project
            entity (str, optional): The entity to scope this project to.
            repo (str, optional): Url of the program's repository.
            state (str, optional): State of the program.
            job_type (str, optional): Type of job, e.g 'train'.
            program_path (str, optional): Path to the program.
            commit (str, optional): The Git SHA to associate the run with
            summary_metrics (str, optional): The JSON summary metrics
        """
        mutation = gql("""
        mutation UpsertBucket(
            $id: String, $name: String,
            $project: String,
            $entity: String!,
            $groupName: String,
            $description: String,
            $displayName: String,
            $notes: String,
            $commit: String,
            $config: JSONString,
            $host: String,
            $debug: Boolean,
            $program: String,
            $repo: String,
            $jobType: String,
            $state: String,
            $sweep: String,
            $tags: [String!],
            $summaryMetrics: JSONString,
        ) {
            upsertBucket(input: {
                id: $id,
                name: $name,
                groupName: $groupName,
                modelName: $project,
                entityName: $entity,
                description: $description,
                displayName: $displayName,
                notes: $notes,
                config: $config,
                commit: $commit,
                host: $host,
                debug: $debug,
                jobProgram: $program,
                jobRepo: $repo,
                jobType: $jobType,
                state: $state,
                sweep: $sweep,
                tags: $tags,
                summaryMetrics: $summaryMetrics,
            }) {
                bucket {
                    id
                    name
                    displayName
                    description
                    config
                    project {
                        id
                        name
                        entity {
                            id
                            name
                        }
                    }
                }
            }
        }
        """)
        if config is not None:
            config = json.dumps(config)
        if not description or description.isspace():
            description = None

        kwargs = {}
        if num_retries is not None:
            kwargs["num_retries"] = num_retries

        variable_values = {
            "id": id,
            "entity": entity or self.settings("entity"),
            "name": name,
            "project": project,
            "groupName": group,
            "tags": tags,
            "description": description,
            "config": config,
            "commit": commit,
            "displayName": display_name,
            "notes": notes,
            "host":
            None if self.settings().get("anonymous") == "true" else host,
            "debug": env.is_debug(env=self._environ),
            "repo": repo,
            "program": program_path,
            "jobType": job_type,
            "state": state,
            "sweep": sweep_name,
            "summaryMetrics": summary_metrics,
        }

        response = self.gql(mutation,
                            variable_values=variable_values,
                            **kwargs)

        run = response["upsertBucket"]["bucket"]
        if project := run.get("project"):
            self.set_setting("project", project["name"])
            if entity := project.get("entity"):
                self.set_setting("entity", entity["name"])