def test_query_context_limit_and_offset(self): self.login(username="******") payload = get_query_context("birth_names") # Use defaults payload["queries"][0].pop("row_limit", None) payload["queries"][0].pop("row_offset", None) query_context = ChartDataQueryContextSchema().load(payload) query_object = query_context.queries[0] self.assertEqual(query_object.row_limit, app.config["ROW_LIMIT"]) self.assertEqual(query_object.row_offset, 0) # Valid limit and offset payload["queries"][0]["row_limit"] = 100 payload["queries"][0]["row_offset"] = 200 query_context = ChartDataQueryContextSchema().load(payload) query_object = query_context.queries[0] self.assertEqual(query_object.row_limit, 100) self.assertEqual(query_object.row_offset, 200) # too low limit and offset payload["queries"][0]["row_limit"] = -1 payload["queries"][0]["row_offset"] = -1 with self.assertRaises(ValidationError) as context: _ = ChartDataQueryContextSchema().load(payload) self.assertIn("row_limit", context.exception.messages["queries"][0]) self.assertIn("row_offset", context.exception.messages["queries"][0])
def test_query_context_limit_and_offset(self): self.login(username="******") table_name = "birth_names" table = self.get_table_by_name(table_name) payload = get_query_context(table.name, table.id, table.type) # Use defaults payload["queries"][0].pop("row_limit", None) payload["queries"][0].pop("row_offset", None) query_context, errors = load_query_context(payload) self.assertEqual(errors, {}) query_object = query_context.queries[0] self.assertEqual(query_object.row_limit, app.config["ROW_LIMIT"]) self.assertEqual(query_object.row_offset, 0) # Valid limit and offset payload["queries"][0]["row_limit"] = 100 payload["queries"][0]["row_offset"] = 200 query_context, errors = ChartDataQueryContextSchema().load(payload) self.assertEqual(errors, {}) query_object = query_context.queries[0] self.assertEqual(query_object.row_limit, 100) self.assertEqual(query_object.row_offset, 200) # too low limit and offset payload["queries"][0]["row_limit"] = 0 payload["queries"][0]["row_offset"] = -1 query_context, errors = ChartDataQueryContextSchema().load(payload) self.assertIn("row_limit", errors["queries"][0]) self.assertIn("row_offset", errors["queries"][0])
def get_sql_text(payload: Dict[str, Any]) -> str: payload["result_type"] = ChartDataResultType.QUERY.value query_context = ChartDataQueryContextSchema().load(payload) responses = query_context.get_payload() assert len(responses) == 1 response = responses["queries"][0] assert len(response) == 2 assert response["language"] == "sql" return response["query"]
def set_query_context(self, form_data: Dict[str, Any]) -> QueryContext: self._form_data = form_data try: self._query_context = ChartDataQueryContextSchema().load(self._form_data) except KeyError as ex: raise ValidationError("Request is incorrect") from ex except ValidationError as error: raise error return self._query_context
def test_sql_injection_via_groupby(self): """ Ensure that calling invalid columns names in groupby are caught """ self.login(username="******") payload = get_query_context("birth_names") payload["queries"][0]["groupby"] = ["currentDatabase()"] query_context = ChartDataQueryContextSchema().load(payload) query_payload = query_context.get_payload() assert query_payload["queries"][0].get("error") is not None
def test_cache_key_changes_when_post_processing_is_updated(self): self.login(username="******") table_name = "birth_names" table = self.get_table_by_name(table_name) payload = get_query_context(table.name, table.id, table.type, add_postprocessing_operations=True) # construct baseline cache_key from query_context with post processing operation query_context = ChartDataQueryContextSchema().load(payload) query_object = query_context.queries[0] cache_key_original = query_context.query_cache_key(query_object) # ensure added None post_processing operation doesn't change cache_key payload["queries"][0]["post_processing"].append(None) query_context = ChartDataQueryContextSchema().load(payload) query_object = query_context.queries[0] cache_key_with_null = query_context.query_cache_key(query_object) self.assertEqual(cache_key_original, cache_key_with_null) # ensure query without post processing operation is different payload["queries"][0].pop("post_processing") query_context = ChartDataQueryContextSchema().load(payload) query_object = query_context.queries[0] cache_key_without_post_processing = query_context.query_cache_key( query_object) self.assertNotEqual(cache_key_original, cache_key_without_post_processing)
def data(self) -> Response: """ Takes a query context constructed in the client and returns payload data response for the given query. --- post: description: >- Takes a query context constructed in the client and returns payload data response for the given query. requestBody: description: >- A query context consists of a datasource from which to fetch data and one or many query objects. required: true content: application/json: schema: $ref: "#/components/schemas/ChartDataQueryContextSchema" responses: 200: description: Query result content: application/json: schema: $ref: "#/components/schemas/ChartDataResponseSchema" 400: $ref: '#/components/responses/400' 500: $ref: '#/components/responses/500' """ if not request.is_json: return self.response_400(message="Request is not JSON") try: query_context, errors = ChartDataQueryContextSchema().load( request.json) if errors: return self.response_400( message=_("Request is incorrect: %(error)s", error=errors)) except KeyError: return self.response_400(message="Request is incorrect") try: security_manager.assert_query_context_permission(query_context) except SupersetSecurityException: return self.response_401() payload_json = query_context.get_payload() response_data = simplejson.dumps({"result": payload_json}, default=json_int_dttm_ser, ignore_nan=True) resp = make_response(response_data, 200) resp.headers["Content-Type"] = "application/json; charset=utf-8" return resp
def test_csv_response_format(self): """ Ensure that CSV result format works """ self.login(username="******") payload = get_query_context("birth_names") payload["result_format"] = ChartDataResultFormat.CSV.value payload["queries"][0]["row_limit"] = 10 query_context = ChartDataQueryContextSchema().load(payload) responses = query_context.get_payload() self.assertEqual(len(responses), 1) data = responses["queries"][0]["data"] self.assertIn("name,sum__num\n", data) self.assertEqual(len(data.split("\n")), 12)
def test_fetch_values_predicate_not_in_query(self): """ Ensure that fetch values predicate is not added to query """ self.login(username="******") payload = get_query_context("birth_names") payload["result_type"] = ChartDataResultType.QUERY.value query_context = ChartDataQueryContextSchema().load(payload) responses = query_context.get_payload() assert len(responses) == 1 response = responses["queries"][0] assert len(response) == 2 assert response["language"] == "sql" assert "123 = 123" not in response["query"]
def test_query_response_type(self): """ Ensure that query result type works """ self.login(username="******") payload = get_query_context("birth_names") payload["result_type"] = ChartDataResultType.QUERY.value query_context = ChartDataQueryContextSchema().load(payload) responses = query_context.get_payload() self.assertEqual(len(responses), 1) response = responses["queries"][0] self.assertEqual(len(response), 2) self.assertEqual(response["language"], "sql") self.assertIn("SELECT", response["query"])
def test_sql_injection_via_columns(self): """ Ensure that calling invalid column names in columns are caught """ self.login(username="******") table_name = "birth_names" table = self.get_table_by_name(table_name) payload = get_query_context(table.name, table.id, table.type) payload["queries"][0]["groupby"] = [] payload["queries"][0]["metrics"] = [] payload["queries"][0]["columns"] = ["*, 'extra'"] query_context = ChartDataQueryContextSchema().load(payload) query_payload = query_context.get_payload() assert query_payload["queries"][0].get("error") is not None
def test_query_response_type(self): """ Ensure that query result type works """ self.login(username="******") payload = get_query_context("birth_names") payload["result_type"] = ChartDataResultType.QUERY.value query_context = ChartDataQueryContextSchema().load(payload) responses = query_context.get_payload() assert len(responses) == 1 response = responses["queries"][0] assert len(response) == 2 assert response["language"] == "sql" assert "SELECT" in response["query"]
def test_samples_response_type(self): """ Ensure that samples result type works """ self.login(username="******") payload = get_query_context("birth_names") payload["result_type"] = ChartDataResultType.SAMPLES.value payload["queries"][0]["row_limit"] = 5 query_context = ChartDataQueryContextSchema().load(payload) responses = query_context.get_payload() self.assertEqual(len(responses), 1) data = responses["queries"][0]["data"] self.assertIsInstance(data, list) self.assertEqual(len(data), 5) self.assertNotIn("sum__num", data[0])
def test_query_cache_key_does_not_change_for_non_existent_or_null(self): self.login(username="******") payload = get_query_context("birth_names", add_postprocessing_operations=True) del payload["queries"][0]["granularity"] # construct baseline query_cache_key from query_context with post processing operation query_context: QueryContext = ChartDataQueryContextSchema().load(payload) query_object: QueryObject = query_context.queries[0] cache_key_original = query_context.query_cache_key(query_object) payload["queries"][0]["granularity"] = None query_context = ChartDataQueryContextSchema().load(payload) query_object = query_context.queries[0] assert query_context.query_cache_key(query_object) == cache_key_original
def _create_query_context_from_form(form_data: Dict[str, Any]) -> QueryContext: try: return ChartDataQueryContextSchema().load(form_data) except KeyError as ex: raise ValidationError("Request is incorrect") from ex except ValidationError as error: raise error
def test_sql_injection_via_filters(self): """ Ensure that calling invalid columns names in filters are caught """ self.login(username="******") table_name = "birth_names" table = self.get_table_by_name(table_name) payload = get_query_context(table.name, table.id, table.type) payload["queries"][0]["groupby"] = ["name"] payload["queries"][0]["metrics"] = [] payload["queries"][0]["filters"] = [ {"col": "*", "op": FilterOperator.EQUALS.value, "val": ";"} ] query_context = ChartDataQueryContextSchema().load(payload) query_payload = query_context.get_payload() assert query_payload[0].get("error") is not None
def test_convert_deprecated_fields(self): """ Ensure that deprecated fields are converted correctly """ self.login(username="******") payload = get_query_context("birth_names") columns = payload["queries"][0]["columns"] payload["queries"][0]["groupby"] = columns payload["queries"][0]["timeseries_limit"] = 99 payload["queries"][0]["timeseries_limit_metric"] = "sum__num" del payload["queries"][0]["columns"] payload["queries"][0]["granularity_sqla"] = "timecol" payload["queries"][0]["having_filters"] = [{ "col": "a", "op": "==", "val": "b" }] query_context = ChartDataQueryContextSchema().load(payload) self.assertEqual(len(query_context.queries), 1) query_object = query_context.queries[0] self.assertEqual(query_object.granularity, "timecol") self.assertEqual(query_object.columns, columns) self.assertEqual(query_object.series_limit, 99) self.assertEqual(query_object.series_limit_metric, "sum__num") self.assertIn("having_druid", query_object.extras)
def test_query_context_null_post_processing_op(self): self.login(username="******") payload = get_query_context("birth_names") payload["queries"][0]["post_processing"] = [None] query_context = ChartDataQueryContextSchema().load(payload) self.assertEqual(query_context.queries[0].post_processing, [])
def test_schema_deserialization(self): """ Ensure that the deserialized QueryContext contains all required fields. """ payload = get_query_context("birth_names", add_postprocessing_operations=True) query_context = ChartDataQueryContextSchema().load(payload) self.assertEqual(len(query_context.queries), len(payload["queries"])) for query_idx, query in enumerate(query_context.queries): payload_query = payload["queries"][query_idx] # check basic properies self.assertEqual(query.extras, payload_query["extras"]) self.assertEqual(query.filter, payload_query["filters"]) self.assertEqual(query.columns, payload_query["columns"]) # metrics are mutated during creation for metric_idx, metric in enumerate(query.metrics): payload_metric = payload_query["metrics"][metric_idx] payload_metric = ( payload_metric if "expressionType" in payload_metric else payload_metric["label"] ) self.assertEqual(metric, payload_metric) self.assertEqual(query.orderby, payload_query["orderby"]) self.assertEqual(query.time_range, payload_query["time_range"]) # check post processing operation properties for post_proc_idx, post_proc in enumerate(query.post_processing): payload_post_proc = payload_query["post_processing"][post_proc_idx] self.assertEqual(post_proc["operation"], payload_post_proc["operation"]) self.assertEqual(post_proc["options"], payload_post_proc["options"])
def test_query_context_null_timegrain(self): self.login(username="******") table_name = "birth_names" table = self.get_table_by_name(table_name) payload = get_query_context(table.name, table.id, table.type) payload["queries"][0]["extras"]["time_grain_sqla"] = None _ = ChartDataQueryContextSchema().load(payload)
def test_handle_metrics_field(self): """ Should support both predefined and adhoc metrics. """ self.login(username="******") adhoc_metric = { "expressionType": "SIMPLE", "column": { "column_name": "num_boys", "type": "BIGINT(20)" }, "aggregate": "SUM", "label": "Boys", "optionName": "metric_11", } payload = get_query_context("birth_names") payload["queries"][0]["metrics"] = [ "sum__num", { "label": "abc" }, adhoc_metric ] query_context = ChartDataQueryContextSchema().load(payload) query_object = query_context.queries[0] self.assertEqual(query_object.metrics, ["sum__num", "abc", adhoc_metric])
def test_cache(self): table_name = "birth_names" table = self.get_table_by_name(table_name) payload = get_query_context(table.name, table.id) payload["force"] = True query_context = ChartDataQueryContextSchema().load(payload) query_object = query_context.queries[0] query_cache_key = query_context.query_cache_key(query_object) response = query_context.get_payload(cache_query_context=True) cache_key = response["cache_key"] assert cache_key is not None cached = cache_manager.cache.get(cache_key) assert cached is not None rehydrated_qc = ChartDataQueryContextSchema().load(cached["data"]) rehydrated_qo = rehydrated_qc.queries[0] rehydrated_query_cache_key = rehydrated_qc.query_cache_key( rehydrated_qo) self.assertEqual(rehydrated_qc.datasource, query_context.datasource) self.assertEqual(len(rehydrated_qc.queries), 1) self.assertEqual(query_cache_key, rehydrated_query_cache_key) self.assertEqual(rehydrated_qc.result_type, query_context.result_type) self.assertEqual(rehydrated_qc.result_format, query_context.result_format) self.assertFalse(rehydrated_qc.force)
def test_cache(self): table_name = "birth_names" payload = get_query_context( query_name=table_name, add_postprocessing_operations=True, ) payload["force"] = True query_context = ChartDataQueryContextSchema().load(payload) query_object = query_context.queries[0] query_cache_key = query_context.query_cache_key(query_object) response = query_context.get_payload(cache_query_context=True) # MUST BE a successful query query_dump = response["queries"][0] assert query_dump["status"] == QueryStatus.SUCCESS cache_key = response["cache_key"] assert cache_key is not None cached = cache_manager.cache.get(cache_key) assert cached is not None rehydrated_qc = ChartDataQueryContextSchema().load(cached["data"]) rehydrated_qo = rehydrated_qc.queries[0] rehydrated_query_cache_key = rehydrated_qc.query_cache_key(rehydrated_qo) self.assertEqual(rehydrated_qc.datasource, query_context.datasource) self.assertEqual(len(rehydrated_qc.queries), 1) self.assertEqual(query_cache_key, rehydrated_query_cache_key) self.assertEqual(rehydrated_qc.result_type, query_context.result_type) self.assertEqual(rehydrated_qc.result_format, query_context.result_format) self.assertFalse(rehydrated_qc.force)
def test_query_context_null_post_processing_op(self): self.login(username="******") table_name = "birth_names" table = self.get_table_by_name(table_name) payload = get_query_context(table.name, table.id, table.type) payload["queries"][0]["post_processing"] = [None] query_context = ChartDataQueryContextSchema().load(payload) self.assertEqual(query_context.queries[0].post_processing, [])
def test_sql_injection_via_metrics(self): """ Ensure that calling invalid column names in filters are caught """ self.login(username="******") payload = get_query_context("birth_names") payload["queries"][0]["groupby"] = ["name"] payload["queries"][0]["metrics"] = [ { "expressionType": AdhocMetricExpressionType.SIMPLE.value, "column": {"column_name": "invalid_col"}, "aggregate": "SUM", "label": "My Simple Label", } ] query_context = ChartDataQueryContextSchema().load(payload) query_payload = query_context.get_payload() assert query_payload["queries"][0].get("error") is not None
def test_csv_response_format(self): """ Ensure that CSV result format works """ self.login(username="******") table_name = "birth_names" table = self.get_table_by_name(table_name) payload = get_query_context(table.name, table.id, table.type) payload["result_format"] = ChartDataResultFormat.CSV.value payload["queries"][0]["row_limit"] = 10 query_context = ChartDataQueryContextSchema().load(payload) responses = query_context.get_payload() self.assertEqual(len(responses), 1) data = responses[0]["data"] self.assertIn("name,sum__num\n", data) self.assertEqual(len(data.split("\n")), 12) ck = db.session.query(CacheKey).order_by(CacheKey.id.desc()).first() assert ck.datasource_uid == "3__table"
def test_query_context_limit_and_offset(self): self.login(username="******") payload = get_query_context("birth_names") # too low limit and offset payload["queries"][0]["row_limit"] = -1 payload["queries"][0]["row_offset"] = -1 with self.assertRaises(ValidationError) as context: _ = ChartDataQueryContextSchema().load(payload) self.assertIn("row_limit", context.exception.messages["queries"][0]) self.assertIn("row_offset", context.exception.messages["queries"][0])
def test_query_response_type(self): """ Ensure that query result type works """ self.login(username="******") payload = get_query_context("birth_names") payload["result_type"] = ChartDataResultType.QUERY.value query_context = ChartDataQueryContextSchema().load(payload) responses = query_context.get_payload() assert len(responses) == 1 response = responses["queries"][0] assert len(response) == 2 sql_text = response["query"] assert response["language"] == "sql" assert "SELECT" in sql_text assert re.search(r'[`"\[]?num[`"\]]? IS NOT NULL', sql_text) assert re.search( r"""NOT \([`"\[]?name[`"\]]? IS NULL[\s\n]* """ r"""OR [`"\[]?name[`"\]]? IN \('abc'\)\)""", sql_text, )
def test_time_offsets_sql(self): payload = get_query_context("birth_names") payload["queries"][0]["metrics"] = ["sum__num"] payload["queries"][0]["groupby"] = ["state"] payload["queries"][0]["is_timeseries"] = True payload["queries"][0]["timeseries_limit"] = 5 payload["queries"][0]["time_offsets"] = [] payload["queries"][0]["time_range"] = "1980 : 1991" payload["queries"][0]["granularity"] = "ds" payload["queries"][0]["extras"]["time_grain_sqla"] = "P1Y" query_context = ChartDataQueryContextSchema().load(payload) query_object = query_context.queries[0] query_result = query_context.get_query_result(query_object) # get main query dataframe df = query_result.df # set time_offsets to query_object payload["queries"][0]["time_offsets"] = ["3 years ago", "3 years later"] query_context = ChartDataQueryContextSchema().load(payload) query_object = query_context.queries[0] time_offsets_obj = query_context.processing_time_offsets(df, query_object) query_from_1977_to_1988 = time_offsets_obj["queries"][0] query_from_1983_to_1994 = time_offsets_obj["queries"][1] # should generate expected date range in sql assert "1977-01-01" in query_from_1977_to_1988 assert "1988-01-01" in query_from_1977_to_1988 assert "1983-01-01" in query_from_1983_to_1994 assert "1994-01-01" in query_from_1983_to_1994
def test_query_cache_key_changes_when_datasource_is_updated(self): self.login(username="******") payload = get_query_context("birth_names") # construct baseline query_cache_key query_context = ChartDataQueryContextSchema().load(payload) query_object = query_context.queries[0] cache_key_original = query_context.query_cache_key(query_object) # make temporary change and revert it to refresh the changed_on property datasource = ConnectorRegistry.get_datasource( datasource_type=payload["datasource"]["type"], datasource_id=payload["datasource"]["id"], session=db.session, ) description_original = datasource.description datasource.description = "temporary description" db.session.commit() datasource.description = description_original db.session.commit() # create new QueryContext with unchanged attributes, extract new query_cache_key query_context = ChartDataQueryContextSchema().load(payload) query_object = query_context.queries[0] cache_key_new = query_context.query_cache_key(query_object) # the new cache_key should be different due to updated datasource self.assertNotEqual(cache_key_original, cache_key_new)