Beispiel #1
0
def no_retry_auth(e):
    if hasattr(e, "exception"):
        e = e.exception
    if not isinstance(e, requests.HTTPError):
        return True
    # Don't retry bad request errors; raise immediately
    if e.response.status_code == 400:
        return False
    # Retry all non-forbidden/unauthorized/not-found errors.
    if e.response.status_code not in (401, 403, 404):
        return True
    # Crash w/message on forbidden/unauthorized errors.
    if e.response.status_code == 401:
        extra = ""
        if wandb.run and str(wandb.run.api.api_key).startswith("local-"):
            extra = " --host=http://localhost:8080"
            if wandb.run.api.api_url == "https://api.wandb.ai":
                raise CommError(
                    "Attempting to authenticate with the cloud using a local API key.  Set WANDB_BASE_URL to your local instance."
                )
        raise CommError("Invalid or missing api_key.  Run wandb login" + extra)
    elif wandb.run:
        raise CommError("Permission denied to access {}".format(
            wandb.run.path))
    else:
        raise CommError(
            "Permission denied, ask the project owner to grant you access")
Beispiel #2
0
    def _load_entity(self, api, network):
        if not api.api_key:
            raise CommError(
                "Can't find API key, run wandb login or set WANDB_API_KEY")

        entity = api.settings('entity')
        if network:
            if api.settings('entity') is None:
                # Kaggle has internet disabled by default, this checks for that case
                async_viewer = util.async_call(api.viewer, timeout=3)
                viewer, viewer_thread = async_viewer()
                if viewer_thread.is_alive():
                    if is_kaggle():
                        raise CommError(
                            "To use W&B in kaggle you must enable internet in the settings panel on the right."
                        )
                    else:
                        raise CommError(
                            "Can't connect to network to query entity from API key"
                        )
                if viewer.get('entity'):
                    api.set_setting('entity', viewer['entity'])

            entity = api.settings('entity')

        if not entity:
            # This can happen on network failure
            raise CommError(
                "Can't connect to network to query entity from API key")

        return entity
Beispiel #3
0
    def upload_file(self, url, file, callback=None, extra_headers={}):
        """Uploads a file to W&B with failure resumption

        Args:
            url (str): The url to download
            file (str): The path to the file you want to upload
            callback (:obj:`func`, optional): A callback which is passed the number of
            bytes uploaded since the last time it was called, used to report progress

        Returns:
            The requests library response object
        """
        extra_headers = extra_headers.copy()
        response = None
        progress = Progress(file, callback=callback)
        if progress.len == 0:
            raise CommError("%s is an empty file" % file.name)
        try:
            response = requests.put(
                url, data=progress, headers=extra_headers)
            response.raise_for_status()
        except requests.exceptions.RequestException as e:
            status_code = e.response.status_code if e.response != None else 0
            # Retry errors from cloud storage or local network issues
            if status_code in (308, 409, 429, 500, 502, 503, 504) or isinstance(e, (requests.exceptions.Timeout, requests.exceptions.ConnectionError)):
                util.sentry_reraise(retry.TransientException(exc=e))
            else:
                util.sentry_reraise(e)

        return response
