コード例 #1
0
def test_failed_register(valid_simple_ml_data):
    session = mock.Mock()
    session.post_resource.side_effect = NotFound("/projects/uuid/not_found")
    pc = PredictorCollection(uuid.uuid4(), session)
    predictor = SimpleMLPredictor.build(valid_simple_ml_data)
    with pytest.raises(ModuleRegistrationFailedException) as e:
        pc.register(predictor)
    assert 'The "SimpleMLPredictor" failed to register. NotFound: /projects/uuid/not_found' in str(e.value)
コード例 #2
0
    def checked_request(self, method: str, path: str, *args,
                        **kwargs) -> requests.Response:
        """Check response status code and throw an exception if relevant."""
        if self._is_access_token_expired():
            self._refresh_access_token()
        uri = self.base_url + path.lstrip('/')
        response = super().request(method, uri, *args, **kwargs)

        try:
            if response.status_code == 401 and response.json().get(
                    "reason") == "invalid-token":
                self._refresh_access_token()
                response = super().request(method, uri, *args, **kwargs)
        except ValueError:
            # Ignore ValueErrors thrown by attempting to decode json bodies. This
            # might occur if we get a 401 response without a JSON body
            pass

        # TODO: More substantial/careful error handling
        if 200 <= response.status_code <= 299:
            self.logger.info('%s %s %s', response.status_code, method, path)
            return response
        else:
            stacktrace = self._extract_response_stacktrace(response)
            if stacktrace is not None:
                self.logger.error('Response arrived with stacktrace:')
                self.logger.error(stacktrace)
            if response.status_code == 400:
                self.logger.error('%s %s %s', response.status_code, method,
                                  path)
                raise BadRequest(path, response)
            elif response.status_code == 401:
                self.logger.error('%s %s %s', response.status_code, method,
                                  path)
                raise Unauthorized(path, response)
            elif response.status_code == 404:
                self.logger.error('%s %s %s', response.status_code, method,
                                  path)
                raise NotFound(path, response)
            elif response.status_code == 409:
                self.logger.debug('%s %s %s', response.status_code, method,
                                  path)
                raise WorkflowConflictException(response.text)
            elif response.status_code == 425:
                self.logger.debug('%s %s %s', response.status_code, method,
                                  path)
                msg = 'Cant execute at this time. Try again later. Error: {}'.format(
                    response.text)
                raise WorkflowNotReadyException(msg)
            else:
                self.logger.error('%s %s %s', response.status_code, method,
                                  path)
                raise CitrineException(response.text)
コード例 #3
0
    def get_by_unique_name(self, unique_name: str) -> ResourceType:
        """Get a Dataset with the given unique name."""
        if unique_name is None:
            raise ValueError("You must supply a unique_name")
        path = self._get_path() + "?unique_name=" + unique_name
        data = self.session.get_resource(path)

        if len(data) == 1:
            return self.build(data[0])
        elif len(data) > 1:
            raise RuntimeError(
                "Received multiple results when requesting a unique dataset")
        else:
            raise NotFound(path)
コード例 #4
0
        def search(self,
                   search_params: Optional[dict] = None,
                   per_page: int = 100):
            if not self.search_implemented:
                raise NotFound("search")

            ans = self.projects
            if search_params.get("name"):
                method = search_params["name"]["search_method"]
                value = search_params["name"]["value"]
                if method == "EXACT":
                    ans = [x for x in ans if x.name == value]
                elif method == "SUBSTRING":
                    ans = [x for x in ans if value in x.name]
            if search_params.get("description"):
                method = search_params["description"]["search_method"]
                value = search_params["description"]["value"]
                if method == "EXACT":
                    ans = [x for x in ans if x.description == value]
                elif method == "SUBSTRING":
                    ans = [x for x in ans if value in x.description]

            return ans
コード例 #5
0
def test_archive_and_restore(valid_label_fractions_predictor_data):
    session = mock.Mock()
    pc = PredictorCollection(uuid.uuid4(), session)
    session.get_resource.return_value = valid_label_fractions_predictor_data

    def _mock_put_resource(url, data):
        """Assume that update returns the serialized predictor data."""
        return data
    session.put_resource.side_effect = _mock_put_resource
    archived_predictor = pc.archive(uuid.uuid4())
    assert archived_predictor.archived

    valid_label_fractions_predictor_data["archived"] = True
    session.get_resource.return_value = valid_label_fractions_predictor_data
    restored_predictor = pc.restore(uuid.uuid4())
    assert not restored_predictor.archived

    session.get_resource.side_effect = NotFound("")
    with pytest.raises(RuntimeError):
        pc.archive(uuid.uuid4())

    with pytest.raises(RuntimeError):
        pc.restore(uuid.uuid4())
コード例 #6
0
    def checked_request(self,
                        method: str,
                        path: str,
                        version: str = 'v1',
                        **kwargs) -> requests.Response:
        """Check response status code and throw an exception if relevant."""
        logger.debug('BEGIN request details:')
        logger.debug('\tmethod: {}'.format(method))
        logger.debug('\tpath: {}'.format(path))
        logger.debug('\tversion: {}'.format(version))

        if self._is_access_token_expired():
            self._refresh_access_token()
        uri = self._versioned_base_url(version) + path.lstrip('/')

        logger.debug('\turi: {}'.format(uri))

        for k, v in kwargs.items():
            logger.debug('\t{}: {}'.format(k, v))
        logger.debug('END request details.')

        response = self._request_with_retry(method, uri, **kwargs)

        try:
            if response.status_code == 401 and response.json().get(
                    "reason") == "invalid-token":
                self._refresh_access_token()
                response = self._request_with_retry(method, uri, **kwargs)
        except AttributeError:
            # Catch AttributeErrors and log response
            # The 401 status will be handled further down
            logger.error("Failed to decode json from response: {}".format(
                response.text))
        except ValueError:
            # Ignore ValueErrors thrown by attempting to decode json bodies. This
            # might occur if we get a 401 response without a JSON body
            pass

        # TODO: More substantial/careful error handling
        if 200 <= response.status_code <= 299:
            logger.info('%s %s %s', response.status_code, method, path)
            return response
        else:
            stacktrace = self._extract_response_stacktrace(response)
            if stacktrace is not None:
                logger.error('Response arrived with stacktrace:')
                logger.error(stacktrace)
            if response.status_code == 400:
                logger.error('%s %s %s', response.status_code, method, path)
                logger.error(response.text)
                raise BadRequest(path, response)
            elif response.status_code == 401:
                logger.error('%s %s %s', response.status_code, method, path)
                raise Unauthorized(path, response)
            elif response.status_code == 403:
                logger.error('%s %s %s', response.status_code, method, path)
                raise Unauthorized(path, response)
            elif response.status_code == 404:
                logger.error('%s %s %s', response.status_code, method, path)
                raise NotFound(path, response)
            elif response.status_code == 409:
                logger.debug('%s %s %s', response.status_code, method, path)
                raise WorkflowConflictException(response.text)
            elif response.status_code == 425:
                logger.debug('%s %s %s', response.status_code, method, path)
                msg = 'Cant execute at this time. Try again later. Error: {}'.format(
                    response.text)
                raise WorkflowNotReadyException(msg)
            else:
                logger.error('%s %s %s', response.status_code, method, path)
                raise CitrineException(response.text)