Exemple #1
0
 def query(self):
     """
     Takes a query_obj constructed in the client and returns payload data response
     for the given query_obj.
     """
     query_context = QueryContext(**json.loads(request.form.get('query_context')))
     security_manager.assert_datasource_permission(query_context.datasource, g.user)
     payload_json = query_context.get_data()
     return data_payload_response(payload_json)
    def test_cache_key_changes_when_post_processing_is_updated(self):
        self.login(username="******")
        table_name = "birth_names"
        table = self.get_table_by_name(table_name)
        payload = get_query_context(table.name,
                                    table.id,
                                    table.type,
                                    add_postprocessing_operations=True)

        # construct baseline cache_key from query_context with post processing operation
        query_context = QueryContext(**payload)
        query_object = query_context.queries[0]
        cache_key_original = query_context.cache_key(query_object)

        # ensure added None post_processing operation doesn't change cache_key
        payload["queries"][0]["post_processing"].append(None)
        query_context = QueryContext(**payload)
        query_object = query_context.queries[0]
        cache_key_with_null = query_context.cache_key(query_object)
        self.assertEqual(cache_key_original, cache_key_with_null)

        # ensure query without post processing operation is different
        payload["queries"][0].pop("post_processing")
        query_context = QueryContext(**payload)
        query_object = query_context.queries[0]
        cache_key_without_post_processing = query_context.cache_key(
            query_object)
        self.assertNotEqual(cache_key_original,
                            cache_key_without_post_processing)
Exemple #3
0
 def query(self):
     """
     Takes a query_obj constructed in the client and returns payload data response
     for the given query_obj.
     """
     query_context = QueryContext(
         **json.loads(request.form.get('query_context')))
     security_manager.assert_datasource_permission(query_context.datasource,
                                                   g.user)
     payload_json = query_context.get_data()
     return data_payload_response(payload_json)
Exemple #4
0
 def query(self):
     """
     Takes a query_obj constructed in the client and returns payload data response
     for the given query_obj.
     params: query_context: json_blob
     """
     query_context = QueryContext(**json.loads(request.form.get("query_context")))
     security_manager.assert_query_context_permission(query_context)
     payload_json = query_context.get_payload()
     return json.dumps(
         payload_json, default=utils.json_int_dttm_ser, ignore_nan=True
     )
Exemple #5
0
    def query(self) -> FlaskResponse:
        """
        Takes a query_obj constructed in the client and returns payload data response
        for the given query_obj.

        raises SupersetSecurityException: If the user cannot access the resource
        """
        query_context = QueryContext(**json.loads(request.form["query_context"]))
        query_context.raise_for_access()
        result = query_context.get_payload()
        payload_json = result["queries"]
        return json.dumps(
            payload_json, default=utils.json_int_dttm_ser, ignore_nan=True
        )
 def test_csv_response_format(self):
     """
     Ensure that CSV result format works
     """
     self.login(username="******")
     table_name = "birth_names"
     table = self.get_table_by_name(table_name)
     payload = get_query_context(table.name, table.id, table.type)
     payload["result_format"] = ChartDataResultFormat.CSV.value
     payload["queries"][0]["row_limit"] = 10
     query_context = QueryContext(**payload)
     responses = query_context.get_payload()
     self.assertEqual(len(responses), 1)
     data = responses[0]["data"]
     self.assertIn("name,sum__num\n", data)
     self.assertEqual(len(data.split("\n")), 12)
Exemple #7
0
 def create(self,
            *,
            datasource: DatasourceDict,
            queries: List[Dict[str, Any]],
            form_data: Optional[Dict[str, Any]] = None,
            result_type: Optional[ChartDataResultType] = None,
            result_format: Optional[ChartDataResultFormat] = None,
            force: bool = False,
            custom_cache_timeout: Optional[int] = None) -> QueryContext:
     datasource_model_instance = None
     if datasource:
         datasource_model_instance = self._convert_to_model(datasource)
     result_type = result_type or ChartDataResultType.FULL
     result_format = result_format or ChartDataResultFormat.JSON
     queries_ = [
         self._query_object_factory.create(result_type, **query_obj)
         for query_obj in queries
     ]
     cache_values = {
         "datasource": datasource,
         "queries": queries,
         "result_type": result_type,
         "result_format": result_format,
     }
     return QueryContext(
         datasource=datasource_model_instance,
         queries=queries_,
         form_data=form_data,
         result_type=result_type,
         result_format=result_format,
         force=force,
         custom_cache_timeout=custom_cache_timeout,
         cache_values=cache_values,
     )
 def test_query_response_type(self):
     """
     Ensure that query result type works
     """
     self.login(username="******")
     table_name = "birth_names"
     table = self.get_table_by_name(table_name)
     payload = get_query_context(table.name, table.id, table.type)
     payload["result_type"] = ChartDataResultType.QUERY.value
     query_context = QueryContext(**payload)
     responses = query_context.get_payload()
     self.assertEqual(len(responses), 1)
     response = responses[0]
     self.assertEqual(len(response), 2)
     self.assertEqual(response["language"], "sql")
     self.assertIn("SELECT", response["query"])
 def test_samples_response_type(self):
     """
     Ensure that samples result type works
     """
     self.login(username="******")
     table_name = "birth_names"
     table = self.get_table_by_name(table_name)
     payload = get_query_context(table.name, table.id, table.type)
     payload["result_type"] = ChartDataResultType.SAMPLES.value
     payload["queries"][0]["row_limit"] = 5
     query_context = QueryContext(**payload)
     responses = query_context.get_payload()
     self.assertEqual(len(responses), 1)
     data = responses[0]["data"]
     self.assertIsInstance(data, list)
     self.assertEqual(len(data), 5)
     self.assertNotIn("sum__num", data[0])