Beispiel #4
0
    def push(self, files, run=None, entity=None, project=None, description=None, force=True, progress=False):
        """Uploads multiple files to W&B

        Args:
            files (list or dict): The filenames to upload
            run (str, optional): The run to upload to
            entity (str, optional): The entity to scope this project to.  Defaults to wandb models
            project (str, optional): The name of the project to upload to. Defaults to the one in settings.
            description (str, optional): The description of the changes
            force (bool, optional): Whether to prevent push if git has uncommitted changes
            progress (callable, or stream): If callable, will be called with (chunk_bytes,
                total_bytes) as argument else if True, renders a progress bar to stream.

        Returns:
            The requests library response object
        """
        if project is None:
            project = self.get_project()
        if project is None:
            raise CommError("No project configured.")
        if run is None:
            run = self.current_run_id

        # TODO(adrian): we use a retriable version of self.upload_file() so
        # will never retry self.upload_urls() here. Instead, maybe we should
        # make push itself retriable.
        run_id, result = self.upload_urls(
            project, files, run, entity, description)
        responses = []
        for file_name, file_info in result.items():
            file_url = file_info['url']

            # If the upload URL is relative, fill it in with the base URL,
            # since its a proxied file store like the on-prem VM.
            if file_url.startswith('/'):
                file_url = '{}{}'.format(self.api_url, file_url)

            try:
                # To handle Windows paths
                # TODO: this doesn't handle absolute paths...
                normal_name = os.path.join(*file_name.split("/"))
                open_file = files[file_name] if isinstance(
                    files, dict) else open(normal_name, "rb")
            except IOError:
                print("%s does not exist" % file_name)
                continue
            if progress:
                if hasattr(progress, '__call__'):
                    responses.append(self.upload_file_retry(
                        file_url, open_file, progress))
                else:
                    length = os.fstat(open_file.fileno()).st_size
                    with click.progressbar(file=progress, length=length, label='Uploading file: %s' % (file_name),
                                           fill_char=click.style('&', fg='green')) as bar:
                        responses.append(self.upload_file_retry(
                            file_url, open_file, lambda bites, _: bar.update(bites)))
            else:
                responses.append(self.upload_file_retry(file_info['url'], open_file))
            open_file.close()
        return responses
Beispiel #5
0
    def upload_file(self, url, file, callback=None, extra_headers={}):
        """Uploads a file to W&B with failure resumption

        Args:
            url (str): The url to download
            file (str): The path to the file you want to upload
            callback (:obj:`func`, optional): A callback which is passed the number of
            bytes uploaded since the last time it was called, used to report progress

        Returns:
            The requests library response object
        """
        extra_headers = extra_headers.copy()
        response = None
        if os.stat(file.name).st_size == 0:
            raise CommError("%s is an empty file" % file.name)
        try:
            progress = Progress(file, callback=callback)
            response = requests.put(url, data=progress, headers=extra_headers)
            response.raise_for_status()
        except requests.exceptions.RequestException as e:
            total = progress.len
            status = self._status_request(url, total)
            # TODO(adrian): there's probably even more stuff we should add here
            # like if we're offline, we should retry then too
            if status.status_code in (308, 408, 500, 502, 503, 504):
                util.sentry_reraise(retry.TransientException(exc=e))
            else:
                util.sentry_reraise(e)

        return response
Beispiel #6
0
def no_retry_auth(e):
    if hasattr(e, "exception"):
        e = e.exception
    if not isinstance(e, requests.HTTPError):
        return True
    # Don't retry bad request errors; raise immediately
    if e.response.status_code == 400:
        return False
    # Retry all non-forbidden/unauthorized errors.
    if e.response.status_code not in (401, 403):
        return True
    # Crash w/message on forbidden/unauthorized errors.
    if e.response.status_code == 401:
        raise CommError("Invalid or missing api_key.  Run wandb login")
    else:
        raise CommError("Permission denied, ask the project owner to grant you access")
Beispiel #7
0
    def store_path(self, artifact, path, name=None, checksum=True, max_objects=None):
        self.init_boto()
        bucket, key = self._parse_uri(path)
        max_objects = max_objects or DEFAULT_MAX_OBJECTS
        if not checksum:
            return [ArtifactManifestEntry(name or key, path, digest=path)]

        objs = [self._s3.Object(bucket, key)]
        start_time = None
        multi = False
        try:
            objs[0].load()
        except self._botocore.exceptions.ClientError as e:
            if e.response['Error']['Code'] == "404":
                multi = True
                start_time = time.time()
                termlog('Generating checksum for up to %i objects with prefix "%s"... ' % (max_objects, key), newline=False)
                objs = self._s3.Bucket(bucket).objects.filter(Prefix=key).limit(max_objects)
            else:
                raise CommError("Unable to connect to S3 (%s): %s" % (e.response['Error']['Code'], e.response['Error']['Message']))

        # Weird iterator scoping makes us assign this to a local function
        size = self._size_from_obj
        entries = [self._entry_from_obj(obj, path, name, prefix=key, multi=multi) for obj in objs if size(obj) > 0]
        if start_time is not None:
            termlog('Done. %.1fs' % (time.time() - start_time), prefix=False)
        if len(entries) >= max_objects:
            raise ValueError('Exceeded %i objects tracked, pass max_objects to add_reference' % max_objects)
        return entries
