def stats_timing(stats_key: str, stats_logger: BaseStatsLogger) -> Iterator[float]:
    """Provide a transactional scope around a series of operations."""
    start_ts = now_as_float()
    try:
        yield start_ts
    except Exception as ex:
        raise ex
    finally:
        stats_logger.timing(stats_key, now_as_float() - start_ts)
class BaseSupersetModelRestApi(ModelRestApi):
    """
    Extends FAB's ModelResApi to implement specific superset generic functionality
    """

    csrf_exempt = False
    method_permission_name = {
        "bulk_delete": "delete",
        "data": "list",
        "data_from_cache": "list",
        "delete": "delete",
        "distinct": "list",
        "export": "mulexport",
        "import_": "add",
        "get": "show",
        "get_list": "list",
        "info": "list",
        "post": "add",
        "put": "edit",
        "refresh": "edit",
        "related": "list",
        "related_objects": "list",
        "schemas": "list",
        "select_star": "list",
        "table_metadata": "list",
        "test_connection": "post",
        "thumbnail": "list",
        "viz_types": "list",
    }

    order_rel_fields: Dict[str, Tuple[str, str]] = {}
    """
    Impose ordering on related fields query::

        order_rel_fields = {
            "<RELATED_FIELD>": ("<RELATED_FIELD_FIELD>", "<asc|desc>"),
             ...
        }
    """  # pylint: disable=pointless-string-statement
    related_field_filters: Dict[str, Union[RelatedFieldFilter, str]] = {}
    """
    Declare the filters for related fields::

        related_fields = {
            "<RELATED_FIELD>": <RelatedFieldFilter>)
        }
    """  # pylint: disable=pointless-string-statement
    filter_rel_fields: Dict[str, BaseFilter] = {}
    """
    Declare the related field base filter::

        filter_rel_fields_field = {
            "<RELATED_FIELD>": "<FILTER>")
        }
    """  # pylint: disable=pointless-string-statement
    allowed_rel_fields: Set[str] = set()
    """
    Declare a set of allowed related fields that the `related` endpoint supports
    """  # pylint: disable=pointless-string-statement

    text_field_rel_fields: Dict[str, str] = {}
    """
    Declare an alternative for the human readable representation of the Model object::

        text_field_rel_fields = {
            "<RELATED_FIELD>": "<RELATED_OBJECT_FIELD>"
        }
    """  # pylint: disable=pointless-string-statement

    allowed_distinct_fields: Set[str] = set()

    openapi_spec_component_schemas: Tuple[Type[Schema], ...] = tuple()
    """
    Add extra schemas to the OpenAPI component schemas section
    """

  # pylint: disable=pointless-string-statement

    add_columns: List[str]
    edit_columns: List[str]
    list_columns: List[str]
    show_columns: List[str]

    responses = {
        "400": {
            "description": "Bad request",
            "content": error_payload_content
        },
        "401": {
            "description": "Unauthorized",
            "content": error_payload_content
        },
        "403": {
            "description": "Forbidden",
            "content": error_payload_content
        },
        "404": {
            "description": "Not found",
            "content": error_payload_content
        },
        "422": {
            "description": "Could not process entity",
            "content": error_payload_content,
        },
        "500": {
            "description": "Fatal error",
            "content": error_payload_content
        },
    }

    def __init__(self) -> None:
        # Setup statsd
        self.stats_logger = BaseStatsLogger()
        # Add base API spec base query parameter schemas
        if self.apispec_parameter_schemas is None:  # type: ignore
            self.apispec_parameter_schemas = {}
        self.apispec_parameter_schemas[
            "get_related_schema"] = get_related_schema
        if self.openapi_spec_component_schemas is None:
            self.openapi_spec_component_schemas = ()
        self.openapi_spec_component_schemas = self.openapi_spec_component_schemas + (
            RelatedResponseSchema,
            DistincResponseSchema,
        )
        super().__init__()

    def add_apispec_components(self, api_spec: APISpec) -> None:
        """
        Adds extra OpenApi schema spec components, these are declared
        on the `openapi_spec_component_schemas` class property
        """
        for schema in self.openapi_spec_component_schemas:
            try:
                api_spec.components.schema(
                    schema.__name__,
                    schema=schema,
                )
            except DuplicateComponentNameError:
                pass
        super().add_apispec_components(api_spec)

    def create_blueprint(self, appbuilder: AppBuilder, *args: Any,
                         **kwargs: Any) -> Blueprint:
        self.stats_logger = self.appbuilder.get_app.config["STATS_LOGGER"]
        return super().create_blueprint(appbuilder, *args, **kwargs)

    def _init_properties(self) -> None:
        model_id = self.datamodel.get_pk_name()
        if self.list_columns is None and not self.list_model_schema:
            self.list_columns = [model_id]
        if self.show_columns is None and not self.show_model_schema:
            self.show_columns = [model_id]
        if self.edit_columns is None and not self.edit_model_schema:
            self.edit_columns = [model_id]
        if self.add_columns is None and not self.add_model_schema:
            self.add_columns = [model_id]
        super()._init_properties()

    def _get_related_filter(self, datamodel: SQLAInterface, column_name: str,
                            value: str) -> Filters:
        filter_field = self.related_field_filters.get(column_name)
        if isinstance(filter_field, str):
            filter_field = RelatedFieldFilter(cast(str, filter_field),
                                              FilterStartsWith)
        filter_field = cast(RelatedFieldFilter, filter_field)
        search_columns = [filter_field.field_name] if filter_field else None
        filters = datamodel.get_filters(search_columns)
        base_filters = self.filter_rel_fields.get(column_name)
        if base_filters:
            filters.add_filter_list(base_filters)
        if value and filter_field:
            filters.add_filter(filter_field.field_name,
                               filter_field.filter_class, value)
        return filters

    def _get_distinct_filter(self, column_name: str, value: str) -> Filters:
        filter_field = RelatedFieldFilter(column_name, FilterStartsWith)
        filter_field = cast(RelatedFieldFilter, filter_field)
        search_columns = [filter_field.field_name] if filter_field else None
        filters = self.datamodel.get_filters(search_columns)
        filters.add_filter_list(self.base_filters)
        if value and filter_field:
            filters.add_filter(filter_field.field_name,
                               filter_field.filter_class, value)
        return filters

    def _get_text_for_model(self, model: Model, column_name: str) -> str:
        if column_name in self.text_field_rel_fields:
            model_column_name = self.text_field_rel_fields.get(column_name)
            if model_column_name:
                return getattr(model, model_column_name)
        return str(model)

    def _get_result_from_rows(self, datamodel: SQLAInterface,
                              rows: List[Model],
                              column_name: str) -> List[Dict[str, Any]]:
        return [{
            "value": datamodel.get_pk_value(row),
            "text": self._get_text_for_model(row, column_name),
        } for row in rows]

    def _add_extra_ids_to_result(
        self,
        datamodel: SQLAInterface,
        column_name: str,
        ids: List[int],
        result: List[Dict[str, Any]],
    ) -> None:
        if ids:
            # Filter out already present values on the result
            values = [row["value"] for row in result]
            ids = [id_ for id_ in ids if id_ not in values]
            pk_col = datamodel.get_pk()
            # Fetch requested values from ids
            extra_rows = db.session.query(datamodel.obj).filter(
                pk_col.in_(ids)).all()
            result += self._get_result_from_rows(datamodel, extra_rows,
                                                 column_name)

    def incr_stats(self, action: str, func_name: str) -> None:
        """
        Proxy function for statsd.incr to impose a key structure for REST API's

        :param action: String with an action name eg: error, success
        :param func_name: The function name
        """
        self.stats_logger.incr(
            f"{self.__class__.__name__}.{func_name}.{action}")

    def timing_stats(self, action: str, func_name: str, value: float) -> None:
        """
        Proxy function for statsd.incr to impose a key structure for REST API's

        :param action: String with an action name eg: error, success
        :param func_name: The function name
        :param value: A float with the time it took for the endpoint to execute
        """
        self.stats_logger.timing(
            f"{self.__class__.__name__}.{func_name}.{action}", value)

    def send_stats_metrics(self,
                           response: Response,
                           key: str,
                           time_delta: Optional[float] = None) -> None:
        """
        Helper function to handle sending statsd metrics

        :param response: flask response object, will evaluate if it was an error
        :param key: The function name
        :param time_delta: Optional time it took for the endpoint to execute
        """
        if 200 <= response.status_code < 400:
            self.incr_stats("success", key)
        else:
            self.incr_stats("error", key)
        if time_delta:
            self.timing_stats("time", key, time_delta)

    @event_logger.log_this_with_context(
        action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.info",
        object_ref=False,
        log_to_statsd=False,
    )
    def info_headless(self, **kwargs: Any) -> Response:
        """
        Add statsd metrics to builtin FAB _info endpoint
        """
        duration, response = time_function(super().info_headless, **kwargs)
        self.send_stats_metrics(response, self.info.__name__, duration)
        return response

    @event_logger.log_this_with_context(
        action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.get",
        object_ref=False,
        log_to_statsd=False,
    )
    def get_headless(self, pk: int, **kwargs: Any) -> Response:
        """
        Add statsd metrics to builtin FAB GET endpoint
        """
        duration, response = time_function(super().get_headless, pk, **kwargs)
        self.send_stats_metrics(response, self.get.__name__, duration)
        return response

    @event_logger.log_this_with_context(
        action=lambda self, *args, **kwargs:
        f"{self.__class__.__name__}.get_list",
        object_ref=False,
        log_to_statsd=False,
    )
    def get_list_headless(self, **kwargs: Any) -> Response:
        """
        Add statsd metrics to builtin FAB GET list endpoint
        """
        duration, response = time_function(super().get_list_headless, **kwargs)
        self.send_stats_metrics(response, self.get_list.__name__, duration)
        return response

    @event_logger.log_this_with_context(
        action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.post",
        object_ref=False,
        log_to_statsd=False,
    )
    def post_headless(self) -> Response:
        """
        Add statsd metrics to builtin FAB POST endpoint
        """
        duration, response = time_function(super().post_headless)
        self.send_stats_metrics(response, self.post.__name__, duration)
        return response

    @event_logger.log_this_with_context(
        action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.put",
        object_ref=False,
        log_to_statsd=False,
    )
    def put_headless(self, pk: int) -> Response:
        """
        Add statsd metrics to builtin FAB PUT endpoint
        """
        duration, response = time_function(super().put_headless, pk)
        self.send_stats_metrics(response, self.put.__name__, duration)
        return response

    @event_logger.log_this_with_context(
        action=lambda self, *args, **kwargs:
        f"{self.__class__.__name__}.delete",
        object_ref=False,
        log_to_statsd=False,
    )
    def delete_headless(self, pk: int) -> Response:
        """
        Add statsd metrics to builtin FAB DELETE endpoint
        """
        duration, response = time_function(super().delete_headless, pk)
        self.send_stats_metrics(response, self.delete.__name__, duration)
        return response

    @expose("/related/<column_name>", methods=["GET"])
    @protect()
    @safe
    @statsd_metrics
    @rison(get_related_schema)
    def related(self, column_name: str, **kwargs: Any) -> FlaskResponse:
        """Get related fields data
        ---
        get:
          parameters:
          - in: path
            schema:
              type: string
            name: column_name
          - in: query
            name: q
            content:
              application/json:
                schema:
                  $ref: '#/components/schemas/get_related_schema'
          responses:
            200:
              description: Related column data
              content:
                application/json:
                  schema:
                  schema:
                    $ref: "#/components/schemas/RelatedResponseSchema"
            400:
              $ref: '#/components/responses/400'
            401:
              $ref: '#/components/responses/401'
            404:
              $ref: '#/components/responses/404'
            500:
              $ref: '#/components/responses/500'
        """
        if column_name not in self.allowed_rel_fields:
            self.incr_stats("error", self.related.__name__)
            return self.response_404()
        args = kwargs.get("rison", {})

        # handle pagination
        page, page_size = self._handle_page_args(args)
        try:
            datamodel = self.datamodel.get_related_interface(column_name)
        except KeyError:
            return self.response_404()
        page, page_size = self._sanitize_page_args(page, page_size)
        # handle ordering
        order_field = self.order_rel_fields.get(column_name)
        if order_field:
            order_column, order_direction = order_field
        else:
            order_column, order_direction = "", ""
        # handle filters
        filters = self._get_related_filter(datamodel, column_name,
                                           args.get("filter"))
        # Make the query
        _, rows = datamodel.query(filters,
                                  order_column,
                                  order_direction,
                                  page=page,
                                  page_size=page_size)

        # produce response
        result = self._get_result_from_rows(datamodel, rows, column_name)

        # If ids are specified make sure we fetch and include them on the response
        ids = args.get("include_ids")
        self._add_extra_ids_to_result(datamodel, column_name, ids, result)

        return self.response(200, count=len(result), result=result)

    @expose("/distinct/<column_name>", methods=["GET"])
    @protect()
    @safe
    @statsd_metrics
    @rison(get_related_schema)
    def distinct(self, column_name: str, **kwargs: Any) -> FlaskResponse:
        """Get distinct values from field data
        ---
        get:
          parameters:
          - in: path
            schema:
              type: string
            name: column_name
          - in: query
            name: q
            content:
              application/json:
                schema:
                  $ref: '#/components/schemas/get_related_schema'
          responses:
            200:
              description: Distinct field data
              content:
                application/json:
                  schema:
                  schema:
                    $ref: "#/components/schemas/DistincResponseSchema"
            400:
              $ref: '#/components/responses/400'
            401:
              $ref: '#/components/responses/401'
            404:
              $ref: '#/components/responses/404'
            500:
              $ref: '#/components/responses/500'
        """
        if column_name not in self.allowed_distinct_fields:
            self.incr_stats("error", self.related.__name__)
            return self.response_404()
        args = kwargs.get("rison", {})
        # handle pagination
        page, page_size = self._sanitize_page_args(
            *self._handle_page_args(args))
        # Create generic base filters with added request filter
        filters = self._get_distinct_filter(column_name, args.get("filter"))
        # Make the query
        query_count = self.appbuilder.get_session.query(
            func.count(distinct(getattr(self.datamodel.obj, column_name))))
        count = self.datamodel.apply_filters(query_count, filters).scalar()
        if count == 0:
            return self.response(200, count=count, result=[])
        query = self.appbuilder.get_session.query(
            distinct(getattr(self.datamodel.obj, column_name)))
        # Apply generic base filters with added request filter
        query = self.datamodel.apply_filters(query, filters)
        # Apply sort
        query = self.datamodel.apply_order_by(query, column_name, "asc")
        # Apply pagination
        result = self.datamodel.apply_pagination(query, page, page_size).all()
        # produce response
        result = [{
            "text": item[0],
            "value": item[0]
        } for item in result if item[0] is not None]
        return self.response(200, count=count, result=result)
