示例#1
0
def validate_all(call: APICall, endpoint: Endpoint):
    """ Perform all required call/endpoint validation, update call result appropriately """
    try:
        validate_auth(endpoint, call)

        validate_role(endpoint, call)

        if validate_impersonation(endpoint, call):
            # if impersonating, validate role again
            validate_role(endpoint, call)

        # todo: remove vaildate_required_fields once all endpoints have json schema
        validate_required_fields(endpoint, call)

        # set models. models will be validated automatically
        call.schema_validator = endpoint.request_schema_validator
        if endpoint.request_data_model:
            call.data_model_cls = endpoint.request_data_model

        call.result.schema_validator = endpoint.response_schema_validator
        if endpoint.response_data_model:
            call.result.data_model_cls = endpoint.response_data_model

        return True

    except CallParsingError as ex:
        raise errors.bad_request.ValidationError(str(ex))
    except jsonmodels.errors.ValidationError as ex:
        raise errors.bad_request.ValidationError(" ".join(
            map(str.lower, map(str, ex.args))))
    except fastjsonschema.exceptions.JsonSchemaException as ex:
        log.exception(f"{endpoint.name}: fastjsonschema exception")
        raise errors.bad_request.ValidationError(ex.args[0])
示例#2
0
def _call_or_empty_with_error(call, req, msg, code=500, subcode=0):
    call = call or APICall("",
                           remote_addr=req.remote_addr,
                           headers=dict(req.headers),
                           files=req.files)
    call.set_error_result(msg=msg, code=code, subcode=subcode)
    return call
示例#3
0
def create_api_call(req):
    call = None
    try:
        # Parse the request path
        endpoint_version, endpoint_name = ServiceRepo.parse_endpoint_path(req.path)

        # Resolve authorization: if cookies contain an authorization token, use it as a starting point.
        # in any case, request headers always take precedence.
        auth_cookie = req.cookies.get(
            config.get("apiserver.auth.session_auth_cookie_name")
        )
        headers = (
            {}
            if not auth_cookie
            else {"Authorization": f"{AuthType.bearer_token} {auth_cookie}"}
        )
        headers.update(
            list(req.headers.items())
        )  # add (possibly override with) the headers

        # Construct call instance
        call = APICall(
            endpoint_name=endpoint_name,
            remote_addr=req.remote_addr,
            endpoint_version=endpoint_version,
            headers=headers,
            files=req.files,
        )

        # Update call data from request
        with TimingContext("preprocess", "update_call_data"):
            update_call_data(call, req)

    except PathParsingError as ex:
        call = _call_or_empty_with_error(call, req, ex.args[0], 400)
        call.log_api = False
    except BadRequest as ex:
        call = _call_or_empty_with_error(call, req, ex.description, 400)
    except BaseError as ex:
        call = _call_or_empty_with_error(call, req, ex.msg, ex.code, ex.subcode)
    except Exception as ex:
        log.exception("Error creating call")
        call = _call_or_empty_with_error(
            call, req, ex.args[0] if ex.args else type(ex).__name__, 500
        )

    return call
示例#4
0
def add_batch(call: APICall, company_id, req_model):
    events = call.batched_data
    if events is None or len(events) == 0:
        raise errors.bad_request.BatchContainsNoItems()

    added, batch_errors = event_bll.add_events(company_id, events, call.worker)
    call.result.data = dict(added=added, errors=len(batch_errors))
    call.kpis["events"] = len(events)
示例#5
0
def add(call: APICall, company_id, req_model):
    data = call.data.copy()
    allow_locked = data.pop("allow_locked", False)
    added, batch_errors = event_bll.add_events(company_id, [data],
                                               call.worker,
                                               allow_locked_tasks=allow_locked)
    call.result.data = dict(added=added, errors=len(batch_errors))
    call.kpis["events"] = 1
示例#6
0
def add(call: APICall, company_id, _):
    data = call.data.copy()
    allow_locked = data.pop("allow_locked", False)
    added, err_count, err_info = event_bll.add_events(
        company_id, [data], call.worker, allow_locked_tasks=allow_locked)
    call.result.data = dict(added=added,
                            errors=err_count,
                            errors_info=err_info)
    call.kpis["events"] = 1
示例#7
0
def update_for_task(call: APICall, company_id, _):
    task_id = call.data["task"]
    uri = call.data.get("uri")
    iteration = call.data.get("iteration")
    override_model_id = call.data.get("override_model_id")
    if not (uri or override_model_id) or (uri and override_model_id):
        raise errors.bad_request.MissingRequiredFields(
            "exactly one field is required",
            fields=("uri", "override_model_id"))

    with translate_errors_context():

        query = dict(id=task_id, company=company_id)
        task = Task.get_for_writing(
            id=task_id,
            company=company_id,
            _only=["output", "execution", "name", "status", "project"],
        )
        if not task:
            raise errors.bad_request.InvalidTaskId(**query)

        allowed_states = [TaskStatus.created, TaskStatus.in_progress]
        if task.status not in allowed_states:
            raise errors.bad_request.InvalidTaskStatus(
                f"model can only be updated for tasks in the {allowed_states} states",
                **query,
            )

        if override_model_id:
            query = dict(company=company_id, id=override_model_id)
            model = Model.objects(**query).first()
            if not model:
                raise errors.bad_request.InvalidModelId(**query)
        else:
            if "name" not in call.data:
                # use task name if name not provided
                call.data["name"] = task.name

            if "comment" not in call.data:
                call.data[
                    "comment"] = f"Created by task `{task.name}` ({task.id})"

            if task.output and task.output.model:
                # model exists, update
                res = _update_model(call,
                                    company_id,
                                    model_id=task.output.model).to_struct()
                res.update({"id": task.output.model, "created": False})
                call.result.data = res
                return

            # new model, create
            fields = parse_model_fields(call, create_fields)

            # create and save model
            model = Model(
                id=database.utils.id(),
                created=datetime.utcnow(),
                user=call.identity.user,
                company=company_id,
                project=task.project,
                framework=task.execution.framework,
                parent=task.execution.model,
                design=task.execution.model_desc,
                labels=task.execution.model_labels,
                ready=(task.status == TaskStatus.published),
                **fields,
            )
            model.save()
            _update_cached_tags(company_id,
                                project=model.project,
                                fields=fields)

        TaskBLL.update_statistics(
            task_id=task_id,
            company_id=company_id,
            last_iteration_max=iteration,
            output__model=model.id,
        )

        call.result.data = {"id": model.id, "created": True}