示例#1
0
    def _FormatResultAsJson(self, result, format_mode=None):
        if result is None:
            return dict(status="OK")

        if format_mode == JsonMode.PROTO3_JSON_MODE:
            return json.loads(
                json_format.MessageToJson(result.AsPrimitiveProto()))
        elif format_mode == JsonMode.GRR_ROOT_TYPES_STRIPPED_JSON_MODE:
            result_dict = {}
            for field, value in result.ListSetFields():
                if isinstance(
                        field,
                    (rdf_structs.ProtoDynamicEmbedded,
                     rdf_structs.ProtoEmbedded, rdf_structs.ProtoList)):
                    result_dict[field.name] = api_value_renderers.RenderValue(
                        value)
                else:
                    result_dict[field.name] = api_value_renderers.RenderValue(
                        value)["value"]
            return result_dict
        elif format_mode == JsonMode.GRR_TYPE_STRIPPED_JSON_MODE:
            rendered_data = api_value_renderers.RenderValue(result)
            return api_value_renderers.StripTypeInfo(rendered_data)
        elif format_mode == JsonMode.GRR_JSON_MODE:
            return api_value_renderers.RenderValue(result)
        else:
            raise ValueError("Invalid format_mode: %s", format_mode)
示例#2
0
    def HandleCheck(self, method_metadata, args=None, replace=None):
        """Does regression check for given method, args and a replace function."""

        if not replace:
            raise ValueError("replace can't be None")

        if self.__class__.api_version == 1:
            request, prepped_request = self._PrepareV1Request(
                method_metadata.name, args=args)
        elif self.__class__.api_version == 2:
            request, prepped_request = self._PrepareV2Request(
                method_metadata.name, args=args)
        else:
            raise ValueError("api_version may be only 1 or 2, not %d",
                             flags.FLAGS.api_version)

        session = requests.Session()
        response = session.send(prepped_request)

        check_result = {
            "url": replace(prepped_request.path_url),
            "method": request.method
        }

        if request.data:
            request_payload = self._ParseJSON(replace(request.data))
            if request_payload:
                check_result["request_payload"] = request_payload

        if (method_metadata.result_type == api_call_router.
                RouterMethodMetadata.BINARY_STREAM_RESULT_TYPE):
            check_result["response"] = replace(
                utils.SmartUnicode(response.content))
        else:
            check_result["response"] = self._ParseJSON(
                replace(response.content))

        if self.__class__.api_version == 1:
            stripped_response = api_value_renderers.StripTypeInfo(
                check_result["response"])
            if stripped_response != check_result["response"]:
                check_result["type_stripped_response"] = stripped_response

        return check_result
示例#3
0
  def _PrepareV1Request(self, method, args=None):
    """Prepares API v1 request for a given method and args."""

    args_proto = None
    if args:
      args_proto = args.AsPrimitiveProto()
    request = self.connector.BuildRequest(method, args_proto)
    request.url = request.url.replace("/api/v2/", "/api/")
    if args and request.data:
      body_proto = args.__class__().AsPrimitiveProto()
      json_format.Parse(request.data, body_proto)
      body_args = args.__class__()
      body_args.ParseFromString(body_proto.SerializeToString())
      request.data = json.dumps(
          api_value_renderers.StripTypeInfo(
              api_value_renderers.RenderValue(body_args)),
          cls=http_api.JSONEncoderWithRDFPrimitivesSupport)

    prepped_request = request.prepare()

    return request, prepped_request
