示例#1
0
    def GetModelStatus(self, request, context):
        logger.debug("MODEL_STATUS, get request: {}".format(request))
        model_name = request.model_spec.name
        requested_version = request.model_spec.version.value
        valid_model_status = check_availability_of_requested_status(
            models=self.models,
            requested_version=requested_version,
            model_name=model_name)

        if not valid_model_status:
            context.set_code(StatusCode.NOT_FOUND)
            context.set_details(
                WRONG_MODEL_SPEC.format(model_name, requested_version))
            logger.debug("MODEL_STATUS, invalid model spec from request")
            return get_model_status_pb2.GetModelStatusResponse()

        response = get_model_status_pb2.GetModelStatusResponse()
        if requested_version:
            version_status = self.models[model_name].versions_statuses[
                requested_version]
            add_status_to_response(version_status, response)
        else:
            for version_status in self.models[model_name].versions_statuses. \
                    values():
                add_status_to_response(version_status, response)

        logger.debug("MODEL_STATUS created a response for {} - {}".format(
            model_name, requested_version))
        return response
示例#2
0
    def on_get(self, req, resp, model_name, requested_version=0):
        logger.debug("MODEL_STATUS, get request: {}, {}".format(
            model_name, requested_version))
        valid_model_status = check_availability_of_requested_status(
            models=self.models,
            requested_version=requested_version,
            model_name=model_name)

        if not valid_model_status:
            resp.status = falcon.HTTP_NOT_FOUND
            logger.debug("MODEL_STATUS, invalid model spec from request")
            err_out_json = {
                'error': WRONG_MODEL_SPEC.format(model_name, requested_version)
            }
            resp.body = json.dumps(err_out_json)
            return
        requested_version = int(requested_version)

        response = get_model_status_pb2.GetModelStatusResponse()
        if requested_version:
            version_status = self.models[model_name].versions_statuses[
                requested_version]
            add_status_to_response(version_status, response)
        else:
            for version_status in self.models[model_name].versions_statuses. \
                    values():
                add_status_to_response(version_status, response)
        logger.debug("MODEL_STATUS created a response for {} - {}".format(
            model_name, requested_version))
        resp.status = falcon.HTTP_200
        resp.body = MessageToJson(response,
                                  including_default_value_fields=True)
    def test_get_model_status_rest(self, model_version_policy_models,
                                   start_server_model_ver_policy, model_name,
                                   throw_error):

        _, ports = start_server_model_ver_policy
        print("Downloaded model files:", model_version_policy_models)

        versions = [1, 2, 3]
        for x in range(len(versions)):
            rest_url = 'http://localhost:{}/v1/models/{}/' \
                       'versions/{}'.format(ports["rest_port"], model_name,
                                            versions[x])
            result = requests.get(rest_url)
            if not throw_error[x]:
                output_json = result.text
                status_pb = get_model_status_pb2.GetModelStatusResponse()
                response = Parse(output_json,
                                 status_pb,
                                 ignore_unknown_fields=False)
                versions_statuses = response.model_version_status
                version_status = versions_statuses[0]
                assert version_status.version == versions[x]
                assert version_status.state == ModelVersionState.AVAILABLE
                assert version_status.status.error_code == ErrorCode.OK
                assert version_status.status.error_message == _ERROR_MESSAGE[
                    ModelVersionState.AVAILABLE][ErrorCode.OK]
            else:
                assert 404 == result.status_code

                #   aggregated results check
        if model_name == 'all':
            rest_url = 'http://localhost:{}/v1/models/all'.format(
                ports["rest_port"])
            response = get_model_status_response_rest(rest_url)
            versions_statuses = response.model_version_status
            assert len(versions_statuses) == 3
            for version_status in versions_statuses:
                assert version_status.state == ModelVersionState.AVAILABLE
                assert version_status.status.error_code == ErrorCode.OK
                assert version_status.status.error_message == _ERROR_MESSAGE[
                    ModelVersionState.AVAILABLE][ErrorCode.OK]
def get_model_status_response_rest(rest_url):
    result = requests.get(rest_url)
    output_json = result.text
    status_pb = get_model_status_pb2.GetModelStatusResponse()
    response = Parse(output_json, status_pb, ignore_unknown_fields=False)
    return response
示例#5
0
def _make_response(
    payload: Dict[Text, Any]) -> get_model_status_pb2.GetModelStatusResponse:
  result = get_model_status_pb2.GetModelStatusResponse()
  json_format.ParseDict(payload, result)
  return result
示例#6
0
 def __init__(self, model_version_status=None):
     super().__init__(get_model_status_pb2.GetModelStatusResponse(), 
                      model_version_status=model_version_status)