Beispiel #3
0
class BaseSupersetModelRestApi(ModelRestApi):
    """
    Extends FAB's ModelResApi to implement specific superset generic functionality
    """

    csrf_exempt = False
    method_permission_name = {
        "get_list": "list",
        "get": "show",
        "export": "mulexport",
        "post": "add",
        "put": "edit",
        "delete": "delete",
        "bulk_delete": "delete",
        "info": "list",
        "related": "list",
        "thumbnail": "list",
        "refresh": "edit",
        "data": "list",
        "viz_types": "list",
        "datasources": "list",
    }

    order_rel_fields: Dict[str, Tuple[str, str]] = {}
    """
    Impose ordering on related fields query::

        order_rel_fields = {
            "<RELATED_FIELD>": ("<RELATED_FIELD_FIELD>", "<asc|desc>"),
             ...
        }
    """  # pylint: disable=pointless-string-statement
    related_field_filters: Dict[str, Union[RelatedFieldFilter, str]] = {}
    """
    Declare the filters for related fields::

        related_fields = {
            "<RELATED_FIELD>": <RelatedFieldFilter>)
        }
    """  # pylint: disable=pointless-string-statement
    filter_rel_fields: Dict[str, BaseFilter] = {}
    """
    Declare the related field base filter::

        filter_rel_fields_field = {
            "<RELATED_FIELD>": "<FILTER>")
        }
    """  # pylint: disable=pointless-string-statement
    allowed_rel_fields: Set[str] = set()

    openapi_spec_component_schemas: Tuple[Schema, ...] = tuple()
    """
    Add extra schemas to the OpenAPI component schemas section
    """

  # pylint: disable=pointless-string-statement

    add_columns: List[str]
    edit_columns: List[str]
    list_columns: List[str]
    show_columns: List[str]

    def __init__(self) -> None:
        super().__init__()
        self.stats_logger = BaseStatsLogger()

    def add_apispec_components(self, api_spec: APISpec) -> None:

        for schema in self.openapi_spec_component_schemas:
            api_spec.components.schema(
                schema.__name__,
                schema=schema,
            )
        super().add_apispec_components(api_spec)

    def create_blueprint(self, appbuilder: AppBuilder, *args: Any,
                         **kwargs: Any) -> Blueprint:
        self.stats_logger = self.appbuilder.get_app.config["STATS_LOGGER"]
        return super().create_blueprint(appbuilder, *args, **kwargs)

    def _init_properties(self) -> None:
        model_id = self.datamodel.get_pk_name()
        if self.list_columns is None and not self.list_model_schema:
            self.list_columns = [model_id]
        if self.show_columns is None and not self.show_model_schema:
            self.show_columns = [model_id]
        if self.edit_columns is None and not self.edit_model_schema:
            self.edit_columns = [model_id]
        if self.add_columns is None and not self.add_model_schema:
            self.add_columns = [model_id]
        super()._init_properties()

    def _get_related_filter(self, datamodel: Model, column_name: str,
                            value: str) -> Filters:
        filter_field = self.related_field_filters.get(column_name)
        if isinstance(filter_field, str):
            filter_field = RelatedFieldFilter(cast(str, filter_field),
                                              FilterStartsWith)
        filter_field = cast(RelatedFieldFilter, filter_field)
        search_columns = [filter_field.field_name] if filter_field else None
        filters = datamodel.get_filters(search_columns)
        base_filters = self.filter_rel_fields.get(column_name)
        if base_filters:
            filters.add_filter_list(base_filters)
        if value and filter_field:
            filters.add_filter(filter_field.field_name,
                               filter_field.filter_class, value)
        return filters

    def incr_stats(self, action: str, func_name: str) -> None:
        """
        Proxy function for statsd.incr to impose a key structure for REST API's

        :param action: String with an action name eg: error, success
        :param func_name: The function name
        """
        self.stats_logger.incr(
            f"{self.__class__.__name__}.{func_name}.{action}")

    def timing_stats(self, action: str, func_name: str, value: float) -> None:
        """
        Proxy function for statsd.incr to impose a key structure for REST API's

        :param action: String with an action name eg: error, success
        :param func_name: The function name
        :param value: A float with the time it took for the endpoint to execute
        """
        self.stats_logger.timing(
            f"{self.__class__.__name__}.{func_name}.{action}", value)

    def send_stats_metrics(self,
                           response: Response,
                           key: str,
                           time_delta: Optional[float] = None) -> None:
        """
        Helper function to handle sending statsd metrics

        :param response: flask response object, will evaluate if it was an error
        :param key: The function name
        :param time_delta: Optional time it took for the endpoint to execute
        """
        if 200 <= response.status_code < 400:
            self.incr_stats("success", key)
        else:
            self.incr_stats("error", key)
        if time_delta:
            self.timing_stats("time", key, time_delta)

    def info_headless(self, **kwargs: Any) -> Response:
        """
        Add statsd metrics to builtin FAB _info endpoint
        """
        duration, response = time_function(super().info_headless, **kwargs)
        self.send_stats_metrics(response, self.info.__name__, duration)
        return response

    def get_headless(self, pk: int, **kwargs: Any) -> Response:
        """
        Add statsd metrics to builtin FAB GET endpoint
        """
        duration, response = time_function(super().get_headless, pk, **kwargs)
        self.send_stats_metrics(response, self.get.__name__, duration)
        return response

    def get_list_headless(self, **kwargs: Any) -> Response:
        """
        Add statsd metrics to builtin FAB GET list endpoint
        """
        duration, response = time_function(super().get_list_headless, **kwargs)
        self.send_stats_metrics(response, self.get_list.__name__, duration)
        return response

    @expose("/related/<column_name>", methods=["GET"])
    @protect()
    @safe
    @statsd_metrics
    @rison(get_related_schema)
    def related(self, column_name: str, **kwargs: Any) -> FlaskResponse:
        """Get related fields data
        ---
        get:
          parameters:
          - in: path
            schema:
              type: string
            name: column_name
          - in: query
            name: q
            content:
              application/json:
                schema:
                  type: object
                  properties:
                    page_size:
                      type: integer
                    page:
                      type: integer
                    filter:
                      type: string
          responses:
            200:
              description: Related column data
              content:
                application/json:
                  schema:
                    type: object
                    properties:
                      count:
                        type: integer
                      result:
                        type: object
                        properties:
                          value:
                            type: integer
                          text:
                            type: string
            400:
              $ref: '#/components/responses/400'
            401:
              $ref: '#/components/responses/401'
            404:
              $ref: '#/components/responses/404'
            422:
              $ref: '#/components/responses/422'
            500:
              $ref: '#/components/responses/500'
        """
        if column_name not in self.allowed_rel_fields:
            self.incr_stats("error", self.related.__name__)
            return self.response_404()
        args = kwargs.get("rison", {})
        # handle pagination
        page, page_size = self._handle_page_args(args)
        try:
            datamodel = self.datamodel.get_related_interface(column_name)
        except KeyError:
            return self.response_404()
        page, page_size = self._sanitize_page_args(page, page_size)
        # handle ordering
        order_field = self.order_rel_fields.get(column_name)
        if order_field:
            order_column, order_direction = order_field
        else:
            order_column, order_direction = "", ""
        # handle filters
        filters = self._get_related_filter(datamodel, column_name,
                                           args.get("filter"))
        # Make the query
        count, values = datamodel.query(filters,
                                        order_column,
                                        order_direction,
                                        page=page,
                                        page_size=page_size)
        # produce response
        result = [{
            "value": datamodel.get_pk_value(value),
            "text": str(value)
        } for value in values]
        return self.response(200, count=count, result=result)