示例#4
0
    def Check(self, method, url, payload=None, replace=None):
        """Records output of a given url accessed with a given method.

    Args:
      method: HTTP method. May be "GET" or "POST".
      url: String repesenting an url.
      payload: JSON-able payload that will be sent when "POST" method is used.
      replace: Dictionary of key->value pairs. In the recorded JSON output
               every "key" string will be replaced with its "value"
               counterpart. This way we can properly handle dynamically
               generated values (like hunts IDs) in the regression data.
    Raises:
      ValueError: if unsupported method argument is passed. Currently only
                  "GET", "POST", "DELETE" and "PATCH" are supported.
      RuntimeError: if request was handled by an unexpected API method (
                  every test is annotated with an "api_method" attribute
                  that points to the expected API method).
    """
        if self.use_api_v2:
            url = url.replace("/api/", "/api/v2/")

        parsed_url = urlparse.urlparse(url)
        request = utils.DataObject(method=method,
                                   scheme="http",
                                   path=parsed_url.path,
                                   environ={
                                       "SERVER_NAME": "foo.bar",
                                       "SERVER_PORT": 1234
                                   },
                                   user=self.token.username,
                                   body="")
        request.META = {"CONTENT_TYPE": "application/json"}

        if method == "GET":
            request.GET = dict(urlparse.parse_qsl(parsed_url.query))
        elif method in ["POST", "DELETE", "PATCH"]:
            # NOTE: this is a temporary trick. Payloads in regression tests
            # are using the API v1 (non-proto3) format. Here we're reparsing
            # them and serializing as proto3 JSON.
            # TODO(user): Make regression tests payload format-agnostic.
            # I.e. use protobuf and API client library to send requests.
            if self.use_api_v2 and payload:
                router_matcher = http_api.RouterMatcher()
                _, metadata, _ = router_matcher.MatchRouter(request)

                rdf_args = metadata.args_type()
                rdf_args.FromDict(payload)
                proto_args = metadata.args_type.protobuf()
                proto_args.ParseFromString(rdf_args.SerializeToString())

                request.body = json_format.MessageToJson(proto_args)
                payload = json.loads(request.body)
            else:
                request.body = json.dumps(payload or "")
        else:
            raise ValueError("Unsupported method: %s." % method)

        with self.NoAuthorizationChecks():
            http_response = http_api.RenderHttpResponse(request)

        api_method = http_response["X-API-Method"]
        if api_method != self.__class__.api_method:
            raise RuntimeError("Request was handled by an unexpected method. "
                               "Expected %s, got %s." %
                               (self.__class__.api_method, api_method))

        if hasattr(http_response, "streaming_content"):
            # We don't know the nature of response content, but we force it to be
            # unicode. It's a strategy that's good enough for testing purposes.
            content = utils.SmartUnicode("".join(
                http_response.streaming_content))
        else:
            content = http_response.content

        xssi_token = ")]}'\n"
        if content.startswith(xssi_token):
            content = content[len(xssi_token):]

        # replace the values of all tracebacks by <traceback content>
        regex = re.compile(r'"traceBack": "Traceback[^"\\]*(?:\\.[^"\\]*)*"',
                           re.DOTALL)
        content = regex.sub('"traceBack": "<traceback content>"', content)

        if replace:
            if hasattr(replace, "__call__"):
                replace = replace()

            # We reverse sort replacements by length to avoid cases when
            # replacements include each other and therefore order
            # of replacements affects the result.
            for substr in sorted(replace, key=len, reverse=True):
                repl = replace[substr]

                if hasattr(substr, "sub"):  # regex
                    content = substr.sub(repl, content)
                    url = substr.sub(repl, url)
                else:
                    content = content.replace(substr, repl)
                    url = url.replace(substr, repl)

        # We treat streaming content purely as strings and don't expect it to
        # contain JSON data.
        if hasattr(http_response, "streaming_content"):
            parsed_content = content
        else:
            parsed_content = json.loads(content)

        check_result = dict(api_method=api_method,
                            method=method,
                            url=url,
                            test_class=self.__class__.__name__,
                            response=parsed_content)

        if payload:
            check_result["request_payload"] = payload

        # Type stripping only makes sense for version 1 of the API.
        if not self.use_api_v2:
            stripped_content = api_value_renderers.StripTypeInfo(
                parsed_content)
            if parsed_content != stripped_content:
                check_result["type_stripped_response"] = stripped_content

        self.checks.append(check_result)
