예제 #1
0
파일: util.py 프로젝트: vannirobben/client
def no_retry_auth(e):
    if hasattr(e, "exception"):
        e = e.exception
    if not isinstance(e, requests.HTTPError):
        return True
    # Don't retry bad request errors; raise immediately
    if e.response.status_code == 400:
        return False
    # Retry all non-forbidden/unauthorized/not-found errors.
    if e.response.status_code not in (401, 403, 404):
        return True
    # Crash w/message on forbidden/unauthorized errors.
    if e.response.status_code == 401:
        raise CommError("Invalid or missing api_key.  Run wandb login")
    elif wandb.run:
        raise CommError("Permission denied to access {}".format(
            wandb.run.path))
    else:
        raise CommError(
            "Permission denied, ask the project owner to grant you access")
예제 #2
0
    def wrapper(*args, **kwargs):
        message = "Whoa, you found a bug."
        try:
            return func(*args, **kwargs)
        except requests.HTTPError as err:
            raise CommError(err.response, err)
        except RetryError as err:
            if "response" in dir(err.last_exception
                                 ) and err.last_exception.response is not None:
                try:
                    message = err.last_exception.response.json().get(
                        'errors', [{
                            'message': message
                        }])[0]['message']
                except ValueError:
                    message = err.last_exception.response.text
            else:
                message = err.last_exception

            if env.is_debug():
                six.reraise(type(err.last_exception), err.last_exception,
                            sys.exc_info()[2])
            else:
                six.reraise(CommError, CommError(message, err.last_exception),
                            sys.exc_info()[2])
        except Exception as err:
            # gql raises server errors with dict's as strings...
            if len(err.args) > 0:
                payload = err.args[0]
            else:
                payload = err
            if str(payload).startswith("{"):
                message = ast.literal_eval(str(payload))["message"]
            else:
                message = str(err)
            if env.is_debug():
                six.reraise(*sys.exc_info())
            else:
                six.reraise(CommError, CommError(message, err),
                            sys.exc_info()[2])
예제 #3
0
    def upload_urls(self, project, files, run=None, entity=None, description=None):
        """Generate temporary resumeable upload urls

        Args:
            project (str): The project to download
            files (list or dict): The filenames to upload
            run (str, optional): The run to upload to
            entity (str, optional): The entity to scope this project to.  Defaults to wandb models

        Returns:
            (bucket_id, file_info)
            bucket_id: id of bucket we uploaded to
            file_info: A dict of filenames and urls, also indicates if this revision already has uploaded files.
                {
                    'weights.h5': { "url": "https://weights.url" },
                    'model.json': { "url": "https://model.json", "updatedAt": '2013-04-26T22:22:23.832Z', 'md5': 'mZFLkyvTelC5g8XnyQrpOw==' },
                }
        """
        query = gql('''
        query Model($name: String!, $files: [String]!, $entity: String!, $run: String!, $description: String) {
            model(name: $name, entityName: $entity) {
                bucket(name: $run, desc: $description) {
                    id
                    files(names: $files) {
                        uploadHeaders
                        edges {
                            node {
                                name
                                url(upload: true)
                                updatedAt
                            }
                        }
                    }
                }
            }
        }
        ''')
        run_id = run or self.settings('run')
        entity = entity or self.settings('entity')
        query_result = self.gql(query, variable_values={
            'name': project, 'run': run_id,
            'entity': entity,
            'description': description,
            'files': [file for file in files]
        })

        run = query_result['model']['bucket']
        if run:
            result = {file['name']: file for file in self._flatten_edges(run['files'])}
            return run['id'], run['files']['uploadHeaders'], result
        else:
            raise CommError("Run does not exist {}/{}/{}.".format(entity, project, run_id))
예제 #4
0
 def parse_slug(self, slug, project=None, run=None):
     if slug and "/" in slug:
         parts = slug.split("/")
         project = parts[0]
         run = parts[1]
     else:
         project = project or self.settings().get("project")
         if project is None:
             raise CommError("No default project configured.")
         run = run or slug or env.get_run(env=self._environ)
         if run is None:
             run = "latest"
     return (project, run)
예제 #5
0
 def read(self, size=-1):
     """Read bytes and call the callback"""
     bites = self.file.read(size)
     self.bytes_read += len(bites)
     if not bites and self.bytes_read < self.len:
         # Files shrinking during uploads causes request timeouts. Maybe
         # we could avoid those by updating the self.len in real-time, but
         # files getting truncated while uploading seems like something
         # that shouldn't really be happening anyway.
         raise CommError(
             "File {} size shrank from {} to {} while it was being uploaded."
             .format(self.file.name, self.len, self.bytes_read))
     # Growing files are also likely to be bad, but our code didn't break
     # on those in the past so it's riskier to make that an error now.
     self.callback(len(bites), self.bytes_read)
     return bites
