class Api(object): """W&B Internal Api wrapper Note: Settings are automatically overridden by looking for a `wandb/settings` file in the current working directory or it's parent directory. If none can be found, we look in the current users home directory. Args: default_settings(:obj:`dict`, optional): If you aren't using a settings file or you wish to override the section to use in the settings file Override the settings here. """ HTTP_TIMEOUT = env.get_http_timeout(10) def __init__(self, default_settings=None, load_settings=True, retry_timedelta=datetime.timedelta(days=1), environ=os.environ): self._environ = environ self.default_settings = { 'section': "default", 'run': "latest", 'git_remote': "origin", 'ignore_globs': [], 'base_url': "https://api.wandb.ai" } self.retry_timedelta = retry_timedelta self.default_settings.update(default_settings or {}) self.retry_uploads = 10 self._settings = Settings(load_settings=load_settings) self.git = GitRepo(remote=self.settings("git_remote")) # Mutable settings set by the _file_stream_api self.dynamic_settings = { 'system_sample_seconds': 2, 'system_samples': 15, 'heartbeat_seconds': 30, } self.client = Client( transport=RequestsHTTPTransport( headers={'User-Agent': self.user_agent, 'X-WANDB-USERNAME': env.get_username(env=self._environ)}, use_json=True, # this timeout won't apply when the DNS lookup fails. in that case, it will be 60s # https://bugs.python.org/issue22889 timeout=self.HTTP_TIMEOUT, auth=("api", self.api_key or ""), url='%s/graphql' % self.settings('base_url') ) ) self.gql = retry.Retry(self.execute, retry_timedelta=retry_timedelta, check_retry_fn=util.no_retry_auth, retryable_exceptions=(RetryError, requests.RequestException)) self._current_run_id = None self._file_stream_api = None def reauth(self): """Ensures the current api key is set in the transport""" self.client.transport.auth = ("api", self.api_key or "") def execute(self, *args, **kwargs): """Wrapper around execute that logs in cases of failure.""" try: return self.client.execute(*args, **kwargs) except requests.exceptions.HTTPError as err: res = err.response logger.error("%s response executing GraphQL." % res.status_code) logger.error(res.text) self.display_gorilla_error_if_found(res) six.reraise(*sys.exc_info()) def display_gorilla_error_if_found(self, res): try: data = res.json() except ValueError: return if 'errors' in data and isinstance(data['errors'], list): for err in data['errors']: if not err.get('message'): continue wandb.termerror('Error while calling W&B API: {} ({})'.format(err['message'], res)) def disabled(self): return self._settings.get(Settings.DEFAULT_SECTION, 'disabled', fallback=False) def sync_spell(self, run, env=None): """Syncs this run with spell""" try: env = env or os.environ run.config._set_wandb("spell_url", env.get("SPELL_RUN_URL")) run.config.persist() try: url = run.get_url() except CommError as e: wandb.termerror("Unable to register run with spell.run: %s" % e.message) return False return requests.put(env.get("SPELL_API_URL", "https://api.spell.run") + "/wandb_url", json={ "access_token": env.get("WANDB_ACCESS_TOKEN"), "url": url }, timeout=2) except requests.RequestException: return False def save_pip(self, out_dir): """Saves the current working set of pip packages to requirements.txt""" try: import pkg_resources installed_packages = [d for d in iter(pkg_resources.working_set)] installed_packages_list = sorted( ["%s==%s" % (i.key, i.version) for i in installed_packages] ) with open(os.path.join(out_dir, 'requirements.txt'), 'w') as f: f.write("\n".join(installed_packages_list)) except Exception as e: logger.error("Error saving pip packages") def save_patches(self, out_dir): """Save the current state of this repository to one or more patches. Makes one patch against HEAD and another one against the most recent commit that occurs in an upstream branch. This way we can be robust to history editing as long as the user never does "push -f" to break history on an upstream branch. Writes the first patch to <out_dir>/diff.patch and the second to <out_dir>/upstream_diff_<commit_id>.patch. Args: out_dir (str): Directory to write the patch files. """ if not self.git.enabled: return False try: root = self.git.root if self.git.dirty: patch_path = os.path.join(out_dir, 'diff.patch') if self.git.has_submodule_diff: with open(patch_path, 'wb') as patch: # we diff against HEAD to ensure we get changes in the index subprocess.check_call( ['git', 'diff', '--submodule=diff', 'HEAD'], stdout=patch, cwd=root, timeout=5) else: with open(patch_path, 'wb') as patch: subprocess.check_call( ['git', 'diff', 'HEAD'], stdout=patch, cwd=root, timeout=5) upstream_commit = self.git.get_upstream_fork_point() if upstream_commit and upstream_commit != self.git.repo.head.commit: sha = upstream_commit.hexsha upstream_patch_path = os.path.join( out_dir, 'upstream_diff_{}.patch'.format(sha)) if self.git.has_submodule_diff: with open(upstream_patch_path, 'wb') as upstream_patch: subprocess.check_call( ['git', 'diff', '--submodule=diff', sha], stdout=upstream_patch, cwd=root, timeout=5) else: with open(upstream_patch_path, 'wb') as upstream_patch: subprocess.check_call( ['git', 'diff', sha], stdout=upstream_patch, cwd=root, timeout=5) # TODO: A customer saw `ValueError: Reference at 'refs/remotes/origin/foo' does not exist` # so we now catch ValueError. Catching this error feels too generic. except (ValueError, subprocess.CalledProcessError, subprocess.TimeoutExpired) as e: logger.error('Error generating diff: %s' % e) def set_current_run_id(self, run_id): self._current_run_id = run_id @property def current_run_id(self): return self._current_run_id @property def user_agent(self): return 'W&B Internal Client %s' % __version__ @property def api_key(self): auth = requests.utils.get_netrc_auth(self.api_url) key = None if auth: key = auth[-1] # Environment should take precedence if self._environ.get(env.API_KEY): key = self._environ.get(env.API_KEY) return key @property def api_url(self): return self.settings('base_url') @property def app_url(self): api_url = self.api_url # Development if api_url.endswith('.test') or self.settings().get("dev_prod"): return 'http://app.test' # On-prem VM if api_url.endswith(':11001'): return api_url.replace(':11001', ':11000') # Normal if api_url.startswith('https://api.'): return api_url.replace('api.', 'app.') # Unexpected return api_url def settings(self, key=None, section=None): """The settings overridden from the wandb/settings file. Args: key (str, optional): If provided only this setting is returned section (str, optional): If provided this section of the setting file is used, defaults to "default" Returns: A dict with the current settings { "entity": "models", "base_url": "https://api.wandb.ai", "project": None } """ result = self.default_settings.copy() result.update(self._settings.items(section=section)) result.update({ 'entity': env.get_entity( self._settings.get(Settings.DEFAULT_SECTION, "entity", fallback=result.get('entity')), env=self._environ), 'project': env.get_project( self._settings.get(Settings.DEFAULT_SECTION, "project", fallback=result.get('project')), env=self._environ), 'base_url': env.get_base_url( self._settings.get(Settings.DEFAULT_SECTION, "base_url", fallback=result.get('base_url')), env=self._environ), 'ignore_globs': env.get_ignore( self._settings.get(Settings.DEFAULT_SECTION, "ignore_globs", fallback=result.get('ignore_globs')), env=self._environ), }) return result if key is None else result[key] def clear_setting(self, key): self._settings.clear(Settings.DEFAULT_SECTION, key) def set_setting(self, key, value, globally=False): self._settings.set(Settings.DEFAULT_SECTION, key, value, globally=globally) if key == 'entity': env.set_entity(value, env=self._environ) elif key == 'project': env.set_project(value, env=self._environ) def parse_slug(self, slug, project=None, run=None): if slug and "/" in slug: parts = slug.split("/") project = parts[0] run = parts[1] else: project = project or self.settings().get("project") if project is None: raise CommError("No default project configured.") run = run or slug or env.get_run(env=self._environ) if run is None: run = "latest" return (project, run) @normalize_exceptions def viewer(self): query = gql(''' query Viewer{ viewer { id entity teams { edges { node { name } } } } } ''') res = self.gql(query) return res.get('viewer') or {} @normalize_exceptions def list_projects(self, entity=None): """Lists projects in W&B scoped by entity. Args: entity (str, optional): The entity to scope this project to. Returns: [{"id","name","description"}] """ query = gql(''' query Models($entity: String!) { models(first: 10, entityName: $entity) { edges { node { id name description } } } } ''') return self._flatten_edges(self.gql(query, variable_values={ 'entity': entity or self.settings('entity')})['models']) @normalize_exceptions def project(self, project, entity=None): """Retrive project Args: project (str): The project to get details for entity (str, optional): The entity to scope this project to. Returns: [{"id","name","repo","dockerImage","description"}] """ query = gql(''' query Models($entity: String, $project: String!) { model(name: $project, entityName: $entity) { id name repo dockerImage description } } ''') return self.gql(query, variable_values={ 'entity': entity, 'project': project})['model'] @normalize_exceptions def sweep(self, sweep, specs, project=None, entity=None): """Retrieve sweep. Args: sweep (str): The sweep to get details for specs (str): history specs project (str, optional): The project to scope this sweep to. entity (str, optional): The entity to scope this sweep to. Returns: [{"id","name","repo","dockerImage","description"}] """ query = gql(''' query Models($entity: String, $project: String!, $sweep: String!, $specs: [JSONString!]!) { model(name: $project, entityName: $entity) { sweep(sweepName: $sweep) { id name method state description config createdAt heartbeatAt updatedAt earlyStopJobRunning bestLoss controller scheduler runs { edges { node { name state config exitcode heartbeatAt shouldStop failed stopped running summaryMetrics sampledHistory(specs: $specs) } } } } } } ''') data = self.gql(query, variable_values={ 'entity': entity or self.settings('entity'), 'project': project or self.settings('project'), 'sweep': sweep, 'specs': specs})['model']['sweep'] if data: data['runs'] = self._flatten_edges(data['runs']) return data @normalize_exceptions def list_runs(self, project, entity=None): """Lists runs in W&B scoped by project. Args: project (str): The project to scope the runs to entity (str, optional): The entity to scope this project to. Defaults to public models Returns: [{"id",name","description"}] """ query = gql(''' query Buckets($model: String!, $entity: String!) { model(name: $model, entityName: $entity) { buckets(first: 10) { edges { node { id name displayName description } } } } } ''') return self._flatten_edges(self.gql(query, variable_values={ 'entity': entity or self.settings('entity'), 'model': project or self.settings('project')})['model']['buckets']) @normalize_exceptions def launch_run(self, command, project=None, entity=None, run_id=None): """Launch a run in the cloud. Args: command (str): The command to run program (str): The file to run project (str): The project to scope the runs to entity (str, optional): The entity to scope this project to. Defaults to public models run_id (str, optional): The run_id to scope to Returns: [{"podName","status"}] """ query = gql(''' mutation launchRun( $entity: String $model: String $runId: String $image: String $command: String $patch: String $cwd: String $datasets: [String] ) { launchRun(input: {id: $runId, entityName: $entity, patch: $patch, modelName: $model, image: $image, command: $command, datasets: $datasets, cwd: $cwd}) { podName status runId } } ''') patch = BytesIO() if self.git.dirty: self.git.repo.git.execute(['git', 'diff'], output_stream=patch) patch.seek(0) cwd = "." if self.git.enabled: cwd = cwd + os.getcwd().replace(self.git.repo.working_dir, "") return self.gql(query, variable_values={ 'entity': entity or self.settings('entity'), 'model': project or self.settings('project'), 'command': command, 'runId': run_id, 'patch': patch.read().decode("utf8"), 'cwd': cwd }) @normalize_exceptions def run_config(self, project, run=None, entity=None): """Get the relevant configs for a run Args: project (str): The project to download, (can include bucket) run (str, optional): The run to download entity (str, optional): The entity to scope this project to. """ query = gql(''' query Model($name: String!, $entity: String!, $run: String!) { model(name: $name, entityName: $entity) { bucket(name: $run) { config commit patch files(names: ["wandb-metadata.json"]) { edges { node { url } } } } } } ''') response = self.gql(query, variable_values={ 'name': project, 'run': run, 'entity': entity }) if response['model'] == None: raise ValueError("Run {}/{}/{} not found".format(entity, project, run) ) run = response['model']['bucket'] commit = run['commit'] patch = run['patch'] config = json.loads(run['config'] or '{}') if len(run['files']['edges']) > 0: url = run['files']['edges'][0]['node']['url'] res = requests.get(url) res.raise_for_status() metadata = res.json() else: metadata = {} return (commit, config, patch, metadata) @normalize_exceptions def run_resume_status(self, entity, project_name, name): """Check if a run exists and get resume information. Args: entity (str, optional): The entity to scope this project to. project_name (str): The project to download, (can include bucket) run (str, optional): The run to download """ query = gql(''' query Model($project: String!, $entity: String, $name: String!) { model(name: $project, entityName: $entity) { id name entity { id name } bucket(name: $name, missingOk: true) { id name summaryMetrics displayName logLineCount historyLineCount eventsLineCount historyTail eventsTail config } } } ''') response = self.gql(query, variable_values={ 'entity': entity, 'project': project_name, 'name': name, }) if 'model' not in response or 'bucket' not in (response['model'] or {}): return None project = response['model'] self.set_setting('project', project_name) if 'entity' in project: self.set_setting('entity', project['entity']['name']) return project['bucket'] @normalize_exceptions def check_stop_requested(self, project_name, entity_name, run_id): query = gql(''' query Model($projectName: String, $entityName: String, $runId: String!) { project(name:$projectName, entityName:$entityName) { run(name:$runId) { stopped } } } ''') response = self.gql(query, variable_values={ 'projectName': project_name, 'entityName': entity_name, 'runId': run_id, }) project = response.get('project', None) if not project: return False run = project.get('run', None) if not run: return False return run['stopped'] def format_project(self, project): return re.sub(r'\W+', '-', project.lower()).strip("-_") @normalize_exceptions def upsert_project(self, project, id=None, description=None, entity=None): """Create a new project Args: project (str): The project to create description (str, optional): A description of this project entity (str, optional): The entity to scope this project to. """ mutation = gql(''' mutation UpsertModel($name: String!, $id: String, $entity: String!, $description: String, $repo: String) { upsertModel(input: { id: $id, name: $name, entityName: $entity, description: $description, repo: $repo }) { model { name description } } } ''') response = self.gql(mutation, variable_values={ 'name': self.format_project(project), 'entity': entity or self.settings('entity'), 'description': description, 'repo': self.git.remote_url, 'id': id}) return response['upsertModel']['model'] @normalize_exceptions def upsert_run(self, id=None, name=None, project=None, host=None, group=None, tags=None, config=None, description=None, entity=None, state=None, display_name=None, notes=None, repo=None, job_type=None, program_path=None, commit=None, sweep_name=None, summary_metrics=None, num_retries=None): """Update a run Args: id (str, optional): The existing run to update name (str, optional): The name of the run to create group (str, optional): Name of the group this run is a part of project (str, optional): The name of the project config (dict, optional): The latest config params description (str, optional): A description of this project entity (str, optional): The entity to scope this project to. repo (str, optional): Url of the program's repository. state (str, optional): State of the program. job_type (str, optional): Type of job, e.g 'train'. program_path (str, optional): Path to the program. commit (str, optional): The Git SHA to associate the run with summary_metrics (str, optional): The JSON summary metrics """ mutation = gql(''' mutation UpsertBucket( $id: String, $name: String, $project: String, $entity: String!, $groupName: String, $description: String, $displayName: String, $notes: String, $commit: String, $config: JSONString, $host: String, $debug: Boolean, $program: String, $repo: String, $jobType: String, $state: String, $sweep: String, $tags: [String!], $summaryMetrics: JSONString, ) { upsertBucket(input: { id: $id, name: $name, groupName: $groupName, modelName: $project, entityName: $entity, description: $description, displayName: $displayName, notes: $notes, config: $config, commit: $commit, host: $host, debug: $debug, jobProgram: $program, jobRepo: $repo, jobType: $jobType, state: $state, sweep: $sweep, tags: $tags, summaryMetrics: $summaryMetrics, }) { bucket { id name displayName description config project { id name entity { id name } } } } } ''') if config is not None: config = json.dumps(config) if not description or description.isspace(): description = None kwargs = {} if num_retries is not None: kwargs['num_retries'] = num_retries variable_values = { 'id': id, 'entity': entity or self.settings('entity'), 'name': name, 'project': project, 'groupName': group, 'tags': tags, 'description': description, 'config': config, 'commit': commit, 'displayName': display_name, 'notes': notes, 'host': None if self.settings().get('anonymous') == 'true' else host, 'debug': env.is_debug(env=self._environ), 'repo': repo, 'program': program_path, 'jobType': job_type, 'state': state, 'sweep': sweep_name, 'summaryMetrics': summary_metrics } response = self.gql( mutation, variable_values=variable_values, **kwargs) run = response['upsertBucket']['bucket'] project = run.get('project') if project: self.set_setting('project', project['name']) entity = project.get('entity') if entity: self.set_setting('entity', entity['name']) return response['upsertBucket']['bucket'] @normalize_exceptions def upload_urls(self, project, files, run=None, entity=None, description=None): """Generate temporary resumeable upload urls Args: project (str): The project to download files (list or dict): The filenames to upload run (str, optional): The run to upload to entity (str, optional): The entity to scope this project to. Defaults to wandb models Returns: (bucket_id, file_info) bucket_id: id of bucket we uploaded to file_info: A dict of filenames and urls, also indicates if this revision already has uploaded files. { 'weights.h5': { "url": "https://weights.url" }, 'model.json': { "url": "https://model.json", "updatedAt": '2013-04-26T22:22:23.832Z', 'md5': 'mZFLkyvTelC5g8XnyQrpOw==' }, } """ query = gql(''' query Model($name: String!, $files: [String]!, $entity: String!, $run: String!, $description: String) { model(name: $name, entityName: $entity) { bucket(name: $run, desc: $description) { id files(names: $files) { edges { node { name url(upload: true) updatedAt } } } } } } ''') run_id = run or self.settings('run') entity = entity or self.settings('entity') query_result = self.gql(query, variable_values={ 'name': project, 'run': run_id, 'entity': entity, 'description': description, 'files': [file for file in files] }) run = query_result['model']['bucket'] if run: result = {file['name']: file for file in self._flatten_edges(run['files'])} return run['id'], result else: raise CommError("Run does not exist {}/{}/{}.".format(entity, project, run_id)) @normalize_exceptions def download_urls(self, project, run=None, entity=None): """Generate download urls Args: project (str): The project to download run (str, optional): The run to upload to entity (str, optional): The entity to scope this project to. Defaults to wandb models Returns: A dict of extensions and urls { 'weights.h5': { "url": "https://weights.url", "updatedAt": '2013-04-26T22:22:23.832Z', 'md5': 'mZFLkyvTelC5g8XnyQrpOw==' }, 'model.json': { "url": "https://model.url", "updatedAt": '2013-04-26T22:22:23.832Z', 'md5': 'mZFLkyvTelC5g8XnyQrpOw==' } } """ query = gql(''' query Model($name: String!, $entity: String!, $run: String!) { model(name: $name, entityName: $entity) { bucket(name: $run) { files { edges { node { name url md5 updatedAt } } } } } } ''') query_result = self.gql(query, variable_values={ 'name': project, 'run': run or self.settings('run'), 'entity': entity or self.settings('entity')}) files = self._flatten_edges(query_result['model']['bucket']['files']) return {file['name']: file for file in files if file} @normalize_exceptions def download_url(self, project, file_name, run=None, entity=None): """Generate download urls Args: project (str): The project to download file_name (str): The name of the file to download run (str, optional): The run to upload to entity (str, optional): The entity to scope this project to. Defaults to wandb models Returns: A dict of extensions and urls { "url": "https://weights.url", "updatedAt": '2013-04-26T22:22:23.832Z', 'md5': 'mZFLkyvTelC5g8XnyQrpOw==' } """ query = gql(''' query Model($name: String!, $fileName: String!, $entity: String!, $run: String!) { model(name: $name, entityName: $entity) { bucket(name: $run) { files(names: [$fileName]) { edges { node { name url md5 updatedAt } } } } } } ''') query_result = self.gql(query, variable_values={ 'name': project, 'run': run or self.settings('run'), 'fileName': file_name, 'entity': entity or self.settings('entity')}) if query_result['model']: files = self._flatten_edges(query_result['model']['bucket']['files']) return files[0] if len(files) > 0 and files[0].get('updatedAt') else None else: return None @normalize_exceptions def download_file(self, url): """Initiate a streaming download Args: url (str): The url to download Returns: A tuple of the content length and the streaming response """ response = requests.get(url, stream=True) response.raise_for_status() return (int(response.headers.get('content-length', 0)), response) @normalize_exceptions def download_write_file(self, metadata, out_dir=None): """Download a file from a run and write it to wandb/ Args: metadata (obj): The metadata object for the file to download. Comes from Api.download_urls(). Returns: A tuple of the file's local path and the streaming response. The streaming response is None if the file already existed and was up to date. """ fileName = metadata['name'] path = os.path.join(out_dir or wandb_dir(), fileName) if self.file_current(fileName, metadata['md5']): return path, None size, response = self.download_file(metadata['url']) with open(path, "wb") as file: for data in response.iter_content(chunk_size=1024): file.write(data) return path, response def upload_file(self, url, file, callback=None, extra_headers={}): """Uploads a file to W&B with failure resumption Args: url (str): The url to download file (str): The path to the file you want to upload callback (:obj:`func`, optional): A callback which is passed the number of bytes uploaded since the last time it was called, used to report progress Returns: The requests library response object """ extra_headers = extra_headers.copy() response = None progress = Progress(file, callback=callback) if progress.len == 0: raise CommError("%s is an empty file" % file.name) try: response = requests.put( url, data=progress, headers=extra_headers) response.raise_for_status() except requests.exceptions.RequestException as e: status_code = e.response.status_code if e.response != None else 0 # Retry errors from cloud storage or local network issues if status_code in (308, 409, 429, 500, 502, 503, 504) or isinstance(e, (requests.exceptions.Timeout, requests.exceptions.ConnectionError)): util.sentry_reraise(retry.TransientException(exc=e)) else: util.sentry_reraise(e) return response upload_file_retry = normalize_exceptions(retry.retriable(num_retries=5)(upload_file)) @normalize_exceptions def register_agent(self, host, sweep_id=None, project_name=None, entity=None): """Register a new agent Args: host (str): hostname persistent (bool): long running or oneoff sweep (str): sweep id project_name: (str): model that contains sweep """ mutation = gql(''' mutation CreateAgent( $host: String! $projectName: String!, $entityName: String!, $sweep: String! ) { createAgent(input: { host: $host, projectName: $projectName, entityName: $entityName, sweep: $sweep, }) { agent { id } } } ''') if entity is None: entity = self.settings("entity") if project_name is None: project_name = self.settings('project') # don't retry on validation or not found errors def no_retry_4xx(e): if not isinstance(e, requests.HTTPError): return True if not(e.response.status_code >= 400 and e.response.status_code < 500): return True body = json.loads(e.response.content) raise UsageError(body['errors'][0]['message']) response = self.gql(mutation, variable_values={ 'host': host, 'entityName': entity, 'projectName': project_name, 'sweep': sweep_id}, check_retry_fn=no_retry_4xx) return response['createAgent']['agent'] def agent_heartbeat(self, agent_id, metrics, run_states): """Notify server about agent state, receive commands. Args: agent_id (str): agent_id metrics (dict): system metrics run_states (dict): run_id: state mapping Returns: List of commands to execute. """ mutation = gql(''' mutation Heartbeat( $id: ID!, $metrics: JSONString, $runState: JSONString ) { agentHeartbeat(input: { id: $id, metrics: $metrics, runState: $runState }) { agent { id } commands } } ''') try: response = self.gql(mutation, variable_values={ 'id': agent_id, 'metrics': json.dumps(metrics), 'runState': json.dumps(run_states)}) except Exception as e: # GQL raises exceptions with stringified python dictionaries :/ message = ast.literal_eval(e.args[0])["message"] logger.error('Error communicating with W&B: %s', message) return [] else: return json.loads(response['agentHeartbeat']['commands']) @normalize_exceptions def upsert_sweep(self, config, controller=None, scheduler=None, obj_id=None, project=None, entity=None): """Upsert a sweep object. Args: config (str): sweep config (will be converted to yaml) """ project_query = ''' project { id name entity { id name } } ''' mutation_str = ''' mutation UpsertSweep( $id: ID, $config: String, $description: String, $entityName: String!, $projectName: String!, $controller: JSONString, $scheduler: JSONString ) { upsertSweep(input: { id: $id, config: $config, description: $description, entityName: $entityName, projectName: $projectName, controller: $controller, scheduler: $scheduler }) { sweep { name _PROJECT_QUERY_ } } } ''' # FIXME(jhr): we need protocol versioning to know schema is not supported # for now we will just try both new and old query mutation_new = gql(mutation_str.replace("_PROJECT_QUERY_", project_query)) mutation_old = gql(mutation_str.replace("_PROJECT_QUERY_", "")) # don't retry on validation errors # TODO(jhr): generalize error handling routines def no_retry_4xx(e): if not isinstance(e, requests.HTTPError): return True if not(e.response.status_code >= 400 and e.response.status_code < 500): return True body = json.loads(e.response.content) raise UsageError(body['errors'][0]['message']) for mutation in mutation_new, mutation_old: try: response = self.gql(mutation, variable_values={ 'id': obj_id, 'config': yaml.dump(config), 'description': config.get("description"), 'entityName': entity or self.settings("entity"), 'projectName': project or self.settings("project"), 'controller': controller, 'scheduler': scheduler}, check_retry_fn=no_retry_4xx) except UsageError as e: raise(e) except Exception as e: # graphql schema exception is generic err = e continue err = None break if err: raise(err) sweep = response['upsertSweep']['sweep'] project = sweep.get('project') if project: self.set_setting('project', project['name']) entity = project.get('entity') if entity: self.set_setting('entity', entity['name']) return response['upsertSweep']['sweep']['name'] @normalize_exceptions def create_anonymous_api_key(self): """Creates a new API key belonging to a new anonymous user.""" mutation = gql(''' mutation CreateAnonymousApiKey { createAnonymousEntity(input: {}) { apiKey { name } } } ''') response = self.gql(mutation, variable_values={}) return response['createAnonymousEntity']['apiKey']['name'] def file_current(self, fname, md5): """Checksum a file and compare the md5 with the known md5 """ return os.path.isfile(fname) and util.md5_file(fname) == md5 @normalize_exceptions def pull(self, project, run=None, entity=None): """Download files from W&B Args: project (str): The project to download run (str, optional): The run to upload to entity (str, optional): The entity to scope this project to. Defaults to wandb models Returns: The requests library response object """ project, run = self.parse_slug(project, run=run) urls = self.download_urls(project, run, entity) responses = [] for fileName in urls: _, response = self.download_write_file(urls[fileName]) if response: responses.append(response) return responses def get_project(self): return self.settings('project') @normalize_exceptions def push(self, files, run=None, entity=None, project=None, description=None, force=True, progress=False): """Uploads multiple files to W&B Args: files (list or dict): The filenames to upload run (str, optional): The run to upload to entity (str, optional): The entity to scope this project to. Defaults to wandb models project (str, optional): The name of the project to upload to. Defaults to the one in settings. description (str, optional): The description of the changes force (bool, optional): Whether to prevent push if git has uncommitted changes progress (callable, or stream): If callable, will be called with (chunk_bytes, total_bytes) as argument else if True, renders a progress bar to stream. Returns: The requests library response object """ if project is None: project = self.get_project() if project is None: raise CommError("No project configured.") if run is None: run = self.current_run_id # TODO(adrian): we use a retriable version of self.upload_file() so # will never retry self.upload_urls() here. Instead, maybe we should # make push itself retriable. run_id, result = self.upload_urls( project, files, run, entity, description) responses = [] for file_name, file_info in result.items(): file_url = file_info['url'] # If the upload URL is relative, fill it in with the base URL, # since its a proxied file store like the on-prem VM. if file_url.startswith('/'): file_url = '{}{}'.format(self.api_url, file_url) try: # To handle Windows paths # TODO: this doesn't handle absolute paths... normal_name = os.path.join(*file_name.split("/")) open_file = files[file_name] if isinstance( files, dict) else open(normal_name, "rb") except IOError: print("%s does not exist" % file_name) continue if progress: if hasattr(progress, '__call__'): responses.append(self.upload_file_retry( file_url, open_file, progress)) else: length = os.fstat(open_file.fileno()).st_size with click.progressbar(file=progress, length=length, label='Uploading file: %s' % (file_name), fill_char=click.style('&', fg='green')) as bar: responses.append(self.upload_file_retry( file_url, open_file, lambda bites, _: bar.update(bites))) else: responses.append(self.upload_file_retry(file_info['url'], open_file)) open_file.close() return responses def get_file_stream_api(self): """This creates a new file pusher thread. Call start to initiate the thread that talks to W&B""" if not self._file_stream_api: if self._current_run_id is None: raise UsageError( 'Must have a current run to use file stream API.') self._file_stream_api = FileStreamApi(self, self._current_run_id) return self._file_stream_api def _status_request(self, url, length): """Ask google how much we've uploaded""" return requests.put( url=url, headers={'Content-Length': '0', 'Content-Range': 'bytes */%i' % length} ) def _flatten_edges(self, response): """Return an array from the nested graphql relay structure""" return [node['node'] for node in response['edges']]
class FileStreamApi(object): """Pushes chunks of files to our streaming endpoint. This class is used as a singleton. It has a thread that serializes access to the streaming endpoint and performs rate-limiting and batching. TODO: Differentiate between binary/text encoding. """ Finish = collections.namedtuple("Finish", ("exitcode")) Preempting = collections.namedtuple("Preempting", ()) HTTP_TIMEOUT = env.get_http_timeout(10) MAX_ITEMS_PER_PUSH = 10000 def __init__(self, api, run_id, start_time, settings=None): if settings is None: settings = dict() # NOTE: exc_info is set in thread_except_body context and readable by calling threads self._exc_info = None self._settings = settings self._api = api self._run_id = run_id self._start_time = start_time self._client = requests.Session() self._client.auth = ("api", api.api_key) self._client.timeout = self.HTTP_TIMEOUT self._client.headers.update({ "User-Agent": api.user_agent, "X-WANDB-USERNAME": env.get_username(), "X-WANDB-USER-EMAIL": env.get_user_email(), }) self._file_policies = {} self._queue = queue.Queue() self._thread = threading.Thread(target=self._thread_except_body) # It seems we need to make this a daemon thread to get sync.py's atexit handler to run, which # cleans this thread up. self._thread.name = "FileStreamThread" self._thread.daemon = True self._init_endpoint() def _init_endpoint(self): settings = self._api.settings() settings.update(self._settings) self._endpoint = "{base}/files/{entity}/{project}/{run}/file_stream".format( base=settings["base_url"], entity=settings["entity"], project=settings["project"], run=self._run_id, ) def start(self): self._init_endpoint() self._thread.start() def set_default_file_policy(self, filename, file_policy): """Set an upload policy for a file unless one has already been set. """ if filename not in self._file_policies: self._file_policies[filename] = file_policy def set_file_policy(self, filename, file_policy): self._file_policies[filename] = file_policy @property def heartbeat_seconds(self): # Defaults to 30 return self._api.dynamic_settings["heartbeat_seconds"] def rate_limit_seconds(self): run_time = time.time() - self._start_time if run_time < 60: return max(1, self.heartbeat_seconds / 15) elif run_time < 300: return max(2.5, self.heartbeat_seconds / 3) else: return max(5, self.heartbeat_seconds) def _read_queue(self): # called from the push thread (_thread_body), this does an initial read # that'll block for up to rate_limit_seconds. Then it tries to read # as much out of the queue as it can. We do this because the http post # to the server happens within _thread_body, and can take longer than # our rate limit. So next time we get a chance to read the queue we want # read all the stuff that queue'd up since last time. # # If we have more than MAX_ITEMS_PER_PUSH in the queue then the push thread # will get behind and data will buffer up in the queue. return util.read_many_from_queue(self._queue, self.MAX_ITEMS_PER_PUSH, self.rate_limit_seconds()) def _thread_body(self): posted_data_time = time.time() posted_anything_time = time.time() ready_chunks = [] finished = None while finished is None: items = self._read_queue() for item in items: if isinstance(item, self.Finish): finished = item elif isinstance(item, self.Preempting): request_with_retry( self._client.post, self._endpoint, json={ "complete": False, "preempting": True }, ) else: # item is Chunk ready_chunks.append(item) cur_time = time.time() if ready_chunks and (finished or cur_time - posted_data_time > self.rate_limit_seconds()): posted_data_time = cur_time posted_anything_time = cur_time self._send(ready_chunks) ready_chunks = [] if cur_time - posted_anything_time > self.heartbeat_seconds: posted_anything_time = cur_time self._handle_response( request_with_retry( self._client.post, self._endpoint, json={ "complete": False, "failed": False }, )) # post the final close message. (item is self.Finish instance now) request_with_retry( self._client.post, self._endpoint, json={ "complete": True, "exitcode": int(finished.exitcode) }, ) def _thread_except_body(self): # TODO: Consolidate with internal_util.ExceptionThread try: self._thread_body() except Exception as e: exc_info = sys.exc_info() self._exc_info = exc_info logger.exception("generic exception in filestream thread") util.sentry_exc(exc_info, delay=True) raise e def _handle_response(self, response): """Logs dropped chunks and updates dynamic settings""" if isinstance(response, Exception): wandb.termerror( "Droppped streaming file chunk (see wandb/debug.log)") logging.error("dropped chunk %s" % response) raise response else: parsed: dict = None try: parsed = response.json() except Exception: pass if isinstance(parsed, dict): limits = parsed.get("limits") if isinstance(limits, dict): self._api.dynamic_settings.update(limits) def _send(self, chunks): # create files dict. dict of <filename: chunks> pairs where chunks is a list of # [chunk_id, chunk_data] tuples (as lists since this will be json). files = {} # Groupby needs group keys to be consecutive, so sort first. chunks.sort(key=lambda c: c.filename) for filename, file_chunks in itertools.groupby(chunks, lambda c: c.filename): file_chunks = list(file_chunks) # groupby returns iterator # Specific file policies are set by internal/sender.py self.set_default_file_policy(filename, DefaultFilePolicy()) files[filename] = self._file_policies[filename].process_chunks( file_chunks) if not files[filename]: del files[filename] for fs in file_stream_utils.split_files(files, max_mb=10): self._handle_response( request_with_retry( self._client.post, self._endpoint, json={"files": fs}, retry_callback=self._api.retry_callback, )) def stream_file(self, path): name = path.split("/")[-1] with open(path) as f: self._send([Chunk(name, line) for line in f]) def enqueue_preempting(self): self._queue.put(self.Preempting()) def push(self, filename, data): """Push a chunk of a file to the streaming endpoint. Arguments: filename: Name of file that this is a chunk of. chunk_id: TODO: change to 'offset' chunk: File data. """ self._queue.put(Chunk(filename, data)) def finish(self, exitcode): """Cleans up. Anything pushed after finish will be dropped. Arguments: exitcode: The exitcode of the watched process. """ self._queue.put(self.Finish(exitcode)) # TODO(jhr): join on a thread which exited with an exception is a noop, clean up this path self._thread.join() if self._exc_info: logger.error("FileStream exception", exc_info=self._exc_info) # reraising the original exception, will get recaught in internal.py for the sender thread six.reraise(*self._exc_info)
class Api(object): """W&B Internal Api wrapper Note: Settings are automatically overridden by looking for a `wandb/settings` file in the current working directory or it's parent directory. If none can be found, we look in the current users home directory. Arguments: default_settings(`dict`, optional): If you aren't using a settings file or you wish to override the section to use in the settings file Override the settings here. """ HTTP_TIMEOUT = env.get_http_timeout(10) def __init__( self, default_settings=None, load_settings=True, retry_timedelta=None, environ=os.environ, ): if retry_timedelta is None: retry_timedelta = datetime.timedelta(days=1) self._environ = environ self.default_settings = { "section": "default", "git_remote": "origin", "ignore_globs": [], "base_url": "https://api.wandb.ai", } self.retry_timedelta = retry_timedelta self.default_settings.update(default_settings or {}) self.retry_uploads = 10 self._settings = Settings( load_settings=load_settings, root_dir=self.default_settings.get("root_dir")) # self.git = GitRepo(remote=self.settings("git_remote")) self.git = None # Mutable settings set by the _file_stream_api self.dynamic_settings = { "system_sample_seconds": 2, "system_samples": 15, "heartbeat_seconds": 30, } self.client = Client(transport=RequestsHTTPTransport( headers={ "User-Agent": self.user_agent, "X-WANDB-USERNAME": env.get_username(env=self._environ), "X-WANDB-USER-EMAIL": env.get_user_email(env=self._environ), }, use_json=True, # this timeout won't apply when the DNS lookup fails. in that case, it will be 60s # https://bugs.python.org/issue22889 timeout=self.HTTP_TIMEOUT, auth=("api", self.api_key or ""), url="%s/graphql" % self.settings("base_url"), )) self.gql = retry.Retry( self.execute, retry_timedelta=retry_timedelta, check_retry_fn=util.no_retry_auth, retryable_exceptions=(RetryError, requests.RequestException), ) self._current_run_id = None self._file_stream_api = None def reauth(self): """Ensures the current api key is set in the transport""" self.client.transport.auth = ("api", self.api_key or "") def relocate(self): """Ensures the current api points to the right server""" self.client.transport.url = "%s/graphql" % self.settings("base_url") def execute(self, *args, **kwargs): """Wrapper around execute that logs in cases of failure.""" try: return self.client.execute(*args, **kwargs) except requests.exceptions.HTTPError as err: res = err.response logger.error("%s response executing GraphQL." % res.status_code) logger.error(res.text) self.display_gorilla_error_if_found(res) six.reraise(*sys.exc_info()) def display_gorilla_error_if_found(self, res): try: data = res.json() except ValueError: return if "errors" in data and isinstance(data["errors"], list): for err in data["errors"]: if not err.get("message"): continue wandb.termerror("Error while calling W&B API: {} ({})".format( err["message"], res)) def disabled(self): return self._settings.get(Settings.DEFAULT_SECTION, "disabled", fallback=False) def sync_spell(self, run, env=None): """Syncs this run with spell""" try: env = env or os.environ run.config["_wandb"]["spell_url"] = env.get("SPELL_RUN_URL") run.config.persist() try: url = run.get_url() except CommError as e: wandb.termerror("Unable to register run with spell.run: %s" % str(e)) return False return requests.put( env.get("SPELL_API_URL", "https://api.spell.run") + "/wandb_url", json={ "access_token": env.get("WANDB_ACCESS_TOKEN"), "url": url }, timeout=2, ) except requests.RequestException: return False def save_patches(self, out_dir): """Save the current state of this repository to one or more patches. Makes one patch against HEAD and another one against the most recent commit that occurs in an upstream branch. This way we can be robust to history editing as long as the user never does "push -f" to break history on an upstream branch. Writes the first patch to <out_dir>/<DIFF_FNAME> and the second to <out_dir>/upstream_diff_<commit_id>.patch. Arguments: out_dir (str): Directory to write the patch files. """ if not self.git.enabled: return False try: root = self.git.root if self.git.dirty: patch_path = os.path.join(out_dir, wandb_lib.filenames.DIFF_FNAME) if self.git.has_submodule_diff: with open(patch_path, "wb") as patch: # we diff against HEAD to ensure we get changes in the index subprocess.check_call( ["git", "diff", "--submodule=diff", "HEAD"], stdout=patch, cwd=root, timeout=5, ) else: with open(patch_path, "wb") as patch: subprocess.check_call(["git", "diff", "HEAD"], stdout=patch, cwd=root, timeout=5) upstream_commit = self.git.get_upstream_fork_point() if upstream_commit and upstream_commit != self.git.repo.head.commit: sha = upstream_commit.hexsha upstream_patch_path = os.path.join( out_dir, "upstream_diff_{}.patch".format(sha)) if self.git.has_submodule_diff: with open(upstream_patch_path, "wb") as upstream_patch: subprocess.check_call( ["git", "diff", "--submodule=diff", sha], stdout=upstream_patch, cwd=root, timeout=5, ) else: with open(upstream_patch_path, "wb") as upstream_patch: subprocess.check_call( ["git", "diff", sha], stdout=upstream_patch, cwd=root, timeout=5, ) # TODO: A customer saw `ValueError: Reference at 'refs/remotes/origin/foo' does not exist` # so we now catch ValueError. Catching this error feels too generic. except ( ValueError, subprocess.CalledProcessError, subprocess.TimeoutExpired, ) as e: logger.error("Error generating diff: %s" % e) def set_current_run_id(self, run_id): self._current_run_id = run_id @property def current_run_id(self): return self._current_run_id @property def user_agent(self): return "W&B Internal Client %s" % __version__ @property def api_key(self): auth = requests.utils.get_netrc_auth(self.api_url) key = None if auth: key = auth[-1] # Environment should take precedence if self._environ.get(env.API_KEY): key = self._environ.get(env.API_KEY) return key @property def api_url(self): return self.settings("base_url") @property def app_url(self): return wandb.util.app_url(self.api_url) def settings(self, key=None, section=None): """The settings overridden from the wandb/settings file. Arguments: key (str, optional): If provided only this setting is returned section (str, optional): If provided this section of the setting file is used, defaults to "default" Returns: A dict with the current settings { "entity": "models", "base_url": "https://api.wandb.ai", "project": None } """ result = self.default_settings.copy() result.update(self._settings.items(section=section)) result.update({ "entity": env.get_entity( self._settings.get( Settings.DEFAULT_SECTION, "entity", fallback=result.get("entity"), ), env=self._environ, ), "project": env.get_project( self._settings.get( Settings.DEFAULT_SECTION, "project", fallback=result.get("project"), ), env=self._environ, ), "base_url": env.get_base_url( self._settings.get( Settings.DEFAULT_SECTION, "base_url", fallback=result.get("base_url"), ), env=self._environ, ), "ignore_globs": env.get_ignore( self._settings.get( Settings.DEFAULT_SECTION, "ignore_globs", fallback=result.get("ignore_globs"), ), env=self._environ, ), }) return result if key is None else result[key] def clear_setting(self, key, globally=False, persist=False): self._settings.clear(Settings.DEFAULT_SECTION, key, globally=globally, persist=persist) def set_setting(self, key, value, globally=False, persist=False): self._settings.set(Settings.DEFAULT_SECTION, key, value, globally=globally, persist=persist) if key == "entity": env.set_entity(value, env=self._environ) elif key == "project": env.set_project(value, env=self._environ) elif key == "base_url": self.relocate() def parse_slug(self, slug, project=None, run=None): if slug and "/" in slug: parts = slug.split("/") project = parts[0] run = parts[1] else: project = project or self.settings().get("project") if project is None: raise CommError("No default project configured.") run = run or slug or env.get_run(env=self._environ) if run is None: run = "latest" return (project, run) @normalize_exceptions def viewer(self): query = gql(""" query Viewer{ viewer { id entity teams { edges { node { name } } } } } """) res = self.gql(query) return res.get("viewer") or {} @normalize_exceptions def list_projects(self, entity=None): """Lists projects in W&B scoped by entity. Arguments: entity (str, optional): The entity to scope this project to. Returns: [{"id","name","description"}] """ query = gql(""" query Models($entity: String!) { models(first: 10, entityName: $entity) { edges { node { id name description } } } } """) return self._flatten_edges( self.gql( query, variable_values={"entity": entity or self.settings("entity")})["models"]) @normalize_exceptions def project(self, project, entity=None): """Retrive project Arguments: project (str): The project to get details for entity (str, optional): The entity to scope this project to. Returns: [{"id","name","repo","dockerImage","description"}] """ query = gql(""" query Models($entity: String, $project: String!) { model(name: $project, entityName: $entity) { id name repo dockerImage description } } """) return self.gql(query, variable_values={ "entity": entity, "project": project })["model"] @normalize_exceptions def sweep(self, sweep, specs, project=None, entity=None): """Retrieve sweep. Arguments: sweep (str): The sweep to get details for specs (str): history specs project (str, optional): The project to scope this sweep to. entity (str, optional): The entity to scope this sweep to. Returns: [{"id","name","repo","dockerImage","description"}] """ query = gql(""" query Models($entity: String, $project: String!, $sweep: String!, $specs: [JSONString!]!) { model(name: $project, entityName: $entity) { sweep(sweepName: $sweep) { id name method state description config createdAt heartbeatAt updatedAt earlyStopJobRunning bestLoss controller scheduler runs { edges { node { name state config exitcode heartbeatAt shouldStop failed stopped running summaryMetrics sampledHistory(specs: $specs) } } } } } } """) entity = entity or self.settings("entity") project = project or self.settings("project") response = self.gql( query, variable_values={ "entity": entity, "project": project, "sweep": sweep, "specs": specs, }, ) if response["model"] is None or response["model"]["sweep"] is None: raise ValueError("Sweep {}/{}/{} not found".format( entity, project, sweep)) data = response["model"]["sweep"] if data: data["runs"] = self._flatten_edges(data["runs"]) return data @normalize_exceptions def list_runs(self, project, entity=None): """Lists runs in W&B scoped by project. Arguments: project (str): The project to scope the runs to entity (str, optional): The entity to scope this project to. Defaults to public models Returns: [{"id",name","description"}] """ query = gql(""" query Buckets($model: String!, $entity: String!) { model(name: $model, entityName: $entity) { buckets(first: 10) { edges { node { id name displayName description } } } } } """) return self._flatten_edges( self.gql( query, variable_values={ "entity": entity or self.settings("entity"), "model": project or self.settings("project"), }, )["model"]["buckets"]) @normalize_exceptions def launch_run(self, command, project=None, entity=None, run_id=None): """Launch a run in the cloud. Arguments: command (str): The command to run program (str): The file to run project (str): The project to scope the runs to entity (str, optional): The entity to scope this project to. Defaults to public models run_id (str, optional): The run_id to scope to Returns: [{"podName","status"}] """ query = gql(""" mutation launchRun( $entity: String $model: String $runId: String $image: String $command: String $patch: String $cwd: String $datasets: [String] ) { launchRun(input: {id: $runId, entityName: $entity, patch: $patch, modelName: $model, image: $image, command: $command, datasets: $datasets, cwd: $cwd}) { podName status runId } } """) run_id = run_id or self.current_run_id assert run_id, "run_id must be specified" patch = BytesIO() if self.git.dirty: self.git.repo.git.execute(["git", "diff"], output_stream=patch) patch.seek(0) cwd = "." if self.git.enabled: cwd = cwd + os.getcwd().replace(self.git.repo.working_dir, "") return self.gql( query, variable_values={ "entity": entity or self.settings("entity"), "model": project or self.settings("project"), "command": command, "runId": run_id, "patch": patch.read().decode("utf8"), "cwd": cwd, }, ) @normalize_exceptions def run_config(self, project, run=None, entity=None): """Get the relevant configs for a run Arguments: project (str): The project to download, (can include bucket) run (str): The run to download entity (str, optional): The entity to scope this project to. """ query = gql(""" query Model($name: String!, $entity: String!, $run: String!) { model(name: $name, entityName: $entity) { bucket(name: $run) { config commit patch files(names: ["wandb-metadata.json"]) { edges { node { url } } } } } } """) run = run or self.current_run_id assert run, "run must be specified" response = self.gql(query, variable_values={ "name": project, "run": run, "entity": entity }) if response["model"] is None: raise ValueError("Run {}/{}/{} not found".format( entity, project, run)) run = response["model"]["bucket"] commit = run["commit"] patch = run["patch"] config = json.loads(run["config"] or "{}") if len(run["files"]["edges"]) > 0: url = run["files"]["edges"][0]["node"]["url"] res = requests.get(url) res.raise_for_status() metadata = res.json() else: metadata = {} return (commit, config, patch, metadata) @normalize_exceptions def run_resume_status(self, entity, project_name, name): """Check if a run exists and get resume information. Arguments: entity (str, optional): The entity to scope this project to. project_name (str): The project to download, (can include bucket) name (str): The run to download """ query = gql(""" query Model($project: String!, $entity: String, $name: String!) { model(name: $project, entityName: $entity) { id name entity { id name } bucket(name: $name, missingOk: true) { id name summaryMetrics displayName logLineCount historyLineCount eventsLineCount historyTail eventsTail config } } } """) response = self.gql( query, variable_values={ "entity": entity, "project": project_name, "name": name, }, ) if "model" not in response or "bucket" not in (response["model"] or {}): return None project = response["model"] self.set_setting("project", project_name) if "entity" in project: self.set_setting("entity", project["entity"]["name"]) return project["bucket"] @normalize_exceptions def check_stop_requested(self, project_name, entity_name, run_id): query = gql(""" query Model($projectName: String, $entityName: String, $runId: String!) { project(name:$projectName, entityName:$entityName) { run(name:$runId) { stopped } } } """) run_id = run_id or self.current_run_id assert run_id, "run_id must be specified" response = self.gql( query, variable_values={ "projectName": project_name, "entityName": entity_name, "runId": run_id, }, ) project = response.get("project", None) if not project: return False run = project.get("run", None) if not run: return False return run["stopped"] def format_project(self, project): return re.sub(r"\W+", "-", project.lower()).strip("-_") @normalize_exceptions def upsert_project(self, project, id=None, description=None, entity=None): """Create a new project Arguments: project (str): The project to create description (str, optional): A description of this project entity (str, optional): The entity to scope this project to. """ mutation = gql(""" mutation UpsertModel($name: String!, $id: String, $entity: String!, $description: String, $repo: String) { upsertModel(input: { id: $id, name: $name, entityName: $entity, description: $description, repo: $repo }) { model { name description } } } """) response = self.gql( mutation, variable_values={ "name": self.format_project(project), "entity": entity or self.settings("entity"), "description": description, "repo": self.git.remote_url, "id": id, }, ) return response["upsertModel"]["model"] @normalize_exceptions def pop_from_run_queue(self, entity=None, project=None): mutation = gql(""" mutation popFromRunQueue($entity: String!, $project: String!) { popFromRunQueue(input: { entityName: $entity, projectName: $project }) { runQueueItemId runSpec } } """) response = self.gql(mutation, variable_values={ "entity": entity, "project": project }) return response["popFromRunQueue"] @normalize_exceptions def upsert_run( self, id=None, name=None, project=None, host=None, group=None, tags=None, config=None, description=None, entity=None, state=None, display_name=None, notes=None, repo=None, job_type=None, program_path=None, commit=None, sweep_name=None, summary_metrics=None, num_retries=None, ): """Update a run Arguments: id (str, optional): The existing run to update name (str, optional): The name of the run to create group (str, optional): Name of the group this run is a part of project (str, optional): The name of the project config (dict, optional): The latest config params description (str, optional): A description of this project entity (str, optional): The entity to scope this project to. repo (str, optional): Url of the program's repository. state (str, optional): State of the program. job_type (str, optional): Type of job, e.g 'train'. program_path (str, optional): Path to the program. commit (str, optional): The Git SHA to associate the run with summary_metrics (str, optional): The JSON summary metrics """ mutation = gql(""" mutation UpsertBucket( $id: String, $name: String, $project: String, $entity: String!, $groupName: String, $description: String, $displayName: String, $notes: String, $commit: String, $config: JSONString, $host: String, $debug: Boolean, $program: String, $repo: String, $jobType: String, $state: String, $sweep: String, $tags: [String!], $summaryMetrics: JSONString, ) { upsertBucket(input: { id: $id, name: $name, groupName: $groupName, modelName: $project, entityName: $entity, description: $description, displayName: $displayName, notes: $notes, config: $config, commit: $commit, host: $host, debug: $debug, jobProgram: $program, jobRepo: $repo, jobType: $jobType, state: $state, sweep: $sweep, tags: $tags, summaryMetrics: $summaryMetrics, }) { bucket { id name displayName description config project { id name entity { id name } } } } } """) if config is not None: config = json.dumps(config) if not description or description.isspace(): description = None kwargs = {} if num_retries is not None: kwargs["num_retries"] = num_retries variable_values = { "id": id, "entity": entity or self.settings("entity"), "name": name, "project": project, "groupName": group, "tags": tags, "description": description, "config": config, "commit": commit, "displayName": display_name, "notes": notes, "host": None if self.settings().get("anonymous") == "true" else host, "debug": env.is_debug(env=self._environ), "repo": repo, "program": program_path, "jobType": job_type, "state": state, "sweep": sweep_name, "summaryMetrics": summary_metrics, } response = self.gql(mutation, variable_values=variable_values, **kwargs) run = response["upsertBucket"]["bucket"] project = run.get("project") if project: self.set_setting("project", project["name"]) entity = project.get("entity") if entity: self.set_setting("entity", entity["name"]) return response["upsertBucket"]["bucket"] @normalize_exceptions def upload_urls(self, project, files, run=None, entity=None, description=None): """Generate temporary resumeable upload urls Arguments: project (str): The project to download files (list or dict): The filenames to upload run (str): The run to upload to entity (str, optional): The entity to scope this project to. Defaults to wandb models Returns: (bucket_id, file_info) bucket_id: id of bucket we uploaded to file_info: A dict of filenames and urls, also indicates if this revision already has uploaded files. { 'weights.h5': { "url": "https://weights.url" }, 'model.json': { "url": "https://model.json", "updatedAt": '2013-04-26T22:22:23.832Z', 'md5': 'mZFLkyvTelC5g8XnyQrpOw==' }, } """ query = gql(""" query Model($name: String!, $files: [String]!, $entity: String!, $run: String!, $description: String) { model(name: $name, entityName: $entity) { bucket(name: $run, desc: $description) { id files(names: $files) { uploadHeaders edges { node { name url(upload: true) updatedAt } } } } } } """) run_id = run or self.current_run_id assert run, "run must be specified" entity = entity or self.settings("entity") query_result = self.gql( query, variable_values={ "name": project, "run": run_id, "entity": entity, "description": description, "files": [file for file in files], }, ) run = query_result["model"]["bucket"] if run: result = { file["name"]: file for file in self._flatten_edges(run["files"]) } return run["id"], run["files"]["uploadHeaders"], result else: raise CommError("Run does not exist {}/{}/{}.".format( entity, project, run_id)) @normalize_exceptions def download_urls(self, project, run=None, entity=None): """Generate download urls Arguments: project (str): The project to download run (str): The run to upload to entity (str, optional): The entity to scope this project to. Defaults to wandb models Returns: A dict of extensions and urls { 'weights.h5': { "url": "https://weights.url", "updatedAt": '2013-04-26T22:22:23.832Z', 'md5': 'mZFLkyvTelC5g8XnyQrpOw==' }, 'model.json': { "url": "https://model.url", "updatedAt": '2013-04-26T22:22:23.832Z', 'md5': 'mZFLkyvTelC5g8XnyQrpOw==' } } """ query = gql(""" query Model($name: String!, $entity: String!, $run: String!) { model(name: $name, entityName: $entity) { bucket(name: $run) { files { edges { node { name url md5 updatedAt } } } } } } """) run = run or self.current_run_id assert run, "run must be specified" query_result = self.gql( query, variable_values={ "name": project, "run": run, "entity": entity or self.settings("entity"), }, ) files = self._flatten_edges(query_result["model"]["bucket"]["files"]) return {file["name"]: file for file in files if file} @normalize_exceptions def download_url(self, project, file_name, run=None, entity=None): """Generate download urls Arguments: project (str): The project to download file_name (str): The name of the file to download run (str): The run to upload to entity (str, optional): The entity to scope this project to. Defaults to wandb models Returns: A dict of extensions and urls { "url": "https://weights.url", "updatedAt": '2013-04-26T22:22:23.832Z', 'md5': 'mZFLkyvTelC5g8XnyQrpOw==' } """ query = gql(""" query Model($name: String!, $fileName: String!, $entity: String!, $run: String!) { model(name: $name, entityName: $entity) { bucket(name: $run) { files(names: [$fileName]) { edges { node { name url md5 updatedAt } } } } } } """) run = run or self.current_run_id assert run, "run must be specified" query_result = self.gql( query, variable_values={ "name": project, "run": run, "fileName": file_name, "entity": entity or self.settings("entity"), }, ) if query_result["model"]: files = self._flatten_edges( query_result["model"]["bucket"]["files"]) return files[0] if len(files) > 0 and files[0].get( "updatedAt") else None else: return None @normalize_exceptions def download_file(self, url): """Initiate a streaming download Arguments: url (str): The url to download Returns: A tuple of the content length and the streaming response """ response = requests.get(url, stream=True) response.raise_for_status() return (int(response.headers.get("content-length", 0)), response) @normalize_exceptions def download_write_file(self, metadata, out_dir=None): """Download a file from a run and write it to wandb/ Arguments: metadata (obj): The metadata object for the file to download. Comes from Api.download_urls(). Returns: A tuple of the file's local path and the streaming response. The streaming response is None if the file already existed and was up to date. """ file_name = metadata["name"] path = os.path.join(out_dir or self.settings("wandb_dir"), file_name) if self.file_current(file_name, metadata["md5"]): return path, None size, response = self.download_file(metadata["url"]) with util.fsync_open(path, "wb") as file: for data in response.iter_content(chunk_size=1024): file.write(data) return path, response @normalize_exceptions def register_agent(self, host, sweep_id=None, project_name=None, entity=None): """Register a new agent Arguments: host (str): hostname persistent (bool): long running or oneoff sweep (str): sweep id project_name: (str): model that contains sweep """ mutation = gql(""" mutation CreateAgent( $host: String! $projectName: String!, $entityName: String!, $sweep: String! ) { createAgent(input: { host: $host, projectName: $projectName, entityName: $entityName, sweep: $sweep, }) { agent { id } } } """) if entity is None: entity = self.settings("entity") if project_name is None: project_name = self.settings("project") # don't retry on validation or not found errors def no_retry_4xx(e): if not isinstance(e, requests.HTTPError): return True if not (e.response.status_code >= 400 and e.response.status_code < 500): return True body = json.loads(e.response.content) raise UsageError(body["errors"][0]["message"]) response = self.gql( mutation, variable_values={ "host": host, "entityName": entity, "projectName": project_name, "sweep": sweep_id, }, check_retry_fn=no_retry_4xx, ) return response["createAgent"]["agent"] def agent_heartbeat(self, agent_id, metrics, run_states): """Notify server about agent state, receive commands. Arguments: agent_id (str): agent_id metrics (dict): system metrics run_states (dict): run_id: state mapping Returns: List of commands to execute. """ mutation = gql(""" mutation Heartbeat( $id: ID!, $metrics: JSONString, $runState: JSONString ) { agentHeartbeat(input: { id: $id, metrics: $metrics, runState: $runState }) { agent { id } commands } } """) try: response = self.gql( mutation, variable_values={ "id": agent_id, "metrics": json.dumps(metrics), "runState": json.dumps(run_states), }, ) except Exception as e: # GQL raises exceptions with stringified python dictionaries :/ message = ast.literal_eval(e.args[0])["message"] logger.error("Error communicating with W&B: %s", message) return [] else: return json.loads(response["agentHeartbeat"]["commands"]) @normalize_exceptions def upsert_sweep( self, config, controller=None, scheduler=None, obj_id=None, project=None, entity=None, ): """Upsert a sweep object. Arguments: config (str): sweep config (will be converted to yaml) """ project_query = """ project { id name entity { id name } } """ mutation_str = """ mutation UpsertSweep( $id: ID, $config: String, $description: String, $entityName: String!, $projectName: String!, $controller: JSONString, $scheduler: JSONString ) { upsertSweep(input: { id: $id, config: $config, description: $description, entityName: $entityName, projectName: $projectName, controller: $controller, scheduler: $scheduler }) { sweep { name _PROJECT_QUERY_ } } } """ # TODO(jhr): we need protocol versioning to know schema is not supported # for now we will just try both new and old query mutation_new = gql( mutation_str.replace("_PROJECT_QUERY_", project_query)) mutation_old = gql(mutation_str.replace("_PROJECT_QUERY_", "")) # don't retry on validation errors # TODO(jhr): generalize error handling routines def no_retry_4xx(e): if not isinstance(e, requests.HTTPError): return True if not (e.response.status_code >= 400 and e.response.status_code < 500): return True body = json.loads(e.response.content) raise UsageError(body["errors"][0]["message"]) for mutation in mutation_new, mutation_old: try: response = self.gql( mutation, variable_values={ "id": obj_id, "config": yaml.dump(config), "description": config.get("description"), "entityName": entity or self.settings("entity"), "projectName": project or self.settings("project"), "controller": controller, "scheduler": scheduler, }, check_retry_fn=no_retry_4xx, ) except UsageError as e: raise (e) except Exception as e: # graphql schema exception is generic err = e continue err = None break if err: raise (err) sweep = response["upsertSweep"]["sweep"] project = sweep.get("project") if project: self.set_setting("project", project["name"]) entity = project.get("entity") if entity: self.set_setting("entity", entity["name"]) return response["upsertSweep"]["sweep"]["name"] @normalize_exceptions def create_anonymous_api_key(self): """Creates a new API key belonging to a new anonymous user.""" mutation = gql(""" mutation CreateAnonymousApiKey { createAnonymousEntity(input: {}) { apiKey { name } } } """) response = self.gql(mutation, variable_values={}) return response["createAnonymousEntity"]["apiKey"]["name"] def file_current(self, fname, md5): """Checksum a file and compare the md5 with the known md5""" return os.path.isfile(fname) and util.md5_file(fname) == md5 @normalize_exceptions def pull(self, project, run=None, entity=None): """Download files from W&B Arguments: project (str): The project to download run (str): The run to upload to entity (str, optional): The entity to scope this project to. Defaults to wandb models Returns: The requests library response object """ project, run = self.parse_slug(project, run=run) assert run, "run must be specified" urls = self.download_urls(project, run, entity) responses = [] for file_name in urls: _, response = self.download_write_file(urls[file_name]) if response: responses.append(response) return responses def get_project(self): return self.settings("project") def _status_request(self, url, length): """Ask google how much we've uploaded""" return requests.put( url=url, headers={ "Content-Length": "0", "Content-Range": "bytes */%i" % length }, ) def _flatten_edges(self, response): """Return an array from the nested graphql relay structure""" return [node["node"] for node in response["edges"]]
class Api(object): """ Used for querying the wandb server. Examples: Most common way to initialize ``` wandb.Api() ``` Args: overrides (dict): You can set `base_url` if you are using a wandb server other than https://api.wandb.ai. You can also set defaults for `entity`, `project`, and `run`. """ _HTTP_TIMEOUT = env.get_http_timeout(9) def __init__(self, overrides={}): self.settings = { 'entity': None, 'project': None, 'run': "latest", 'base_url': env.get_base_url("https://api.wandb.ai") } self.settings.update(overrides) if 'username' in overrides and 'entity' not in overrides: wandb.termwarn( 'Passing "username" to Api is deprecated. please use "entity" instead.' ) self.settings['entity'] = overrides['username'] self._projects = {} self._runs = {} self._sweeps = {} self._base_client = Client(transport=RequestsHTTPTransport( headers={ 'User-Agent': self.user_agent, 'Use-Admin-Privileges': "true" }, use_json=True, # this timeout won't apply when the DNS lookup fails. in that case, it will be 60s # https://bugs.python.org/issue22889 timeout=self._HTTP_TIMEOUT, auth=("api", self.api_key), url='%s/graphql' % self.settings['base_url'])) self._client = RetryingClient(self._base_client) def create_run(self, **kwargs): return Run.create(self, **kwargs) @property def client(self): return self._client @property def user_agent(self): return 'W&B Public Client %s' % __version__ @property def api_key(self): auth = requests.utils.get_netrc_auth(self.settings['base_url']) key = None if auth: key = auth[-1] # Environment should take precedence if os.getenv("WANDB_API_KEY"): key = os.environ["WANDB_API_KEY"] return key def flush(self): """ The api object keeps a local cache of runs, so if the state of the run may change while executing your script you must clear the local cache with `api.flush()` to get the latest values associated with the run.""" self._runs = {} def _parse_path(self, path): """Parses paths in the following formats: url: entity/project/runs/run_id path: entity/project/run_id docker: entity/project:run_id entity is optional and will fallback to the current logged in user. """ run = self.settings['run'] project = self.settings['project'] entity = self.settings['entity'] parts = path.replace("/runs/", "/").strip("/ ").split("/") if ":" in parts[-1]: run = parts[-1].split(":")[-1] parts[-1] = parts[-1].split(":")[0] elif parts[-1]: run = parts[-1] if len(parts) > 1: project = parts[1] if entity and run == project: project = parts[0] else: entity = parts[0] else: project = parts[0] return (entity, project, run) def projects(self, entity=None, per_page=None): """Get projects for a given entity. Args: entity (str): Name of the entity requested. If None will fallback to default entity passed to :obj:`Api`. If no default entity, will raise a `ValueError`. per_page (int): Sets the page size for query pagination. None will use the default size. Usually there is no reason to change this. Returns: A :obj:`Projects` object which is an iterable collection of :obj:`Project` objects. """ if entity is None: entity = self.settings['entity'] if entity is None: raise ValueError( 'entity must be passed as a parameter, or set in settings') if entity not in self._projects: self._projects[entity] = Projects(self.client, entity, per_page=per_page) return self._projects[entity] def runs(self, path="", filters={}, order="-created_at", per_page=None): """Return a set of runs from a project that match the filters provided. You can filter by `config.*`, `summary.*`, `state`, `entity`, `createdAt`, etc. Examples: Find runs in my_project config.experiment_name has been set to "foo" ``` api.runs(path="my_entity/my_project", {"config.experiment_name": "foo"}) ``` Find runs in my_project config.experiment_name has been set to "foo" or "bar" ``` api.runs(path="my_entity/my_project", {"$or": [{"config.experiment_name": "foo"}, {"config.experiment_name": "bar"}]}) ``` Find runs in my_project sorted by ascending loss ``` api.runs(path="my_entity/my_project", {"order": "+summary.loss"}) ``` Args: path (str): path to project, should be in the form: "entity/project" filters (dict): queries for specific runs using the MongoDB query language. You can filter by run properties such as config.key, summary.key, state, entity, createdAt, etc. For example: {"config.experiment_name": "foo"} would find runs with a config entry of experiment name set to "foo" You can compose operations to make more complicated queries, see Reference for the language is at https://docs.mongodb.com/manual/reference/operator/query order (str): Order can be `created_at`, `heartbeat_at`, `config.*.value`, or `summary.*`. If you prepend order with a + order is ascending. If you prepend order with a - order is descending (default). The default order is run.created_at from newest to oldest. Returns: A :obj:`Runs` object, which is an iterable collection of :obj:`Run` objects. """ entity, project, run = self._parse_path(path) if not self._runs.get(path): self._runs[path + str(filters) + str(order)] = Runs( self.client, entity, project, filters=filters, order=order, per_page=per_page) return self._runs[path + str(filters) + str(order)] @normalize_exceptions def run(self, path=""): """Returns a single run by parsing path in the form entity/project/run_id. Args: path (str): path to run in the form entity/project/run_id. If api.entity is set, this can be in the form project/run_id and if api.project is set this can just be the run_id. Returns: A :obj:`Run` object. """ entity, project, run = self._parse_path(path) if not self._runs.get(path): self._runs[path] = Run(self.client, entity, project, run) return self._runs[path] @normalize_exceptions def sweep(self, path=""): """ Returns a sweep by parsing path in the form entity/project/sweep_id. Args: path (str, optional): path to sweep in the form entity/project/sweep_id. If api.entity is set, this can be in the form project/sweep_id and if api.project is set this can just be the sweep_id. Returns: A :obj:`Sweep` object. """ entity, project, sweep_id = self._parse_path(path) if not self._sweeps.get(sweep_id): self._sweeps[path] = Sweep(self.client, entity, project, sweep_id) return self._sweeps[path]
class FileStreamApi(object): """Pushes chunks of files to our streaming endpoint. This class is used as a singleton. It has a thread that serializes access to the streaming endpoint and performs rate-limiting and batching. TODO: Differentiate between binary/text encoding. """ Finish = collections.namedtuple('Finish', ('exitcode')) HTTP_TIMEOUT = env.get_http_timeout(10) MAX_ITEMS_PER_PUSH = 10000 def __init__(self, api, run_id): self._api = api self._run_id = run_id self._client = requests.Session() self._client.auth = ('api', api.api_key) self._client.timeout = self.HTTP_TIMEOUT self._client.headers.update({ 'User-Agent': api.user_agent, 'X-WANDB-USERNAME': env.get_username() }) self._file_policies = {} self._queue = queue.Queue() self._thread = threading.Thread(target=self._thread_body) # It seems we need to make this a daemon thread to get sync.py's atexit handler to run, which # cleans this thread up. self._thread.daemon = True self._init_endpoint() def _init_endpoint(self): settings = self._api.settings() self._endpoint = "{base}/files/{entity}/{project}/{run}/file_stream".format( base=settings['base_url'], entity=settings['entity'], project=settings['project'], run=self._run_id) def start(self): self._init_endpoint() self._thread.start() def set_default_file_policy(self, filename, file_policy): """Set an upload policy for a file unless one has already been set. """ if filename not in self._file_policies: self._file_policies[filename] = file_policy def set_file_policy(self, filename, file_policy): self._file_policies[filename] = file_policy @property def heartbeat_seconds(self): # Defaults to 30 return self._api.dynamic_settings["heartbeat_seconds"] def rate_limit_seconds(self): run_time = time.time() - wandb.START_TIME if run_time < 60: return max(1, self.heartbeat_seconds / 15) elif run_time < 300: return max(2.5, self.heartbeat_seconds / 3) else: return max(5, self.heartbeat_seconds) def _read_queue(self): # called from the push thread (_thread_body), this does an initial read # that'll block for up to rate_limit_seconds. Then it tries to read # as much out of the queue as it can. We do this because the http post # to the server happens within _thread_body, and can take longer than # our rate limit. So next time we get a chance to read the queue we want # read all the stuff that queue'd up since last time. # # If we have more than MAX_ITEMS_PER_PUSH in the queue then the push thread # will get behind and data will buffer up in the queue. return util.read_many_from_queue(self._queue, self.MAX_ITEMS_PER_PUSH, self.rate_limit_seconds()) def _thread_body(self): posted_data_time = time.time() posted_anything_time = time.time() ready_chunks = [] finished = None while finished is None: items = self._read_queue() for item in items: if isinstance(item, self.Finish): finished = item else: # item is Chunk ready_chunks.append(item) cur_time = time.time() if ready_chunks and (finished or cur_time - posted_data_time > self.rate_limit_seconds()): posted_data_time = cur_time posted_anything_time = cur_time self._send(ready_chunks) ready_chunks = [] if cur_time - posted_anything_time > self.heartbeat_seconds: posted_anything_time = cur_time self._handle_response( util.request_with_retry(self._client.post, self._endpoint, json={ 'complete': False, 'failed': False })) # post the final close message. (item is self.Finish instance now) util.request_with_retry(self._client.post, self._endpoint, json={ 'complete': True, 'exitcode': int(finished.exitcode) }) def _handle_response(self, response): """Logs dropped chunks and updates dynamic settings""" if isinstance(response, Exception): logging.error("dropped chunk %s" % response) elif response.json().get("limits"): parsed = response.json() self._api.dynamic_settings.update(parsed["limits"]) def _send(self, chunks): # create files dict. dict of <filename: chunks> pairs where chunks is a list of # [chunk_id, chunk_data] tuples (as lists since this will be json). files = {} # Groupby needs group keys to be consecutive, so sort first. chunks.sort(key=lambda c: c.filename) #print('fsapi', chunks) for filename, file_chunks in itertools.groupby(chunks, lambda c: c.filename): file_chunks = list(file_chunks) # groupby returns iterator self.set_default_file_policy(filename, DefaultFilePolicy()) files[filename] = self._file_policies[filename].process_chunks( file_chunks) self._handle_response( util.request_with_retry(self._client.post, self._endpoint, json={'files': files})) def stream_file(self, path): name = path.split("/")[-1] self._send([Chunk(name, line) for line in open(path).readlines()]) def push(self, filename, data): """Push a chunk of a file to the streaming endpoint. Args: filename: Name of file that this is a chunk of. chunk_id: TODO: change to 'offset' chunk: File data. """ self._queue.put(Chunk(filename, data)) def finish(self, exitcode): """Cleans up. Anything pushed after finish will be dropped. Args: exitcode: The exitcode of the watched process. """ self._queue.put(self.Finish(exitcode)) self._thread.join()
class Api(object): """W&B Public API Args: setting_overrides(:obj:`dict`, optional): You can set defaults such as entity, project, and run here as well as which api server to use. """ HTTP_TIMEOUT = env.get_http_timeout(9) def __init__(self, overrides={}): self.settings = { 'entity': None, 'project': None, 'run': "latest", 'base_url': env.get_base_url("https://api.wandb.ai") } self.settings.update(overrides) if 'username' in overrides and 'entity' not in overrides: wandb.termwarn('Passing "username" to Api is deprecated. please use "entity" instead.') self.settings['entity'] = overrides['username'] self._projects = {} self._runs = {} self._sweeps = {} self._base_client = Client( transport=RequestsHTTPTransport( headers={'User-Agent': self.user_agent, 'Use-Admin-Privileges': "true"}, use_json=True, # this timeout won't apply when the DNS lookup fails. in that case, it will be 60s # https://bugs.python.org/issue22889 timeout=self.HTTP_TIMEOUT, auth=("api", self.api_key), url='%s/graphql' % self.settings['base_url'] ) ) self._client = RetryingClient(self._base_client) def create_run(self, **kwargs): return Run.create(self, **kwargs) @property def client(self): return self._client @property def user_agent(self): return 'W&B Public Client %s' % __version__ @property def api_key(self): auth = requests.utils.get_netrc_auth(self.settings['base_url']) key = None if auth: key = auth[-1] # Environment should take precedence if os.getenv("WANDB_API_KEY"): key = os.environ["WANDB_API_KEY"] return key def flush(self): """Clear the local cache""" self._runs = {} def _parse_path(self, path): """Parses paths in the following formats: url: entity/project/runs/run_id path: entity/project/run_id docker: entity/project:run_id entity is optional and will fallback to the current logged in user. """ run = self.settings['run'] project = self.settings['project'] entity = self.settings['entity'] parts = path.replace("/runs/", "/").strip("/ ").split("/") if ":" in parts[-1]: run = parts[-1].split(":")[-1] parts[-1] = parts[-1].split(":")[0] elif parts[-1]: run = parts[-1] if len(parts) > 1: project = parts[1] if entity and run == project: project = parts[0] else: entity = parts[0] else: project = parts[0] return (entity, project, run) def projects(self, entity=None, per_page=None): """Return a list of projects for a given entity.""" if entity is None: entity = self.settings['entity'] if entity is None: raise ValueError('entity must be passed as a parameter, or set in settings') if entity not in self._projects: self._projects[entity] = Projects(self.client, entity, per_page=per_page) return self._projects[entity] def runs(self, path="", filters={}, order="-created_at", per_page=None): """Return a set of runs from a project that match the filters provided. You can filter by config.*, summary.*, state, entity, createdAt, etc. The filters use the same query language as MongoDB: https://docs.mongodb.com/manual/reference/operator/query Order can be created_at, heartbeat_at, config.*.value, or summary.*. By default the order is descending, if you prepend order with a + order becomes ascending. """ entity, project, run = self._parse_path(path) if not self._runs.get(path): self._runs[path + str(filters) + str(order)] = Runs(self.client, entity, project, filters=filters, order=order, per_page=per_page) return self._runs[path + str(filters) + str(order)] @normalize_exceptions def run(self, path=""): """Returns a run by parsing path in the form entity/project/run, if defaults were set on the Api, only overrides what's passed. I.E. you can just pass run_id if you set entity and project on the Api""" entity, project, run = self._parse_path(path) if not self._runs.get(path): self._runs[path] = Run(self.client, entity, project, run) return self._runs[path] @normalize_exceptions def sweep(self, path=""): entity, project, sweep_id = self._parse_path(path) if not self._sweeps.get(sweep_id): self._sweeps[path] = Sweep(self.client, entity, project, sweep_id) return self._sweeps[path]
class Api(object): """W&B Internal Api wrapper Note: Settings are automatically overridden by looking for a `wandb/settings` file in the current working directory or it's parent directory. If none can be found, we look in the current users home directory. Args: default_settings(:obj:`dict`, optional): If you aren't using a settings file or you wish to override the section to use in the settings file Override the settings here. """ HTTP_TIMEOUT = env.get_http_timeout(10) def __init__( self, default_settings=None, load_settings=True, retry_timedelta=None, environ=os.environ, ): if retry_timedelta is None: retry_timedelta = datetime.timedelta(days=1) self._environ = environ self.default_settings = { "section": "default", "git_remote": "origin", "ignore_globs": [], "base_url": "https://api.wandb.ai", } self.retry_timedelta = retry_timedelta self.default_settings.update(default_settings or {}) self.retry_uploads = 10 self._settings = Settings( load_settings=load_settings, root_dir=self.default_settings.get("root_dir")) # self.git = GitRepo(remote=self.settings("git_remote")) self.git = None # Mutable settings set by the _file_stream_api self.dynamic_settings = { "system_sample_seconds": 2, "system_samples": 15, "heartbeat_seconds": 30, } self.client = Client(transport=RequestsHTTPTransport( headers={ "User-Agent": self.user_agent, "X-WANDB-USERNAME": env.get_username(env=self._environ), "X-WANDB-USER-EMAIL": env.get_user_email(env=self._environ), }, use_json=True, # this timeout won't apply when the DNS lookup fails. in that case, it will be 60s # https://bugs.python.org/issue22889 timeout=self.HTTP_TIMEOUT, auth=("api", self.api_key or ""), url="%s/graphql" % self.settings("base_url"), )) self.gql = retry.Retry( self.execute, retry_timedelta=retry_timedelta, check_retry_fn=util.no_retry_auth, retryable_exceptions=(RetryError, requests.RequestException), ) self._current_run_id = None self._file_stream_api = None def reauth(self): """Ensures the current api key is set in the transport""" self.client.transport.auth = ("api", self.api_key or "") def relocate(self): """Ensures the current api points to the right server""" self.client.transport.url = "%s/graphql" % self.settings("base_url") def execute(self, *args, **kwargs): """Wrapper around execute that logs in cases of failure.""" try: return self.client.execute(*args, **kwargs) except requests.exceptions.HTTPError as err: res = err.response logger.error("%s response executing GraphQL." % res.status_code) logger.error(res.text) self.display_gorilla_error_if_found(res) six.reraise(*sys.exc_info()) def display_gorilla_error_if_found(self, res): try: data = res.json() except ValueError: return if "errors" in data and isinstance(data["errors"], list): for err in data["errors"]: if not err.get("message"): continue wandb.termerror("Error while calling W&B API: {} ({})".format( err["message"], res)) def disabled(self): return self._settings.get(Settings.DEFAULT_SECTION, "disabled", fallback=False) def sync_spell(self, run, env=None): """Syncs this run with spell""" try: env = env or os.environ run.config["_wandb"]["spell_url"] = env.get("SPELL_RUN_URL") run.config.persist() try: url = run.get_url() except CommError as e: wandb.termerror("Unable to register run with spell.run: %s" % str(e)) return False return requests.put( env.get("SPELL_API_URL", "https://api.spell.run") + "/wandb_url", json={ "access_token": env.get("WANDB_ACCESS_TOKEN"), "url": url }, timeout=2, ) except requests.RequestException: return False def save_patches(self, out_dir): """Save the current state of this repository to one or more patches. Makes one patch against HEAD and another one against the most recent commit that occurs in an upstream branch. This way we can be robust to history editing as long as the user never does "push -f" to break history on an upstream branch. Writes the first patch to <out_dir>/<DIFF_FNAME> and the second to <out_dir>/upstream_diff_<commit_id>.patch. Args: out_dir (str): Directory to write the patch files. """ if not self.git.enabled: return False try: root = self.git.root if self.git.dirty: patch_path = os.path.join(out_dir, wandb_lib.filenames.DIFF_FNAME) if self.git.has_submodule_diff: with open(patch_path, "wb") as patch: # we diff against HEAD to ensure we get changes in the index subprocess.check_call( ["git", "diff", "--submodule=diff", "HEAD"], stdout=patch, cwd=root, timeout=5, ) else: with open(patch_path, "wb") as patch: subprocess.check_call(["git", "diff", "HEAD"], stdout=patch, cwd=root, timeout=5) upstream_commit = self.git.get_upstream_fork_point() if upstream_commit and upstream_commit != self.git.repo.head.commit: sha = upstream_commit.hexsha upstream_patch_path = os.path.join( out_dir, "upstream_diff_{}.patch".format(sha)) if self.git.has_submodule_diff: with open(upstream_patch_path, "wb") as upstream_patch: subprocess.check_call( ["git", "diff", "--submodule=diff", sha], stdout=upstream_patch, cwd=root, timeout=5, ) else: with open(upstream_patch_path, "wb") as upstream_patch: subprocess.check_call( ["git", "diff", sha], stdout=upstream_patch, cwd=root, timeout=5, ) # TODO: A customer saw `ValueError: Reference at 'refs/remotes/origin/foo' does not exist` # so we now catch ValueError. Catching this error feels too generic. except ( ValueError, subprocess.CalledProcessError, subprocess.TimeoutExpired, ) as e: logger.error("Error generating diff: %s" % e) def set_current_run_id(self, run_id): self._current_run_id = run_id @property def current_run_id(self): return self._current_run_id @property def user_agent(self): return "W&B Internal Client %s" % __version__ @property def api_key(self): auth = requests.utils.get_netrc_auth(self.api_url) key = None if auth: key = auth[-1] # Environment should take precedence if self._environ.get(env.API_KEY): key = self._environ.get(env.API_KEY) return key @property def api_url(self): return self.settings("base_url") @property def app_url(self): return wandb.util.app_url(self.api_url) def settings(self, key=None, section=None): """The settings overridden from the wandb/settings file. Args: key (str, optional): If provided only this setting is returned section (str, optional): If provided this section of the setting file is used, defaults to "default" Returns: A dict with the current settings { "entity": "models", "base_url": "https://api.wandb.ai", "project": None } """ result = self.default_settings.copy() result.update(self._settings.items(section=section)) result.update({ "entity": env.get_entity( self._settings.get( Settings.DEFAULT_SECTION, "entity", fallback=result.get("entity"), ), env=self._environ, ), "project": env.get_project( self._settings.get( Settings.DEFAULT_SECTION, "project", fallback=result.get("project"), ), env=self._environ, ), "base_url": env.get_base_url( self._settings.get( Settings.DEFAULT_SECTION, "base_url", fallback=result.get("base_url"), ), env=self._environ, ), "ignore_globs": env.get_ignore( self._settings.get( Settings.DEFAULT_SECTION, "ignore_globs", fallback=result.get("ignore_globs"), ), env=self._environ, ), }) return result if key is None else result[key] def clear_setting(self, key, globally=False, persist=False): self._settings.clear(Settings.DEFAULT_SECTION, key, globally=globally, persist=persist) def set_setting(self, key, value, globally=False, persist=False): self._settings.set(Settings.DEFAULT_SECTION, key, value, globally=globally, persist=persist) if key == "entity": env.set_entity(value, env=self._environ) elif key == "project": env.set_project(value, env=self._environ) elif key == "base_url": self.relocate() def parse_slug(self, slug, project=None, run=None): if slug and "/" in slug: parts = slug.split("/") project = parts[0] run = parts[1] else: project = project or self.settings().get("project") if project is None: raise CommError("No default project configured.") run = run or slug or env.get_run(env=self._environ) if run is None: run = "latest" return (project, run) @normalize_exceptions def viewer(self): query = gql(""" query Viewer{ viewer { id entity teams { edges { node { name } } } } } """) res = self.gql(query) return res.get("viewer") or {} @normalize_exceptions def list_projects(self, entity=None): """Lists projects in W&B scoped by entity. Args: entity (str, optional): The entity to scope this project to. Returns: [{"id","name","description"}] """ query = gql(""" query Models($entity: String!) { models(first: 10, entityName: $entity) { edges { node { id name description } } } } """) return self._flatten_edges( self.gql( query, variable_values={"entity": entity or self.settings("entity")})["models"]) @normalize_exceptions def project(self, project, entity=None): """Retrive project Args: project (str): The project to get details for entity (str, optional): The entity to scope this project to. Returns: [{"id","name","repo","dockerImage","description"}] """ query = gql(""" query Models($entity: String, $project: String!) { model(name: $project, entityName: $entity) { id name repo dockerImage description } } """) return self.gql(query, variable_values={ "entity": entity, "project": project })["model"] @normalize_exceptions def sweep(self, sweep, specs, project=None, entity=None): """Retrieve sweep. Args: sweep (str): The sweep to get details for specs (str): history specs project (str, optional): The project to scope this sweep to. entity (str, optional): The entity to scope this sweep to. Returns: [{"id","name","repo","dockerImage","description"}] """ query = gql(""" query Models($entity: String, $project: String!, $sweep: String!, $specs: [JSONString!]!) { model(name: $project, entityName: $entity) { sweep(sweepName: $sweep) { id name method state description config createdAt heartbeatAt updatedAt earlyStopJobRunning bestLoss controller scheduler runs { edges { node { name state config exitcode heartbeatAt shouldStop failed stopped running summaryMetrics sampledHistory(specs: $specs) } } } } } } """) entity = entity or self.settings("entity") project = project or self.settings("project") response = self.gql( query, variable_values={ "entity": entity, "project": project, "sweep": sweep, "specs": specs, }, ) if response["model"] is None or response["model"]["sweep"] is None: raise ValueError("Sweep {}/{}/{} not found".format( entity, project, sweep)) data = response["model"]["sweep"] if data: data["runs"] = self._flatten_edges(data["runs"]) return data @normalize_exceptions def list_runs(self, project, entity=None): """Lists runs in W&B scoped by project. Args: project (str): The project to scope the runs to entity (str, optional): The entity to scope this project to. Defaults to public models Returns: [{"id",name","description"}] """ query = gql(""" query Buckets($model: String!, $entity: String!) { model(name: $model, entityName: $entity) { buckets(first: 10) { edges { node { id name displayName description } } } } } """) return self._flatten_edges( self.gql( query, variable_values={ "entity": entity or self.settings("entity"), "model": project or self.settings("project"), }, )["model"]["buckets"]) @normalize_exceptions def launch_run(self, command, project=None, entity=None, run_id=None): """Launch a run in the cloud. Args: command (str): The command to run program (str): The file to run project (str): The project to scope the runs to entity (str, optional): The entity to scope this project to. Defaults to public models run_id (str, optional): The run_id to scope to Returns: [{"podName","status"}] """ query = gql(""" mutation launchRun( $entity: String $model: String $runId: String $image: String $command: String $patch: String $cwd: String $datasets: [String] ) { launchRun(input: {id: $runId, entityName: $entity, patch: $patch, modelName: $model, image: $image, command: $command, datasets: $datasets, cwd: $cwd}) { podName status runId } } """) run_id = run_id or self.current_run_id assert run_id, "run_id must be specified" patch = BytesIO() if self.git.dirty: self.git.repo.git.execute(["git", "diff"], output_stream=patch) patch.seek(0) cwd = "." if self.git.enabled: cwd += os.getcwd().replace(self.git.repo.working_dir, "") return self.gql( query, variable_values={ "entity": entity or self.settings("entity"), "model": project or self.settings("project"), "command": command, "runId": run_id, "patch": patch.read().decode("utf8"), "cwd": cwd, }, ) @normalize_exceptions def run_config(self, project, run=None, entity=None): """Get the relevant configs for a run Args: project (str): The project to download, (can include bucket) run (str): The run to download entity (str, optional): The entity to scope this project to. """ query = gql(""" query Model($name: String!, $entity: String!, $run: String!) { model(name: $name, entityName: $entity) { bucket(name: $run) { config commit patch files(names: ["wandb-metadata.json"]) { edges { node { url } } } } } } """) run = run or self.current_run_id assert run, "run must be specified" response = self.gql(query, variable_values={ "name": project, "run": run, "entity": entity }) if response["model"] is None: raise ValueError("Run {}/{}/{} not found".format( entity, project, run)) run = response["model"]["bucket"] commit = run["commit"] patch = run["patch"] config = json.loads(run["config"] or "{}") if len(run["files"]["edges"]) > 0: url = run["files"]["edges"][0]["node"]["url"] res = requests.get(url) res.raise_for_status() metadata = res.json() else: metadata = {} return (commit, config, patch, metadata) @normalize_exceptions def run_resume_status(self, entity, project_name, name): """Check if a run exists and get resume information. Args: entity (str, optional): The entity to scope this project to. project_name (str): The project to download, (can include bucket) name (str): The run to download """ query = gql(""" query Model($project: String!, $entity: String, $name: String!) { model(name: $project, entityName: $entity) { id name entity { id name } bucket(name: $name, missingOk: true) { id name summaryMetrics displayName logLineCount historyLineCount eventsLineCount historyTail eventsTail config } } } """) response = self.gql( query, variable_values={ "entity": entity, "project": project_name, "name": name, }, ) if "model" not in response or "bucket" not in (response["model"] or {}): return None project = response["model"] self.set_setting("project", project_name) if "entity" in project: self.set_setting("entity", project["entity"]["name"]) return project["bucket"] @normalize_exceptions def check_stop_requested(self, project_name, entity_name, run_id): query = gql(""" query Model($projectName: String, $entityName: String, $runId: String!) { project(name:$projectName, entityName:$entityName) { run(name:$runId) { stopped } } } """) run_id = run_id or self.current_run_id assert run_id, "run_id must be specified" response = self.gql( query, variable_values={ "projectName": project_name, "entityName": entity_name, "runId": run_id, }, ) project = response.get("project", None) if not project: return False run = project.get("run", None) if not run: return False return run["stopped"] def format_project(self, project): return re.sub(r"\W+", "-", project.lower()).strip("-_") @normalize_exceptions def upsert_project(self, project, id=None, description=None, entity=None): """Create a new project Args: project (str): The project to create description (str, optional): A description of this project entity (str, optional): The entity to scope this project to. """ mutation = gql(""" mutation UpsertModel($name: String!, $id: String, $entity: String!, $description: String, $repo: String) { upsertModel(input: { id: $id, name: $name, entityName: $entity, description: $description, repo: $repo }) { model { name description } } } """) response = self.gql( mutation, variable_values={ "name": self.format_project(project), "entity": entity or self.settings("entity"), "description": description, "repo": self.git.remote_url, "id": id, }, ) return response["upsertModel"]["model"] @normalize_exceptions def pop_from_run_queue(self, entity=None, project=None): mutation = gql(""" mutation popFromRunQueue($entity: String!, $project: String!) { popFromRunQueue(input: { entityName: $entity, projectName: $project }) { runQueueItemId runSpec } } """) response = self.gql(mutation, variable_values={ "entity": entity, "project": project }) return response["popFromRunQueue"] @normalize_exceptions def upsert_run( self, id=None, name=None, project=None, host=None, group=None, tags=None, config=None, description=None, entity=None, state=None, display_name=None, notes=None, repo=None, job_type=None, program_path=None, commit=None, sweep_name=None, summary_metrics=None, num_retries=None, ): """Update a run Args: id (str, optional): The existing run to update name (str, optional): The name of the run to create group (str, optional): Name of the group this run is a part of project (str, optional): The name of the project config (dict, optional): The latest config params description (str, optional): A description of this project entity (str, optional): The entity to scope this project to. repo (str, optional): Url of the program's repository. state (str, optional): State of the program. job_type (str, optional): Type of job, e.g 'train'. program_path (str, optional): Path to the program. commit (str, optional): The Git SHA to associate the run with summary_metrics (str, optional): The JSON summary metrics """ mutation = gql(""" mutation UpsertBucket( $id: String, $name: String, $project: String, $entity: String!, $groupName: String, $description: String, $displayName: String, $notes: String, $commit: String, $config: JSONString, $host: String, $debug: Boolean, $program: String, $repo: String, $jobType: String, $state: String, $sweep: String, $tags: [String!], $summaryMetrics: JSONString, ) { upsertBucket(input: { id: $id, name: $name, groupName: $groupName, modelName: $project, entityName: $entity, description: $description, displayName: $displayName, notes: $notes, config: $config, commit: $commit, host: $host, debug: $debug, jobProgram: $program, jobRepo: $repo, jobType: $jobType, state: $state, sweep: $sweep, tags: $tags, summaryMetrics: $summaryMetrics, }) { bucket { id name displayName description config project { id name entity { id name } } } } } """) if config is not None: config = json.dumps(config) if not description or description.isspace(): description = None kwargs = {} if num_retries is not None: kwargs["num_retries"] = num_retries variable_values = { "id": id, "entity": entity or self.settings("entity"), "name": name, "project": project, "groupName": group, "tags": tags, "description": description, "config": config, "commit": commit, "displayName": display_name, "notes": notes, "host": None if self.settings().get("anonymous") == "true" else host, "debug": env.is_debug(env=self._environ), "repo": repo, "program": program_path, "jobType": job_type, "state": state, "sweep": sweep_name, "summaryMetrics": summary_metrics, } response = self.gql(mutation, variable_values=variable_values, **kwargs) run = response["upsertBucket"]["bucket"] if project := run.get("project"): self.set_setting("project", project["name"]) if entity := project.get("entity"): self.set_setting("entity", entity["name"])