示例#5
0
    def HandleRequest(self, request):
        """Handles given HTTP request."""
        impersonated_username = config_lib.CONFIG[
            "AdminUI.debug_impersonate_user"]
        if impersonated_username:
            logging.info("Overriding user as %s", impersonated_username)
            request.user = config_lib.CONFIG["AdminUI.debug_impersonate_user"]

        if not aff4_users.GRRUser.IsValidUsername(request.user):
            return self._BuildResponse(
                403, dict(message="Invalid username: %s" % request.user))

        strip_type_info = False
        if hasattr(request, "GET") and request.GET.get("strip_type_info", ""):
            strip_type_info = True

        try:
            router, method_metadata, route_args = self._router_matcher.MatchRouter(
                request)
            args = self._GetArgsFromRequest(request, method_metadata,
                                            route_args)
        except access_control.UnauthorizedAccess as e:
            logging.exception("Access denied to %s (%s): %s", request.path,
                              request.method, e)

            additional_headers = {
                "X-GRR-Unauthorized-Access-Reason": utils.SmartStr(e.message),
                "X-GRR-Unauthorized-Access-Subject": utils.SmartStr(e.subject)
            }
            return self._BuildResponse(
                403,
                dict(message="Access denied by ACL: %s" %
                     utils.SmartStr(e.message),
                     subject=utils.SmartStr(e.subject)),
                headers=additional_headers)

        except ApiCallRouterNotFoundError as e:
            return self._BuildResponse(404, dict(message=e.message))
        except werkzeug_exceptions.MethodNotAllowed as e:
            return self._BuildResponse(405, dict(message=e.message))
        except Error as e:
            logging.exception("Can't match URL to router/method: %s", e)

            return self._BuildResponse(
                500, dict(message=str(e), traceBack=traceback.format_exc()))

        # SetUID() is called here so that ACL checks done by the router do not
        # clash with datastore ACL checks.
        # TODO(user): increase token expiry time.
        token = self.BuildToken(request, 60).SetUID()

        handler = None
        try:
            # ACL checks are done here by the router. If this method succeeds (i.e.
            # does not raise), then handlers run without further ACL checks (they're
            # free to do some in their own implementations, though).
            handler = getattr(router, method_metadata.name)(args, token=token)

            if handler.args_type != method_metadata.args_type:
                raise RuntimeError(
                    "Handler args type doesn't match "
                    "method args type: %s vs %s" %
                    (handler.args_type, method_metadata.args_type))

            binary_result_type = (
                api_call_router.RouterMethodMetadata.BINARY_STREAM_RESULT_TYPE)

            if (handler.result_type != method_metadata.result_type and
                    not (handler.result_type is None and
                         method_metadata.result_type == binary_result_type)):
                raise RuntimeError(
                    "Handler result type doesn't match "
                    "method result type: %s vs %s" %
                    (handler.result_type, method_metadata.result_type))

            # HEAD method is only used for checking the ACLs for particular API
            # methods.
            if request.method == "HEAD":
                # If the request would return a stream, we add the Content-Length
                # header to the response.
                if (method_metadata.result_type ==
                        method_metadata.BINARY_STREAM_RESULT_TYPE):
                    binary_stream = handler.Handle(args, token=token)
                    headers = None
                    if binary_stream.content_length:
                        headers = {
                            "Content-Length": binary_stream.content_length
                        }
                    return self._BuildResponse(
                        200, {"status": "OK"},
                        method_name=method_metadata.name,
                        headers=headers,
                        no_audit_log=method_metadata.no_audit_log_required,
                        token=token)
                else:
                    return self._BuildResponse(
                        200, {"status": "OK"},
                        method_name=method_metadata.name,
                        no_audit_log=method_metadata.no_audit_log_required,
                        token=token)

            if (method_metadata.result_type ==
                    method_metadata.BINARY_STREAM_RESULT_TYPE):
                binary_stream = handler.Handle(args, token=token)
                return self._BuildStreamingResponse(
                    binary_stream, method_name=method_metadata.name)
            else:
                for http_method, unused_url, options in method_metadata.http_methods:
                    strip_root_types = False
                    if http_method == request.method:
                        strip_root_types = options.get("strip_root_types",
                                                       False)
                        break

                rendered_data = self.CallApiHandler(
                    handler,
                    args,
                    token=token,
                    strip_root_types=strip_root_types)

                if strip_type_info:
                    rendered_data = api_value_renderers.StripTypeInfo(
                        rendered_data)

                return self._BuildResponse(
                    200,
                    rendered_data,
                    method_name=method_metadata.name,
                    no_audit_log=method_metadata.no_audit_log_required,
                    token=token)
        except access_control.UnauthorizedAccess as e:
            logging.exception("Access denied to %s (%s) with %s: %s",
                              request.path, request.method,
                              method_metadata.name, e)

            additional_headers = {
                "X-GRR-Unauthorized-Access-Reason": utils.SmartStr(e.message),
                "X-GRR-Unauthorized-Access-Subject": utils.SmartStr(e.subject)
            }
            return self._BuildResponse(
                403,
                dict(message="Access denied by ACL: %s" % e.message,
                     subject=utils.SmartStr(e.subject)),
                headers=additional_headers,
                method_name=method_metadata.name,
                no_audit_log=method_metadata.no_audit_log_required,
                token=token)
        except api_call_handler_base.ResourceNotFoundError as e:
            return self._BuildResponse(
                404,
                dict(message=e.message),
                method_name=method_metadata.name,
                no_audit_log=method_metadata.no_audit_log_required,
                token=token)
        except NotImplementedError as e:
            return self._BuildResponse(
                501,
                dict(message=e.message),
                method_name=method_metadata.name,
                no_audit_log=method_metadata.no_audit_log_required,
                token=token)
        except Exception as e:  # pylint: disable=broad-except
            logging.exception("Error while processing %s (%s) with %s: %s",
                              request.path, request.method,
                              handler.__class__.__name__, e)
            return self._BuildResponse(
                500,
                dict(message=str(e), traceBack=traceback.format_exc()),
                method_name=method_metadata.name,
                no_audit_log=method_metadata.no_audit_log_required,
                token=token)