Beispiel #8
0
    def _load_entity(self, api, network):
        if not api.api_key:
            raise CommError("Can't find API key, run wandb login or set WANDB_API_KEY")

        entity = api.settings('entity')
        if network:
            if api.settings('entity') is None:
                viewer = api.viewer()
                if viewer.get('entity'):
                    api.set_setting('entity', viewer['entity'])
        
            entity = api.settings('entity')
        
        if not entity:
            # This can happen on network failure
            raise CommError("Can't connect to network to query entity from API key")

        return entity
Beispiel #9
0
 def _load_viewer(self):
     if self.mode != "dryrun" and not self._api.disabled() and self._api.api_key:
         # Kaggle has internet disabled by default, this checks for that case
         async_viewer = util.async_call(self._api.viewer, timeout=env.get_http_timeout(5))
         viewer, viewer_thread = async_viewer()
         if viewer_thread.is_alive():
             if is_kaggle():
                 raise CommError("To use W&B in kaggle you must enable internet in the settings panel on the right.")
         else:
             self._viewer = viewer
             self._flags = json.loads(viewer.get("flags", "{}"))
Beispiel #10
0
    def upload_urls(self, project, files, run=None, entity=None, description=None):
        """Generate temporary resumeable upload urls

        Args:
            project (str): The project to download
            files (list or dict): The filenames to upload
            run (str, optional): The run to upload to
            entity (str, optional): The entity to scope this project to.  Defaults to wandb models

        Returns:
            (bucket_id, file_info)
            bucket_id: id of bucket we uploaded to
            file_info: A dict of filenames and urls, also indicates if this revision already has uploaded files.
                {
                    'weights.h5': { "url": "https://weights.url" },
                    'model.json': { "url": "https://model.json", "updatedAt": '2013-04-26T22:22:23.832Z', 'md5': 'mZFLkyvTelC5g8XnyQrpOw==' },
                }
        """
        query = gql('''
        query Model($name: String!, $files: [String]!, $entity: String!, $run: String!, $description: String) {
            model(name: $name, entityName: $entity) {
                bucket(name: $run, desc: $description) {
                    id
                    files(names: $files) {
                        uploadHeaders
                        edges {
                            node {
                                name
                                url(upload: true)
                                updatedAt
                            }
                        }
                    }
                }
            }
        }
        ''')
        run_id = run or self.settings('run')
        entity = entity or self.settings('entity')
        query_result = self.gql(query, variable_values={
            'name': project, 'run': run_id,
            'entity': entity,
            'description': description,
            'files': [file for file in files]
        })

        run = query_result['model']['bucket']
        if run:
            result = {file['name']: file for file in self._flatten_edges(run['files'])}
            return run['id'], run['files']['uploadHeaders'], result
        else:
            raise CommError("Run does not exist {}/{}/{}.".format(entity, project, run_id))
Beispiel #11
0
 def parse_slug(self, slug, project=None, run=None):
     if slug and "/" in slug:
         parts = slug.split("/")
         project = parts[0]
         run = parts[1]
     else:
         project = project or self.settings().get("project")
         if project is None:
             raise CommError("No default project configured.")
         run = run or slug or env.get_run(env=self._environ)
         if run is None:
             run = "latest"
     return (project, run)
Beispiel #12
0
 def read(self, size=-1):
     """Read bytes and call the callback"""
     bites = self.file.read(size)
     self.bytes_read += len(bites)
     if not bites and self.bytes_read < self.len:
         # Files shrinking during uploads causes request timeouts. Maybe
         # we could avoid those by updating the self.len in real-time, but
         # files getting truncated while uploading seems like something
         # that shouldn't really be happening anyway.
         raise CommError('File {} size shrank from {} to {} while it was being uploaded.'.format(self.file.name, self.len, self.bytes_read))
     # Growing files are also likely to be bad, but our code didn't break
     # on those in the past so it's riskier to make that an error now.
     self.callback(len(bites), self.bytes_read)
     return bites