Example #1
0
    def test_query_context_null_post_processing_op(self):
        self.login(username="******")
        table_name = "birth_names"
        table = self.get_table_by_name(table_name)
        payload = get_query_context(table.name, table.id, table.type)

        payload["queries"][0]["post_processing"] = [None]
        query_context = ChartDataQueryContextSchema().load(payload)
        self.assertEqual(query_context.queries[0].post_processing, [])
 def test_query_exec_not_allowed(self):
     """
     Chart data API: Test chart data query not allowed
     """
     self.login(username="******")
     table = self.get_table_by_name("birth_names")
     payload = get_query_context(table.name, table.id, table.type)
     rv = self.post_assert_metric(CHART_DATA_URI, payload, "data")
     self.assertEqual(rv.status_code, 401)
Example #3
0
 def test_chart_data_with_invalid_datasource(self):
     """Chart data API: Test chart data query with invalid schema
     """
     self.login(username="******")
     table = self.get_table_by_name("birth_names")
     payload = get_query_context(table.name, table.id, table.type)
     payload["datasource"] = "abc"
     rv = self.post_assert_metric(CHART_DATA_URI, payload, "data")
     self.assertEqual(rv.status_code, 400)
    def test_query_context_null_timegrain(self):
        self.login(username="******")
        table_name = "birth_names"
        table = self.get_table_by_name(table_name)
        payload = get_query_context(table.name, table.id, table.type)

        payload["queries"][0]["extras"]["time_grain_sqla"] = None
        _, errors = ChartDataQueryContextSchema().load(payload)
        self.assertEqual(errors, {})
Example #5
0
 def test_sql_injection_via_groupby(self):
     """
     Ensure that calling invalid columns names in groupby are caught
     """
     self.login(username="******")
     payload = get_query_context("birth_names")
     payload["queries"][0]["groupby"] = ["currentDatabase()"]
     query_context = ChartDataQueryContextSchema().load(payload)
     query_payload = query_context.get_payload()
     assert query_payload["queries"][0].get("error") is not None
Example #6
0
 def test_chart_data_incorrect_result_type(self):
     """
     Chart data API: Test chart data with unsupported result type
     """
     self.login(username="******")
     table = self.get_table_by_name("birth_names")
     request_payload = get_query_context(table.name, table.id, table.type)
     request_payload["result_type"] = "qwerty"
     rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
     self.assertEqual(rv.status_code, 400)
Example #7
0
 def test_chart_data_csv_result_format(self):
     """
     Chart data API: Test chart data with CSV result format
     """
     self.login(username="******")
     table = self.get_table_by_name("birth_names")
     request_payload = get_query_context(table.name, table.id, table.type)
     request_payload["result_format"] = "csv"
     rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
     self.assertEqual(rv.status_code, 200)
 def test_chart_data_with_invalid_enum_value(self):
     """Chart data API: Test chart data query with invalid enum value"""
     self.login(username="******")
     table = self.get_table_by_name("birth_names")
     payload = get_query_context(table.name, table.id, table.type)
     payload["queries"][0]["extras"]["time_range_endpoints"] = [
         "abc",
         "EXCLUSIVE",
     ]
     rv = self.client.post(CHART_DATA_URI, json=payload)
     self.assertEqual(rv.status_code, 400)
Example #9
0
 def test_chart_data_simple(self):
     """
     Chart data API: Test chart data query
     """
     self.login(username="******")
     table = self.get_table_by_name("birth_names")
     request_payload = get_query_context(table.name, table.id, table.type)
     rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
     self.assertEqual(rv.status_code, 200)
     data = json.loads(rv.data.decode("utf-8"))
     self.assertEqual(data["result"][0]["rowcount"], 45)
Example #10
0
 def test_chart_data(self):
     """
     Query API: Test chart data query
     """
     self.login(username="******")
     table = self.get_table_by_name("birth_names")
     payload = get_query_context(table.name, table.id, table.type)
     uri = "api/v1/chart/data"
     rv = self.post_assert_metric(uri, payload, "data")
     self.assertEqual(rv.status_code, 200)
     data = json.loads(rv.data.decode("utf-8"))
     self.assertEqual(data["result"][0]["rowcount"], 100)