예제 #6
0
    def store_path(self, artifact, path, name=None, checksum=True, max_objects=None):
        self.init_boto()
        bucket, key = self._parse_uri(path)
        max_objects = max_objects or DEFAULT_MAX_OBJECTS
        if not checksum:
            return [ArtifactManifestEntry(name or key, path, digest=path)]

        objs = [self._s3.Object(bucket, key)]
        start_time = None
        multi = False
        try:
            objs[0].load()
        except self._botocore.exceptions.ClientError as e:
            if e.response["Error"]["Code"] == "404":
                multi = True
                start_time = time.time()
                termlog(
                    'Generating checksum for up to %i objects with prefix "%s"... '
                    % (max_objects, key),
                    newline=False,
                )
                objs = (
                    self._s3.Bucket(bucket)
                    .objects.filter(Prefix=key)
                    .limit(max_objects)
                )
            else:
                raise CommError(
                    "Unable to connect to S3 (%s): %s"
                    % (e.response["Error"]["Code"], e.response["Error"]["Message"])
                )

        # Weird iterator scoping makes us assign this to a local function
        size = self._size_from_obj
        entries = [
            self._entry_from_obj(obj, path, name, prefix=key, multi=multi)
            for obj in objs
            if size(obj) > 0
        ]
        if start_time is not None:
            termlog("Done. %.1fs" % (time.time() - start_time), prefix=False)
        if len(entries) >= max_objects:
            raise ValueError(
                "Exceeded %i objects tracked, pass max_objects to add_reference"
                % max_objects
            )
        return entries
예제 #7
0
 def query_with_timeout(self, timeout=None):
     if self._settings and self._settings._disable_viewer:
         return
     timeout = timeout or 5
     async_viewer = util.async_call(self._api.viewer_server_info, timeout=timeout)
     try:
         viewer_tuple, viewer_thread = async_viewer()
     except Exception:
         # TODO: currently a bare exception as lots can happen, we should classify
         self._error_network = True
         return
     if viewer_thread.is_alive():
         if util._is_kaggle():
             raise CommError(
                 "To use W&B in kaggle you must enable internet in the settings panel on the right."  # noqa: E501
             )
         # this is likely a DNS hang
         self._error_network = True
         return
     self._error_network = False
     # TODO(jhr): should we kill the thread?
     self._viewer, self._serverinfo = viewer_tuple
     self._flags = json.loads(self._viewer.get("flags", "{}"))
예제 #8
0
    def upload_urls(self,
                    project,
                    files,
                    run=None,
                    entity=None,
                    description=None):
        """Generate temporary resumeable upload urls

        Arguments:
            project (str): The project to download
            files (list or dict): The filenames to upload
            run (str): The run to upload to
            entity (str, optional): The entity to scope this project to.  Defaults to wandb models

        Returns:
            (bucket_id, file_info)
            bucket_id: id of bucket we uploaded to
            file_info: A dict of filenames and urls, also indicates if this revision already has uploaded files.
                {
                    'weights.h5': { "url": "https://weights.url" },
                    'model.json': { "url": "https://model.json", "updatedAt": '2013-04-26T22:22:23.832Z', 'md5': 'mZFLkyvTelC5g8XnyQrpOw==' },
                }
        """
        query = gql("""
        query Model($name: String!, $files: [String]!, $entity: String!, $run: String!, $description: String) {
            model(name: $name, entityName: $entity) {
                bucket(name: $run, desc: $description) {
                    id
                    files(names: $files) {
                        uploadHeaders
                        edges {
                            node {
                                name
                                url(upload: true)
                                updatedAt
                            }
                        }
                    }
                }
            }
        }
        """)
        run_id = run or self.current_run_id
        assert run, "run must be specified"
        entity = entity or self.settings("entity")
        query_result = self.gql(
            query,
            variable_values={
                "name": project,
                "run": run_id,
                "entity": entity,
                "description": description,
                "files": [file for file in files],
            },
        )

        run = query_result["model"]["bucket"]
        if run:
            result = {
                file["name"]: file
                for file in self._flatten_edges(run["files"])
            }
            return run["id"], run["files"]["uploadHeaders"], result
        else:
            raise CommError("Run does not exist {}/{}/{}.".format(
                entity, project, run_id))