def no_retry_auth(e): if hasattr(e, "exception"): e = e.exception if not isinstance(e, requests.HTTPError): return True # Don't retry bad request errors; raise immediately if e.response.status_code == 400: return False # Retry all non-forbidden/unauthorized/not-found errors. if e.response.status_code not in (401, 403, 404): return True # Crash w/message on forbidden/unauthorized errors. if e.response.status_code == 401: extra = "" if wandb.run and str(wandb.run.api.api_key).startswith("local-"): extra = " --host=http://localhost:8080" if wandb.run.api.api_url == "https://api.wandb.ai": raise CommError( "Attempting to authenticate with the cloud using a local API key. Set WANDB_BASE_URL to your local instance." ) raise CommError("Invalid or missing api_key. Run wandb login" + extra) elif wandb.run: raise CommError("Permission denied to access {}".format( wandb.run.path)) else: raise CommError( "Permission denied, ask the project owner to grant you access")
def _load_entity(self, api, network): if not api.api_key: raise CommError( "Can't find API key, run wandb login or set WANDB_API_KEY") entity = api.settings('entity') if network: if api.settings('entity') is None: # Kaggle has internet disabled by default, this checks for that case async_viewer = util.async_call(api.viewer, timeout=3) viewer, viewer_thread = async_viewer() if viewer_thread.is_alive(): if is_kaggle(): raise CommError( "To use W&B in kaggle you must enable internet in the settings panel on the right." ) else: raise CommError( "Can't connect to network to query entity from API key" ) if viewer.get('entity'): api.set_setting('entity', viewer['entity']) entity = api.settings('entity') if not entity: # This can happen on network failure raise CommError( "Can't connect to network to query entity from API key") return entity
def upload_file(self, url, file, callback=None, extra_headers={}): """Uploads a file to W&B with failure resumption Args: url (str): The url to download file (str): The path to the file you want to upload callback (:obj:`func`, optional): A callback which is passed the number of bytes uploaded since the last time it was called, used to report progress Returns: The requests library response object """ extra_headers = extra_headers.copy() response = None progress = Progress(file, callback=callback) if progress.len == 0: raise CommError("%s is an empty file" % file.name) try: response = requests.put( url, data=progress, headers=extra_headers) response.raise_for_status() except requests.exceptions.RequestException as e: status_code = e.response.status_code if e.response != None else 0 # Retry errors from cloud storage or local network issues if status_code in (308, 409, 429, 500, 502, 503, 504) or isinstance(e, (requests.exceptions.Timeout, requests.exceptions.ConnectionError)): util.sentry_reraise(retry.TransientException(exc=e)) else: util.sentry_reraise(e) return response
def push(self, files, run=None, entity=None, project=None, description=None, force=True, progress=False): """Uploads multiple files to W&B Args: files (list or dict): The filenames to upload run (str, optional): The run to upload to entity (str, optional): The entity to scope this project to. Defaults to wandb models project (str, optional): The name of the project to upload to. Defaults to the one in settings. description (str, optional): The description of the changes force (bool, optional): Whether to prevent push if git has uncommitted changes progress (callable, or stream): If callable, will be called with (chunk_bytes, total_bytes) as argument else if True, renders a progress bar to stream. Returns: The requests library response object """ if project is None: project = self.get_project() if project is None: raise CommError("No project configured.") if run is None: run = self.current_run_id # TODO(adrian): we use a retriable version of self.upload_file() so # will never retry self.upload_urls() here. Instead, maybe we should # make push itself retriable. run_id, result = self.upload_urls( project, files, run, entity, description) responses = [] for file_name, file_info in result.items(): file_url = file_info['url'] # If the upload URL is relative, fill it in with the base URL, # since its a proxied file store like the on-prem VM. if file_url.startswith('/'): file_url = '{}{}'.format(self.api_url, file_url) try: # To handle Windows paths # TODO: this doesn't handle absolute paths... normal_name = os.path.join(*file_name.split("/")) open_file = files[file_name] if isinstance( files, dict) else open(normal_name, "rb") except IOError: print("%s does not exist" % file_name) continue if progress: if hasattr(progress, '__call__'): responses.append(self.upload_file_retry( file_url, open_file, progress)) else: length = os.fstat(open_file.fileno()).st_size with click.progressbar(file=progress, length=length, label='Uploading file: %s' % (file_name), fill_char=click.style('&', fg='green')) as bar: responses.append(self.upload_file_retry( file_url, open_file, lambda bites, _: bar.update(bites))) else: responses.append(self.upload_file_retry(file_info['url'], open_file)) open_file.close() return responses
def upload_file(self, url, file, callback=None, extra_headers={}): """Uploads a file to W&B with failure resumption Args: url (str): The url to download file (str): The path to the file you want to upload callback (:obj:`func`, optional): A callback which is passed the number of bytes uploaded since the last time it was called, used to report progress Returns: The requests library response object """ extra_headers = extra_headers.copy() response = None if os.stat(file.name).st_size == 0: raise CommError("%s is an empty file" % file.name) try: progress = Progress(file, callback=callback) response = requests.put(url, data=progress, headers=extra_headers) response.raise_for_status() except requests.exceptions.RequestException as e: total = progress.len status = self._status_request(url, total) # TODO(adrian): there's probably even more stuff we should add here # like if we're offline, we should retry then too if status.status_code in (308, 408, 500, 502, 503, 504): util.sentry_reraise(retry.TransientException(exc=e)) else: util.sentry_reraise(e) return response
def no_retry_auth(e): if hasattr(e, "exception"): e = e.exception if not isinstance(e, requests.HTTPError): return True # Don't retry bad request errors; raise immediately if e.response.status_code == 400: return False # Retry all non-forbidden/unauthorized errors. if e.response.status_code not in (401, 403): return True # Crash w/message on forbidden/unauthorized errors. if e.response.status_code == 401: raise CommError("Invalid or missing api_key. Run wandb login") else: raise CommError("Permission denied, ask the project owner to grant you access")
def store_path(self, artifact, path, name=None, checksum=True, max_objects=None): self.init_boto() bucket, key = self._parse_uri(path) max_objects = max_objects or DEFAULT_MAX_OBJECTS if not checksum: return [ArtifactManifestEntry(name or key, path, digest=path)] objs = [self._s3.Object(bucket, key)] start_time = None multi = False try: objs[0].load() except self._botocore.exceptions.ClientError as e: if e.response['Error']['Code'] == "404": multi = True start_time = time.time() termlog('Generating checksum for up to %i objects with prefix "%s"... ' % (max_objects, key), newline=False) objs = self._s3.Bucket(bucket).objects.filter(Prefix=key).limit(max_objects) else: raise CommError("Unable to connect to S3 (%s): %s" % (e.response['Error']['Code'], e.response['Error']['Message'])) # Weird iterator scoping makes us assign this to a local function size = self._size_from_obj entries = [self._entry_from_obj(obj, path, name, prefix=key, multi=multi) for obj in objs if size(obj) > 0] if start_time is not None: termlog('Done. %.1fs' % (time.time() - start_time), prefix=False) if len(entries) >= max_objects: raise ValueError('Exceeded %i objects tracked, pass max_objects to add_reference' % max_objects) return entries
def _load_entity(self, api, network): if not api.api_key: raise CommError("Can't find API key, run wandb login or set WANDB_API_KEY") entity = api.settings('entity') if network: if api.settings('entity') is None: viewer = api.viewer() if viewer.get('entity'): api.set_setting('entity', viewer['entity']) entity = api.settings('entity') if not entity: # This can happen on network failure raise CommError("Can't connect to network to query entity from API key") return entity
def _load_viewer(self): if self.mode != "dryrun" and not self._api.disabled() and self._api.api_key: # Kaggle has internet disabled by default, this checks for that case async_viewer = util.async_call(self._api.viewer, timeout=env.get_http_timeout(5)) viewer, viewer_thread = async_viewer() if viewer_thread.is_alive(): if is_kaggle(): raise CommError("To use W&B in kaggle you must enable internet in the settings panel on the right.") else: self._viewer = viewer self._flags = json.loads(viewer.get("flags", "{}"))
def upload_urls(self, project, files, run=None, entity=None, description=None): """Generate temporary resumeable upload urls Args: project (str): The project to download files (list or dict): The filenames to upload run (str, optional): The run to upload to entity (str, optional): The entity to scope this project to. Defaults to wandb models Returns: (bucket_id, file_info) bucket_id: id of bucket we uploaded to file_info: A dict of filenames and urls, also indicates if this revision already has uploaded files. { 'weights.h5': { "url": "https://weights.url" }, 'model.json': { "url": "https://model.json", "updatedAt": '2013-04-26T22:22:23.832Z', 'md5': 'mZFLkyvTelC5g8XnyQrpOw==' }, } """ query = gql(''' query Model($name: String!, $files: [String]!, $entity: String!, $run: String!, $description: String) { model(name: $name, entityName: $entity) { bucket(name: $run, desc: $description) { id files(names: $files) { uploadHeaders edges { node { name url(upload: true) updatedAt } } } } } } ''') run_id = run or self.settings('run') entity = entity or self.settings('entity') query_result = self.gql(query, variable_values={ 'name': project, 'run': run_id, 'entity': entity, 'description': description, 'files': [file for file in files] }) run = query_result['model']['bucket'] if run: result = {file['name']: file for file in self._flatten_edges(run['files'])} return run['id'], run['files']['uploadHeaders'], result else: raise CommError("Run does not exist {}/{}/{}.".format(entity, project, run_id))
def parse_slug(self, slug, project=None, run=None): if slug and "/" in slug: parts = slug.split("/") project = parts[0] run = parts[1] else: project = project or self.settings().get("project") if project is None: raise CommError("No default project configured.") run = run or slug or env.get_run(env=self._environ) if run is None: run = "latest" return (project, run)
def read(self, size=-1): """Read bytes and call the callback""" bites = self.file.read(size) self.bytes_read += len(bites) if not bites and self.bytes_read < self.len: # Files shrinking during uploads causes request timeouts. Maybe # we could avoid those by updating the self.len in real-time, but # files getting truncated while uploading seems like something # that shouldn't really be happening anyway. raise CommError('File {} size shrank from {} to {} while it was being uploaded.'.format(self.file.name, self.len, self.bytes_read)) # Growing files are also likely to be bad, but our code didn't break # on those in the past so it's riskier to make that an error now. self.callback(len(bites), self.bytes_read) return bites