def query(self): """ Takes a query_obj constructed in the client and returns payload data response for the given query_obj. """ query_context = QueryContext(**json.loads(request.form.get('query_context'))) security_manager.assert_datasource_permission(query_context.datasource, g.user) payload_json = query_context.get_data() return data_payload_response(payload_json)
def test_cache_key_changes_when_post_processing_is_updated(self): self.login(username="******") table_name = "birth_names" table = self.get_table_by_name(table_name) payload = get_query_context(table.name, table.id, table.type, add_postprocessing_operations=True) # construct baseline cache_key from query_context with post processing operation query_context = QueryContext(**payload) query_object = query_context.queries[0] cache_key_original = query_context.cache_key(query_object) # ensure added None post_processing operation doesn't change cache_key payload["queries"][0]["post_processing"].append(None) query_context = QueryContext(**payload) query_object = query_context.queries[0] cache_key_with_null = query_context.cache_key(query_object) self.assertEqual(cache_key_original, cache_key_with_null) # ensure query without post processing operation is different payload["queries"][0].pop("post_processing") query_context = QueryContext(**payload) query_object = query_context.queries[0] cache_key_without_post_processing = query_context.cache_key( query_object) self.assertNotEqual(cache_key_original, cache_key_without_post_processing)
def query(self): """ Takes a query_obj constructed in the client and returns payload data response for the given query_obj. """ query_context = QueryContext( **json.loads(request.form.get('query_context'))) security_manager.assert_datasource_permission(query_context.datasource, g.user) payload_json = query_context.get_data() return data_payload_response(payload_json)
def query(self): """ Takes a query_obj constructed in the client and returns payload data response for the given query_obj. params: query_context: json_blob """ query_context = QueryContext(**json.loads(request.form.get("query_context"))) security_manager.assert_query_context_permission(query_context) payload_json = query_context.get_payload() return json.dumps( payload_json, default=utils.json_int_dttm_ser, ignore_nan=True )
def query(self) -> FlaskResponse: """ Takes a query_obj constructed in the client and returns payload data response for the given query_obj. raises SupersetSecurityException: If the user cannot access the resource """ query_context = QueryContext(**json.loads(request.form["query_context"])) query_context.raise_for_access() result = query_context.get_payload() payload_json = result["queries"] return json.dumps( payload_json, default=utils.json_int_dttm_ser, ignore_nan=True )
def test_csv_response_format(self): """ Ensure that CSV result format works """ self.login(username="******") table_name = "birth_names" table = self.get_table_by_name(table_name) payload = get_query_context(table.name, table.id, table.type) payload["result_format"] = ChartDataResultFormat.CSV.value payload["queries"][0]["row_limit"] = 10 query_context = QueryContext(**payload) responses = query_context.get_payload() self.assertEqual(len(responses), 1) data = responses[0]["data"] self.assertIn("name,sum__num\n", data) self.assertEqual(len(data.split("\n")), 12)
def create(self, *, datasource: DatasourceDict, queries: List[Dict[str, Any]], form_data: Optional[Dict[str, Any]] = None, result_type: Optional[ChartDataResultType] = None, result_format: Optional[ChartDataResultFormat] = None, force: bool = False, custom_cache_timeout: Optional[int] = None) -> QueryContext: datasource_model_instance = None if datasource: datasource_model_instance = self._convert_to_model(datasource) result_type = result_type or ChartDataResultType.FULL result_format = result_format or ChartDataResultFormat.JSON queries_ = [ self._query_object_factory.create(result_type, **query_obj) for query_obj in queries ] cache_values = { "datasource": datasource, "queries": queries, "result_type": result_type, "result_format": result_format, } return QueryContext( datasource=datasource_model_instance, queries=queries_, form_data=form_data, result_type=result_type, result_format=result_format, force=force, custom_cache_timeout=custom_cache_timeout, cache_values=cache_values, )
def test_query_response_type(self): """ Ensure that query result type works """ self.login(username="******") table_name = "birth_names" table = self.get_table_by_name(table_name) payload = get_query_context(table.name, table.id, table.type) payload["result_type"] = ChartDataResultType.QUERY.value query_context = QueryContext(**payload) responses = query_context.get_payload() self.assertEqual(len(responses), 1) response = responses[0] self.assertEqual(len(response), 2) self.assertEqual(response["language"], "sql") self.assertIn("SELECT", response["query"])
def test_samples_response_type(self): """ Ensure that samples result type works """ self.login(username="******") table_name = "birth_names" table = self.get_table_by_name(table_name) payload = get_query_context(table.name, table.id, table.type) payload["result_type"] = ChartDataResultType.SAMPLES.value payload["queries"][0]["row_limit"] = 5 query_context = QueryContext(**payload) responses = query_context.get_payload() self.assertEqual(len(responses), 1) data = responses[0]["data"] self.assertIsInstance(data, list) self.assertEqual(len(data), 5) self.assertNotIn("sum__num", data[0])
def _get_full( query_context: QueryContext, query_obj: QueryObject, force_cached: Optional[bool] = False, ) -> Dict[str, Any]: datasource = _get_datasource(query_context, query_obj) result_type = query_obj.result_type or query_context.result_type payload = query_context.get_df_payload(query_obj, force_cached=force_cached) applied_template_filters = payload.get("applied_template_filters", []) df = payload["df"] status = payload["status"] if status != QueryStatus.FAILED: payload["colnames"] = list(df.columns) payload["indexnames"] = list(df.index) payload["coltypes"] = extract_dataframe_dtypes(df, datasource) payload["data"] = query_context.get_data(df) payload["result_format"] = query_context.result_format del payload["df"] filters = query_obj.filter filter_columns = cast(List[str], [flt.get("col") for flt in filters]) columns = set(datasource.column_names) applied_time_columns, rejected_time_columns = get_time_filter_status( datasource, query_obj.applied_time_extras ) payload["applied_filters"] = [ {"column": get_column_name(col)} for col in filter_columns if is_adhoc_column(col) or col in columns or col in applied_template_filters ] + applied_time_columns payload["rejected_filters"] = [ {"reason": ExtraFiltersReasonType.COL_NOT_IN_DATASOURCE, "column": col} for col in filter_columns if not is_adhoc_column(col) and col not in columns and col not in applied_template_filters ] + rejected_time_columns if result_type == ChartDataResultType.RESULTS and status != QueryStatus.FAILED: return { "data": payload.get("data"), "colnames": payload.get("colnames"), "coltypes": payload.get("coltypes"), } return payload
def test_query_context_time_range_endpoints(self): query_context = QueryContext(**self._get_query_context_dict()) query_object = query_context.queries[0] extras = query_object.to_dict()["extras"] self.assertTrue("time_range_endpoints" in extras) self.assertEquals( extras["time_range_endpoints"], (utils.TimeRangeEndpoint.INCLUSIVE, utils.TimeRangeEndpoint.EXCLUSIVE), )
def test_cache_key_changes_when_datasource_is_updated(self): qc_dict = self._get_query_context_dict() # construct baseline cache_key query_context = QueryContext(**qc_dict) query_object = query_context.queries[0] cache_key_original = query_context.cache_key(query_object) # make temporary change and revert it to refresh the changed_on property datasource = ConnectorRegistry.get_datasource( datasource_type=qc_dict["datasource"]["type"], datasource_id=qc_dict["datasource"]["id"], session=db.session, ) description_original = datasource.description datasource.description = "temporary description" db.session.commit() datasource.description = description_original db.session.commit() # create new QueryContext with unchanged attributes and extract new cache_key query_context = QueryContext(**qc_dict) query_object = query_context.queries[0] cache_key_new = query_context.cache_key(query_object) # the new cache_key should be different due to updated datasource self.assertNotEqual(cache_key_original, cache_key_new)
def test_cache_key_changes_when_datasource_is_updated(self): self.login(username="******") table_name = "birth_names" table = self.get_table_by_name(table_name) payload = get_query_context(table.name, table.id, table.type) # construct baseline cache_key query_context = QueryContext(**payload) query_object = query_context.queries[0] cache_key_original = query_context.cache_key(query_object) # make temporary change and revert it to refresh the changed_on property datasource = ConnectorRegistry.get_datasource( datasource_type=payload["datasource"]["type"], datasource_id=payload["datasource"]["id"], session=db.session, ) description_original = datasource.description datasource.description = "temporary description" db.session.commit() datasource.description = description_original db.session.commit() # create new QueryContext with unchanged attributes and extract new cache_key query_context = QueryContext(**payload) query_object = query_context.queries[0] cache_key_new = query_context.cache_key(query_object) # the new cache_key should be different due to updated datasource self.assertNotEqual(cache_key_original, cache_key_new)
def test_convert_deprecated_fields(self): """ Ensure that deprecated fields are converted correctly """ self.login(username="******") table_name = "birth_names" table = self.get_table_by_name(table_name) payload = get_query_context(table.name, table.id, table.type) payload["queries"][0]["granularity_sqla"] = "timecol" payload["queries"][0]["having_filters"] = {"col": "a", "op": "==", "val": "b"} query_context = QueryContext(**payload) self.assertEqual(len(query_context.queries), 1) query_object = query_context.queries[0] self.assertEqual(query_object.granularity, "timecol") self.assertIn("having_druid", query_object.extras)
def test_query_context_time_range_endpoints(self): """ Ensure that time_range_endpoints are populated automatically when missing from the payload """ self.login(username="******") table_name = "birth_names" table = self.get_table_by_name(table_name) payload = get_query_context(table.name, table.id, table.type) del payload["queries"][0]["extras"]["time_range_endpoints"] query_context = QueryContext(**payload) query_object = query_context.queries[0] extras = query_object.to_dict()["extras"] self.assertTrue("time_range_endpoints" in extras) self.assertEquals( extras["time_range_endpoints"], (TimeRangeEndpoint.INCLUSIVE, TimeRangeEndpoint.EXCLUSIVE), )
def make_query_context(self, data: Dict[str, Any]) -> QueryContext: query_context = QueryContext(**data) return query_context
def data(self) -> Response: """ Takes a query context constructed in the client and returns payload data response for the given query. --- post: description: >- Takes a query context constructed in the client and returns payload data response for the given query. requestBody: description: Query context schema required: true content: application/json: schema: type: object properties: datasource: type: object description: The datasource where the query will run properties: id: type: integer type: type: string queries: type: array items: type: object properties: granularity: type: string groupby: type: array items: type: string metrics: type: array items: type: object filters: type: array items: type: string row_limit: type: integer responses: 200: description: Query result content: application/json: schema: type: array items: type: object properties: cache_key: type: string cached_dttm: type: string cache_timeout: type: integer error: type: string is_cached: type: boolean query: type: string status: type: string stacktrace: type: string rowcount: type: integer data: type: array items: type: object 400: $ref: '#/components/responses/400' 401: $ref: '#/components/responses/401' 404: $ref: '#/components/responses/404' 500: $ref: '#/components/responses/500' """ if not request.is_json: return self.response_400(message="Request is not JSON") try: query_context = QueryContext(**request.json) except KeyError: return self.response_400(message="Request is incorrect") try: security_manager.assert_query_context_permission(query_context) except SupersetSecurityException: return self.response_401() payload_json = query_context.get_payload() response_data = simplejson.dumps(payload_json, default=json_int_dttm_ser, ignore_nan=True) resp = make_response(response_data, 200) resp.headers["Content-Type"] = "application/json; charset=utf-8" return resp