Exemple #10
0
def _get_full(
    query_context: QueryContext,
    query_obj: QueryObject,
    force_cached: Optional[bool] = False,
) -> Dict[str, Any]:
    datasource = _get_datasource(query_context, query_obj)
    result_type = query_obj.result_type or query_context.result_type
    payload = query_context.get_df_payload(query_obj, force_cached=force_cached)
    applied_template_filters = payload.get("applied_template_filters", [])
    df = payload["df"]
    status = payload["status"]
    if status != QueryStatus.FAILED:
        payload["colnames"] = list(df.columns)
        payload["indexnames"] = list(df.index)
        payload["coltypes"] = extract_dataframe_dtypes(df, datasource)
        payload["data"] = query_context.get_data(df)
        payload["result_format"] = query_context.result_format
    del payload["df"]

    filters = query_obj.filter
    filter_columns = cast(List[str], [flt.get("col") for flt in filters])
    columns = set(datasource.column_names)
    applied_time_columns, rejected_time_columns = get_time_filter_status(
        datasource, query_obj.applied_time_extras
    )
    payload["applied_filters"] = [
        {"column": get_column_name(col)}
        for col in filter_columns
        if is_adhoc_column(col) or col in columns or col in applied_template_filters
    ] + applied_time_columns
    payload["rejected_filters"] = [
        {"reason": ExtraFiltersReasonType.COL_NOT_IN_DATASOURCE, "column": col}
        for col in filter_columns
        if not is_adhoc_column(col)
        and col not in columns
        and col not in applied_template_filters
    ] + rejected_time_columns

    if result_type == ChartDataResultType.RESULTS and status != QueryStatus.FAILED:
        return {
            "data": payload.get("data"),
            "colnames": payload.get("colnames"),
            "coltypes": payload.get("coltypes"),
        }
    return payload
    def test_query_context_time_range_endpoints(self):
        query_context = QueryContext(**self._get_query_context_dict())
        query_object = query_context.queries[0]
        extras = query_object.to_dict()["extras"]
        self.assertTrue("time_range_endpoints" in extras)

        self.assertEquals(
            extras["time_range_endpoints"],
            (utils.TimeRangeEndpoint.INCLUSIVE, utils.TimeRangeEndpoint.EXCLUSIVE),
        )
    def test_cache_key_changes_when_datasource_is_updated(self):
        qc_dict = self._get_query_context_dict()

        # construct baseline cache_key
        query_context = QueryContext(**qc_dict)
        query_object = query_context.queries[0]
        cache_key_original = query_context.cache_key(query_object)

        # make temporary change and revert it to refresh the changed_on property
        datasource = ConnectorRegistry.get_datasource(
            datasource_type=qc_dict["datasource"]["type"],
            datasource_id=qc_dict["datasource"]["id"],
            session=db.session,
        )
        description_original = datasource.description
        datasource.description = "temporary description"
        db.session.commit()
        datasource.description = description_original
        db.session.commit()

        # create new QueryContext with unchanged attributes and extract new cache_key
        query_context = QueryContext(**qc_dict)
        query_object = query_context.queries[0]
        cache_key_new = query_context.cache_key(query_object)

        # the new cache_key should be different due to updated datasource
        self.assertNotEqual(cache_key_original, cache_key_new)