Example #11
0
 def test_chart_data_incorrect_request(self):
     """
     Chart data API: Test chart data with invalid SQL
     """
     self.login(username="******")
     table = self.get_table_by_name("birth_names")
     request_payload = get_query_context(table.name, table.id, table.type)
     request_payload["queries"][0]["filters"] = []
     # erroneus WHERE-clause
     request_payload["queries"][0]["extras"]["where"] = "(gender abc def)"
     rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
     self.assertEqual(rv.status_code, 400)
Example #12
0
 def test_chart_data_default_row_limit(self):
     """
     Chart data API: Ensure row count doesn't exceed default limit
     """
     self.login(username="******")
     table = self.get_table_by_name("birth_names")
     request_payload = get_query_context(table.name, table.id, table.type)
     del request_payload["queries"][0]["row_limit"]
     rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
     response_payload = json.loads(rv.data.decode("utf-8"))
     result = response_payload["result"][0]
     self.assertEqual(result["rowcount"], 7)
 def test_convert_deprecated_fields(self):
     """
     Ensure that deprecated fields are converted correctly
     """
     self.login(username="******")
     payload = get_query_context("birth_names")
     payload["queries"][0]["granularity_sqla"] = "timecol"
     payload["queries"][0]["having_filters"] = [{"col": "a", "op": "==", "val": "b"}]
     query_context = ChartDataQueryContextSchema().load(payload)
     self.assertEqual(len(query_context.queries), 1)
     query_object = query_context.queries[0]
     self.assertEqual(query_object.granularity, "timecol")
     self.assertIn("having_druid", query_object.extras)
Example #14
0
 def test_chart_data_mixed_case_filter_op(self):
     """
     Chart data API: Ensure mixed case filter operator generates valid result
     """
     self.login(username="******")
     table = self.get_table_by_name("birth_names")
     request_payload = get_query_context(table.name, table.id, table.type)
     request_payload["queries"][0]["filters"][0]["op"] = "In"
     request_payload["queries"][0]["row_limit"] = 10
     rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
     response_payload = json.loads(rv.data.decode("utf-8"))
     result = response_payload["result"][0]
     self.assertEqual(result["rowcount"], 10)
    def test_fetch_values_predicate(self):
        """
        Ensure that fetch values predicate is added to query if needed
        """
        self.login(username="******")

        payload = get_query_context("birth_names")
        sql_text = get_sql_text(payload)
        assert "123 = 123" not in sql_text

        payload["queries"][0]["apply_fetch_values_predicate"] = True
        sql_text = get_sql_text(payload)
        assert "123 = 123" in sql_text
Example #16
0
 def test_csv_response_format(self):
     """
     Ensure that CSV result format works
     """
     self.login(username="******")
     payload = get_query_context("birth_names")
     payload["result_format"] = ChartDataResultFormat.CSV.value
     payload["queries"][0]["row_limit"] = 10
     query_context = ChartDataQueryContextSchema().load(payload)
     responses = query_context.get_payload()
     self.assertEqual(len(responses), 1)
     data = responses["queries"][0]["data"]
     self.assertIn("name,sum__num\n", data)
     self.assertEqual(len(data.split("\n")), 12)
Example #17
0
 def test_query_response_type(self):
     """
     Ensure that query result type works
     """
     self.login(username="******")
     payload = get_query_context("birth_names")
     payload["result_type"] = ChartDataResultType.QUERY.value
     query_context = ChartDataQueryContextSchema().load(payload)
     responses = query_context.get_payload()
     self.assertEqual(len(responses), 1)
     response = responses["queries"][0]
     self.assertEqual(len(response), 2)
     self.assertEqual(response["language"], "sql")
     self.assertIn("SELECT", response["query"])