示例#6
0
    def Check(self, method, url, payload=None, replace=None):
        """Records output of a given url accessed with a given method.

    Args:
      method: HTTP method. May be "GET" or "POST".
      url: String repesenting an url.
      payload: JSON-able payload that will be sent when "POST" method is used.
      replace: Dictionary of key->value pairs. In the recorded JSON output
               every "key" string will be replaced with its "value"
               counterpart. This way we can properly handle dynamically
               generated values (like hunts IDs) in the regression data.
    Raises:
      ValueError: if unsupported method argument is passed. Currently only
                  "GET", "POST", "DELETE" and "PATCH" are supported.
      RuntimeError: if request was handled by an unexpected API method (
                  every test is annotated with an "api_method" attribute
                  that points to the expected API method).
    """
        parsed_url = urlparse.urlparse(url)
        request = utils.DataObject(method=method,
                                   scheme="http",
                                   path=parsed_url.path,
                                   environ={
                                       "SERVER_NAME": "foo.bar",
                                       "SERVER_PORT": 1234
                                   },
                                   user="******")
        request.META = {"CONTENT_TYPE": "application/json"}

        if method == "GET":
            request.GET = dict(urlparse.parse_qsl(parsed_url.query))
        elif method in ["POST", "DELETE", "PATCH"]:
            request.body = json.dumps(payload or "")
        else:
            raise ValueError("Unsupported method: %s." % method)

        with self.NoAuthorizationChecks():
            http_response = http_api.RenderHttpResponse(request)

        api_method = http_response["X-API-Method"]
        if api_method != self.__class__.api_method:
            raise RuntimeError("Request was handled by an unexpected method. "
                               "Expected %s, got %s." %
                               (self.__class__.api_method, api_method))

        content = http_response.content

        xssi_token = ")]}'\n"
        if content.startswith(xssi_token):
            content = content[len(xssi_token):]

        # replace the values of all tracebacks by <traceback content>
        regex = re.compile(r'"traceBack": "Traceback[^"\\]*(?:\\.[^"\\]*)*"',
                           re.DOTALL)
        content = regex.sub('"traceBack": "<traceback content>"', content)

        if replace:
            if hasattr(replace, "__call__"):
                replace = replace()

            for substr, repl in replace.items():
                if hasattr(substr, "sub"):  # regex
                    content = substr.sub(repl, content)
                    url = substr.sub(repl, url)
                else:
                    content = content.replace(substr, repl)
                    url = url.replace(substr, repl)

        parsed_content = json.loads(content)
        check_result = dict(method=method,
                            url=url,
                            test_class=self.__class__.__name__,
                            response=parsed_content)

        if payload:
            check_result["request_payload"] = payload

        stripped_content = api_value_renderers.StripTypeInfo(parsed_content)
        if parsed_content != stripped_content:
            check_result["type_stripped_response"] = stripped_content

        self.checks.append(check_result)