Exemple #13
0
    def test_cache_key_changes_when_datasource_is_updated(self):
        self.login(username="******")
        table_name = "birth_names"
        table = self.get_table_by_name(table_name)
        payload = get_query_context(table.name, table.id, table.type)

        # construct baseline cache_key
        query_context = QueryContext(**payload)
        query_object = query_context.queries[0]
        cache_key_original = query_context.cache_key(query_object)

        # make temporary change and revert it to refresh the changed_on property
        datasource = ConnectorRegistry.get_datasource(
            datasource_type=payload["datasource"]["type"],
            datasource_id=payload["datasource"]["id"],
            session=db.session,
        )
        description_original = datasource.description
        datasource.description = "temporary description"
        db.session.commit()
        datasource.description = description_original
        db.session.commit()

        # create new QueryContext with unchanged attributes and extract new cache_key
        query_context = QueryContext(**payload)
        query_object = query_context.queries[0]
        cache_key_new = query_context.cache_key(query_object)

        # the new cache_key should be different due to updated datasource
        self.assertNotEqual(cache_key_original, cache_key_new)
Exemple #14
0
 def test_convert_deprecated_fields(self):
     """
     Ensure that deprecated fields are converted correctly
     """
     self.login(username="******")
     table_name = "birth_names"
     table = self.get_table_by_name(table_name)
     payload = get_query_context(table.name, table.id, table.type)
     payload["queries"][0]["granularity_sqla"] = "timecol"
     payload["queries"][0]["having_filters"] = {"col": "a", "op": "==", "val": "b"}
     query_context = QueryContext(**payload)
     self.assertEqual(len(query_context.queries), 1)
     query_object = query_context.queries[0]
     self.assertEqual(query_object.granularity, "timecol")
     self.assertIn("having_druid", query_object.extras)
Exemple #15
0
    def test_query_context_time_range_endpoints(self):
        """
        Ensure that time_range_endpoints are populated automatically when missing
        from the payload
        """
        self.login(username="******")
        table_name = "birth_names"
        table = self.get_table_by_name(table_name)
        payload = get_query_context(table.name, table.id, table.type)
        del payload["queries"][0]["extras"]["time_range_endpoints"]
        query_context = QueryContext(**payload)
        query_object = query_context.queries[0]
        extras = query_object.to_dict()["extras"]
        self.assertTrue("time_range_endpoints" in extras)

        self.assertEquals(
            extras["time_range_endpoints"],
            (TimeRangeEndpoint.INCLUSIVE, TimeRangeEndpoint.EXCLUSIVE),
        )
Exemple #16
0
 def make_query_context(self, data: Dict[str, Any]) -> QueryContext:
     query_context = QueryContext(**data)
     return query_context
Exemple #17
0
 def data(self) -> Response:
     """
     Takes a query context constructed in the client and returns payload
     data response for the given query.
     ---
     post:
       description: >-
         Takes a query context constructed in the client and returns payload data
         response for the given query.
       requestBody:
         description: Query context schema
         required: true
         content:
           application/json:
             schema:
               type: object
               properties:
                 datasource:
                   type: object
                   description: The datasource where the query will run
                   properties:
                     id:
                       type: integer
                     type:
                       type: string
                 queries:
                   type: array
                   items:
                     type: object
                     properties:
                       granularity:
                         type: string
                       groupby:
                         type: array
                         items:
                           type: string
                       metrics:
                         type: array
                         items:
                           type: object
                       filters:
                         type: array
                         items:
                           type: string
                       row_limit:
                         type: integer
       responses:
         200:
           description: Query result
           content:
             application/json:
               schema:
                 type: array
                 items:
                   type: object
                   properties:
                     cache_key:
                       type: string
                     cached_dttm:
                       type: string
                     cache_timeout:
                       type: integer
                     error:
                       type: string
                     is_cached:
                       type: boolean
                     query:
                       type: string
                     status:
                       type: string
                     stacktrace:
                       type: string
                     rowcount:
                       type: integer
                     data:
                       type: array
                       items:
                         type: object
         400:
           $ref: '#/components/responses/400'
         401:
           $ref: '#/components/responses/401'
         404:
           $ref: '#/components/responses/404'
         500:
           $ref: '#/components/responses/500'
     """
     if not request.is_json:
         return self.response_400(message="Request is not JSON")
     try:
         query_context = QueryContext(**request.json)
     except KeyError:
         return self.response_400(message="Request is incorrect")
     try:
         security_manager.assert_query_context_permission(query_context)
     except SupersetSecurityException:
         return self.response_401()
     payload_json = query_context.get_payload()
     response_data = simplejson.dumps(payload_json,
                                      default=json_int_dttm_ser,
                                      ignore_nan=True)
     resp = make_response(response_data, 200)
     resp.headers["Content-Type"] = "application/json; charset=utf-8"
     return resp