Example #18
0
 def test_query_response_type(self):
     """
     Ensure that query result type works
     """
     self.login(username="******")
     payload = get_query_context("birth_names")
     payload["result_type"] = ChartDataResultType.QUERY.value
     query_context = ChartDataQueryContextSchema().load(payload)
     responses = query_context.get_payload()
     assert len(responses) == 1
     response = responses["queries"][0]
     assert len(response) == 2
     assert response["language"] == "sql"
     assert "SELECT" in response["query"]
Example #19
0
 def test_sql_injection_via_columns(self):
     """
     Ensure that calling invalid column names in columns are caught
     """
     self.login(username="******")
     table_name = "birth_names"
     table = self.get_table_by_name(table_name)
     payload = get_query_context(table.name, table.id, table.type)
     payload["queries"][0]["groupby"] = []
     payload["queries"][0]["metrics"] = []
     payload["queries"][0]["columns"] = ["*, 'extra'"]
     query_context = ChartDataQueryContextSchema().load(payload)
     query_payload = query_context.get_payload()
     assert query_payload["queries"][0].get("error") is not None
 def test_query_response_type(self):
     """
     Ensure that query result type works
     """
     self.login(username="******")
     payload = get_query_context("birth_names")
     sql_text = get_sql_text(payload)
     assert "SELECT" in sql_text
     assert re.search(r'[`"\[]?num[`"\]]? IS NOT NULL', sql_text)
     assert re.search(
         r"""NOT \([`"\[]?name[`"\]]? IS NULL[\s\n]* """
         r"""OR [`"\[]?name[`"\]]? IN \('abc'\)\)""",
         sql_text,
     )
 def test_fetch_values_predicate_not_in_query(self):
     """
     Ensure that fetch values predicate is not added to query
     """
     self.login(username="******")
     payload = get_query_context("birth_names")
     payload["result_type"] = ChartDataResultType.QUERY.value
     query_context = ChartDataQueryContextSchema().load(payload)
     responses = query_context.get_payload()
     assert len(responses) == 1
     response = responses["queries"][0]
     assert len(response) == 2
     assert response["language"] == "sql"
     assert "123 = 123" not in response["query"]
 def test_query_object_unknown_fields(self):
     """
     Ensure that query objects with unknown fields don't raise an Exception and
     have an identical cache key as one without the unknown field
     """
     self.login(username="******")
     payload = get_query_context("birth_names")
     query_context = ChartDataQueryContextSchema().load(payload)
     responses = query_context.get_payload()
     orig_cache_key = responses["queries"][0]["cache_key"]
     payload["queries"][0]["foo"] = "bar"
     query_context = ChartDataQueryContextSchema().load(payload)
     responses = query_context.get_payload()
     new_cache_key = responses["queries"][0]["cache_key"]
     self.assertEqual(orig_cache_key, new_cache_key)
    def test_query_cache_key_does_not_change_for_non_existent_or_null(self):
        self.login(username="******")
        payload = get_query_context("birth_names", add_postprocessing_operations=True)
        del payload["queries"][0]["granularity"]

        # construct baseline query_cache_key from query_context with post processing operation
        query_context: QueryContext = ChartDataQueryContextSchema().load(payload)
        query_object: QueryObject = query_context.queries[0]
        cache_key_original = query_context.query_cache_key(query_object)

        payload["queries"][0]["granularity"] = None
        query_context = ChartDataQueryContextSchema().load(payload)
        query_object = query_context.queries[0]

        assert query_context.query_cache_key(query_object) == cache_key_original
Example #24
0
 def test_chart_data_query_missing_filter(self):
     """
     Chart data API: Ensure filter referencing missing column is ignored
     """
     self.login(username="******")
     table = self.get_table_by_name("birth_names")
     request_payload = get_query_context(table.name, table.id, table.type)
     request_payload["queries"][0]["filters"] = [
         {"col": "non_existent_filter", "op": "==", "val": "foo"},
     ]
     request_payload["result_type"] = utils.ChartDataResultType.QUERY
     rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
     self.assertEqual(rv.status_code, 200)
     response_payload = json.loads(rv.data.decode("utf-8"))
     assert "non_existent_filter" not in response_payload["result"][0]["query"]
