def test_sql_injection_via_groupby(self): """ Ensure that calling invalid columns names in groupby are caught """ self.login(username="******") payload = get_query_context("birth_names") payload["queries"][0]["groupby"] = ["currentDatabase()"] query_context = ChartDataQueryContextSchema().load(payload) query_payload = query_context.get_payload() assert query_payload["queries"][0].get("error") is not None
def data(self) -> Response: """ Takes a query context constructed in the client and returns payload data response for the given query. --- post: description: >- Takes a query context constructed in the client and returns payload data response for the given query. requestBody: description: >- A query context consists of a datasource from which to fetch data and one or many query objects. required: true content: application/json: schema: $ref: "#/components/schemas/ChartDataQueryContextSchema" responses: 200: description: Query result content: application/json: schema: $ref: "#/components/schemas/ChartDataResponseSchema" 400: $ref: '#/components/responses/400' 500: $ref: '#/components/responses/500' """ if not request.is_json: return self.response_400(message="Request is not JSON") try: query_context, errors = ChartDataQueryContextSchema().load( request.json) if errors: return self.response_400( message=_("Request is incorrect: %(error)s", error=errors)) except KeyError: return self.response_400(message="Request is incorrect") try: security_manager.assert_query_context_permission(query_context) except SupersetSecurityException: return self.response_401() payload_json = query_context.get_payload() response_data = simplejson.dumps({"result": payload_json}, default=json_int_dttm_ser, ignore_nan=True) resp = make_response(response_data, 200) resp.headers["Content-Type"] = "application/json; charset=utf-8" return resp
def test_query_response_type(self): """ Ensure that query result type works """ self.login(username="******") payload = get_query_context("birth_names") payload["result_type"] = ChartDataResultType.QUERY.value query_context = ChartDataQueryContextSchema().load(payload) responses = query_context.get_payload() self.assertEqual(len(responses), 1) response = responses["queries"][0] self.assertEqual(len(response), 2) self.assertEqual(response["language"], "sql") self.assertIn("SELECT", response["query"])
def test_csv_response_format(self): """ Ensure that CSV result format works """ self.login(username="******") payload = get_query_context("birth_names") payload["result_format"] = ChartDataResultFormat.CSV.value payload["queries"][0]["row_limit"] = 10 query_context = ChartDataQueryContextSchema().load(payload) responses = query_context.get_payload() self.assertEqual(len(responses), 1) data = responses["queries"][0]["data"] self.assertIn("name,sum__num\n", data) self.assertEqual(len(data.split("\n")), 12)
def test_sql_injection_via_columns(self): """ Ensure that calling invalid column names in columns are caught """ self.login(username="******") table_name = "birth_names" table = self.get_table_by_name(table_name) payload = get_query_context(table.name, table.id, table.type) payload["queries"][0]["groupby"] = [] payload["queries"][0]["metrics"] = [] payload["queries"][0]["columns"] = ["*, 'extra'"] query_context = ChartDataQueryContextSchema().load(payload) query_payload = query_context.get_payload() assert query_payload["queries"][0].get("error") is not None
def test_fetch_values_predicate_not_in_query(self): """ Ensure that fetch values predicate is not added to query """ self.login(username="******") payload = get_query_context("birth_names") payload["result_type"] = ChartDataResultType.QUERY.value query_context = ChartDataQueryContextSchema().load(payload) responses = query_context.get_payload() assert len(responses) == 1 response = responses["queries"][0] assert len(response) == 2 assert response["language"] == "sql" assert "123 = 123" not in response["query"]
def test_query_response_type(self): """ Ensure that query result type works """ self.login(username="******") payload = get_query_context("birth_names") payload["result_type"] = ChartDataResultType.QUERY.value query_context = ChartDataQueryContextSchema().load(payload) responses = query_context.get_payload() assert len(responses) == 1 response = responses["queries"][0] assert len(response) == 2 assert response["language"] == "sql" assert "SELECT" in response["query"]
def test_samples_response_type(self): """ Ensure that samples result type works """ self.login(username="******") payload = get_query_context("birth_names") payload["result_type"] = ChartDataResultType.SAMPLES.value payload["queries"][0]["row_limit"] = 5 query_context = ChartDataQueryContextSchema().load(payload) responses = query_context.get_payload() self.assertEqual(len(responses), 1) data = responses["queries"][0]["data"] self.assertIsInstance(data, list) self.assertEqual(len(data), 5) self.assertNotIn("sum__num", data[0])
def test_sql_injection_via_filters(self): """ Ensure that calling invalid columns names in filters are caught """ self.login(username="******") table_name = "birth_names" table = self.get_table_by_name(table_name) payload = get_query_context(table.name, table.id, table.type) payload["queries"][0]["groupby"] = ["name"] payload["queries"][0]["metrics"] = [] payload["queries"][0]["filters"] = [ {"col": "*", "op": FilterOperator.EQUALS.value, "val": ";"} ] query_context = ChartDataQueryContextSchema().load(payload) query_payload = query_context.get_payload() assert query_payload[0].get("error") is not None
def test_sql_injection_via_metrics(self): """ Ensure that calling invalid column names in filters are caught """ self.login(username="******") payload = get_query_context("birth_names") payload["queries"][0]["groupby"] = ["name"] payload["queries"][0]["metrics"] = [ { "expressionType": AdhocMetricExpressionType.SIMPLE.value, "column": {"column_name": "invalid_col"}, "aggregate": "SUM", "label": "My Simple Label", } ] query_context = ChartDataQueryContextSchema().load(payload) query_payload = query_context.get_payload() assert query_payload["queries"][0].get("error") is not None
def test_csv_response_format(self): """ Ensure that CSV result format works """ self.login(username="******") table_name = "birth_names" table = self.get_table_by_name(table_name) payload = get_query_context(table.name, table.id, table.type) payload["result_format"] = ChartDataResultFormat.CSV.value payload["queries"][0]["row_limit"] = 10 query_context = ChartDataQueryContextSchema().load(payload) responses = query_context.get_payload() self.assertEqual(len(responses), 1) data = responses[0]["data"] self.assertIn("name,sum__num\n", data) self.assertEqual(len(data.split("\n")), 12) ck = db.session.query(CacheKey).order_by(CacheKey.id.desc()).first() assert ck.datasource_uid == "3__table"
def test_query_response_type(self): """ Ensure that query result type works """ self.login(username="******") payload = get_query_context("birth_names") payload["result_type"] = ChartDataResultType.QUERY.value query_context = ChartDataQueryContextSchema().load(payload) responses = query_context.get_payload() assert len(responses) == 1 response = responses["queries"][0] assert len(response) == 2 sql_text = response["query"] assert response["language"] == "sql" assert "SELECT" in sql_text assert re.search(r'[`"\[]?num[`"\]]? IS NOT NULL', sql_text) assert re.search( r"""NOT \([`"\[]?name[`"\]]? IS NULL[\s\n]* """ r"""OR [`"\[]?name[`"\]]? IN \('abc'\)\)""", sql_text, )
def test_time_offsets_in_query_object(self): """ Ensure that time_offsets can generate the correct query """ self.login(username="******") payload = get_query_context("birth_names") payload["queries"][0]["metrics"] = ["sum__num"] payload["queries"][0]["groupby"] = ["name"] payload["queries"][0]["is_timeseries"] = True payload["queries"][0]["timeseries_limit"] = 5 payload["queries"][0]["time_offsets"] = ["1 year ago", "1 year later"] payload["queries"][0]["time_range"] = "1990 : 1991" query_context = ChartDataQueryContextSchema().load(payload) responses = query_context.get_payload() self.assertEqual( responses["queries"][0]["colnames"], [ "__timestamp", "name", "sum__num", "sum__num__1 year ago", "sum__num__1 year later", ], ) sqls = [ sql for sql in responses["queries"][0]["query"].split(";") if sql.strip() ] self.assertEqual(len(sqls), 3) # 1 year ago assert re.search(r"1989-01-01.+1990-01-01", sqls[1], re.S) assert re.search(r"1990-01-01.+1991-01-01", sqls[1], re.S) # # 1 year later assert re.search(r"1991-01-01.+1992-01-01", sqls[2], re.S) assert re.search(r"1990-01-01.+1991-01-01", sqls[2], re.S)
def data(self) -> Response: # pylint: disable=too-many-return-statements """ Takes a query context constructed in the client and returns payload data response for the given query. --- post: description: >- Takes a query context constructed in the client and returns payload data response for the given query. requestBody: description: >- A query context consists of a datasource from which to fetch data and one or many query objects. required: true content: application/json: schema: $ref: "#/components/schemas/ChartDataQueryContextSchema" responses: 200: description: Query result content: application/json: schema: $ref: "#/components/schemas/ChartDataResponseSchema" 400: $ref: '#/components/responses/400' 500: $ref: '#/components/responses/500' """ if request.is_json: json_body = request.json elif request.form.get("form_data"): # CSV export submits regular form data json_body = json.loads(request.form["form_data"]) else: return self.response_400(message="Request is not JSON") try: query_context = ChartDataQueryContextSchema().load(json_body) except KeyError: return self.response_400(message="Request is incorrect") except ValidationError as error: return self.response_400(message=_( "Request is incorrect: %(error)s", error=error.messages)) try: query_context.raise_for_access() except SupersetSecurityException: return self.response_401() payload = query_context.get_payload() for query in payload: if query.get("error"): return self.response_400(message=f"Error: {query['error']}") result_format = query_context.result_format if result_format == ChartDataResultFormat.CSV: # return the first result result = payload[0]["data"] return CsvResponse( result, status=200, headers=generate_download_headers("csv"), mimetype="application/csv", ) if result_format == ChartDataResultFormat.JSON: response_data = simplejson.dumps({"result": payload}, default=json_int_dttm_ser, ignore_nan=True) resp = make_response(response_data, 200) resp.headers["Content-Type"] = "application/json; charset=utf-8" return resp return self.response_400( message=f"Unsupported result_format: {result_format}")