Example #25
0
 def test_convert_deprecated_fields(self):
     """
     Ensure that deprecated fields are converted correctly
     """
     self.login(username="******")
     table_name = "birth_names"
     table = self.get_table_by_name(table_name)
     payload = get_query_context(table.name, table.id, table.type)
     payload["queries"][0]["granularity_sqla"] = "timecol"
     payload["queries"][0]["having_filters"] = {"col": "a", "op": "==", "val": "b"}
     query_context = QueryContext(**payload)
     self.assertEqual(len(query_context.queries), 1)
     query_object = query_context.queries[0]
     self.assertEqual(query_object.granularity, "timecol")
     self.assertIn("having_druid", query_object.extras)
Example #26
0
 def test_samples_response_type(self):
     """
     Ensure that samples result type works
     """
     self.login(username="******")
     payload = get_query_context("birth_names")
     payload["result_type"] = ChartDataResultType.SAMPLES.value
     payload["queries"][0]["row_limit"] = 5
     query_context = ChartDataQueryContextSchema().load(payload)
     responses = query_context.get_payload()
     self.assertEqual(len(responses), 1)
     data = responses["queries"][0]["data"]
     self.assertIsInstance(data, list)
     self.assertEqual(len(data), 5)
     self.assertNotIn("sum__num", data[0])
Example #27
0
    def test_load_chart_data_into_cache(self, mock_update_job):
        async_query_manager.init_app(app)
        query_context = get_query_context("birth_names")
        job_metadata = {
            "channel_id": str(uuid4()),
            "job_id": str(uuid4()),
            "user_id": 1,
            "status": "pending",
            "errors": [],
        }

        load_chart_data_into_cache(job_metadata, query_context)

        mock_update_job.assert_called_with(job_metadata,
                                           "done",
                                           result_url=mock.ANY)
Example #28
0
 def test_query_context_time_range_endpoints(self):
     """
     Ensure that time_range_endpoints are populated automatically when missing
     from the payload.
     """
     self.login(username="******")
     payload = get_query_context("birth_names")
     del payload["queries"][0]["extras"]["time_range_endpoints"]
     query_context = ChartDataQueryContextSchema().load(payload)
     query_object = query_context.queries[0]
     extras = query_object.to_dict()["extras"]
     assert "time_range_endpoints" in extras
     self.assertEqual(
         extras["time_range_endpoints"],
         (TimeRangeEndpoint.INCLUSIVE, TimeRangeEndpoint.EXCLUSIVE),
     )
Example #29
0
 def test_chart_data_no_data(self):
     """
     Chart data API: Test chart data with empty result
     """
     self.login(username="******")
     table = self.get_table_by_name("birth_names")
     request_payload = get_query_context(table.name, table.id, table.type)
     request_payload["queries"][0]["filters"] = [
         {"col": "gender", "op": "==", "val": "foo"}
     ]
     rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
     self.assertEqual(rv.status_code, 200)
     response_payload = json.loads(rv.data.decode("utf-8"))
     result = response_payload["result"][0]
     self.assertEqual(result["rowcount"], 0)
     self.assertEqual(result["data"], [])
Example #30
0
 def test_query_response_type(self):
     """
     Ensure that query result type works
     """
     self.login(username="******")
     table_name = "birth_names"
     table = self.get_table_by_name(table_name)
     payload = get_query_context(table.name, table.id, table.type)
     payload["result_type"] = ChartDataResultType.QUERY.value
     query_context = QueryContext(**payload)
     responses = query_context.get_payload()
     self.assertEqual(len(responses), 1)
     response = responses[0]
     self.assertEqual(len(response), 2)
     self.assertEqual(response["language"], "sql")
     self.assertIn("SELECT", response["query"])