Пример #1
0
    def test_query_context_limit_and_offset(self):
        self.login(username="******")
        payload = get_query_context("birth_names")

        # Use defaults
        payload["queries"][0].pop("row_limit", None)
        payload["queries"][0].pop("row_offset", None)
        query_context = ChartDataQueryContextSchema().load(payload)
        query_object = query_context.queries[0]
        self.assertEqual(query_object.row_limit, app.config["ROW_LIMIT"])
        self.assertEqual(query_object.row_offset, 0)

        # Valid limit and offset
        payload["queries"][0]["row_limit"] = 100
        payload["queries"][0]["row_offset"] = 200
        query_context = ChartDataQueryContextSchema().load(payload)
        query_object = query_context.queries[0]
        self.assertEqual(query_object.row_limit, 100)
        self.assertEqual(query_object.row_offset, 200)

        # too low limit and offset
        payload["queries"][0]["row_limit"] = -1
        payload["queries"][0]["row_offset"] = -1
        with self.assertRaises(ValidationError) as context:
            _ = ChartDataQueryContextSchema().load(payload)
        self.assertIn("row_limit", context.exception.messages["queries"][0])
        self.assertIn("row_offset", context.exception.messages["queries"][0])
Пример #2
0
    def test_query_context_limit_and_offset(self):
        self.login(username="******")
        table_name = "birth_names"
        table = self.get_table_by_name(table_name)
        payload = get_query_context(table.name, table.id, table.type)

        # Use defaults
        payload["queries"][0].pop("row_limit", None)
        payload["queries"][0].pop("row_offset", None)
        query_context, errors = load_query_context(payload)
        self.assertEqual(errors, {})
        query_object = query_context.queries[0]
        self.assertEqual(query_object.row_limit, app.config["ROW_LIMIT"])
        self.assertEqual(query_object.row_offset, 0)

        # Valid limit and offset
        payload["queries"][0]["row_limit"] = 100
        payload["queries"][0]["row_offset"] = 200
        query_context, errors = ChartDataQueryContextSchema().load(payload)
        self.assertEqual(errors, {})
        query_object = query_context.queries[0]
        self.assertEqual(query_object.row_limit, 100)
        self.assertEqual(query_object.row_offset, 200)

        # too low limit and offset
        payload["queries"][0]["row_limit"] = 0
        payload["queries"][0]["row_offset"] = -1
        query_context, errors = ChartDataQueryContextSchema().load(payload)
        self.assertIn("row_limit", errors["queries"][0])
        self.assertIn("row_offset", errors["queries"][0])
Пример #3
0
def get_sql_text(payload: Dict[str, Any]) -> str:
    payload["result_type"] = ChartDataResultType.QUERY.value
    query_context = ChartDataQueryContextSchema().load(payload)
    responses = query_context.get_payload()
    assert len(responses) == 1
    response = responses["queries"][0]
    assert len(response) == 2
    assert response["language"] == "sql"
    return response["query"]
Пример #4
0
    def set_query_context(self, form_data: Dict[str, Any]) -> QueryContext:
        self._form_data = form_data
        try:
            self._query_context = ChartDataQueryContextSchema().load(self._form_data)
        except KeyError as ex:
            raise ValidationError("Request is incorrect") from ex
        except ValidationError as error:
            raise error

        return self._query_context
Пример #5
0
 def test_sql_injection_via_groupby(self):
     """
     Ensure that calling invalid columns names in groupby are caught
     """
     self.login(username="******")
     payload = get_query_context("birth_names")
     payload["queries"][0]["groupby"] = ["currentDatabase()"]
     query_context = ChartDataQueryContextSchema().load(payload)
     query_payload = query_context.get_payload()
     assert query_payload["queries"][0].get("error") is not None
Пример #6
0
    def test_cache_key_changes_when_post_processing_is_updated(self):
        self.login(username="******")
        table_name = "birth_names"
        table = self.get_table_by_name(table_name)
        payload = get_query_context(table.name,
                                    table.id,
                                    table.type,
                                    add_postprocessing_operations=True)

        # construct baseline cache_key from query_context with post processing operation
        query_context = ChartDataQueryContextSchema().load(payload)
        query_object = query_context.queries[0]
        cache_key_original = query_context.query_cache_key(query_object)

        # ensure added None post_processing operation doesn't change cache_key
        payload["queries"][0]["post_processing"].append(None)
        query_context = ChartDataQueryContextSchema().load(payload)
        query_object = query_context.queries[0]
        cache_key_with_null = query_context.query_cache_key(query_object)
        self.assertEqual(cache_key_original, cache_key_with_null)

        # ensure query without post processing operation is different
        payload["queries"][0].pop("post_processing")
        query_context = ChartDataQueryContextSchema().load(payload)
        query_object = query_context.queries[0]
        cache_key_without_post_processing = query_context.query_cache_key(
            query_object)
        self.assertNotEqual(cache_key_original,
                            cache_key_without_post_processing)
Пример #7
0
 def data(self) -> Response:
     """
     Takes a query context constructed in the client and returns payload
     data response for the given query.
     ---
     post:
       description: >-
         Takes a query context constructed in the client and returns payload data
         response for the given query.
       requestBody:
         description: >-
           A query context consists of a datasource from which to fetch data
           and one or many query objects.
         required: true
         content:
           application/json:
             schema:
               $ref: "#/components/schemas/ChartDataQueryContextSchema"
       responses:
         200:
           description: Query result
           content:
             application/json:
               schema:
                 $ref: "#/components/schemas/ChartDataResponseSchema"
         400:
           $ref: '#/components/responses/400'
         500:
           $ref: '#/components/responses/500'
         """
     if not request.is_json:
         return self.response_400(message="Request is not JSON")
     try:
         query_context, errors = ChartDataQueryContextSchema().load(
             request.json)
         if errors:
             return self.response_400(
                 message=_("Request is incorrect: %(error)s", error=errors))
     except KeyError:
         return self.response_400(message="Request is incorrect")
     try:
         security_manager.assert_query_context_permission(query_context)
     except SupersetSecurityException:
         return self.response_401()
     payload_json = query_context.get_payload()
     response_data = simplejson.dumps({"result": payload_json},
                                      default=json_int_dttm_ser,
                                      ignore_nan=True)
     resp = make_response(response_data, 200)
     resp.headers["Content-Type"] = "application/json; charset=utf-8"
     return resp
Пример #8
0
 def test_csv_response_format(self):
     """
     Ensure that CSV result format works
     """
     self.login(username="******")
     payload = get_query_context("birth_names")
     payload["result_format"] = ChartDataResultFormat.CSV.value
     payload["queries"][0]["row_limit"] = 10
     query_context = ChartDataQueryContextSchema().load(payload)
     responses = query_context.get_payload()
     self.assertEqual(len(responses), 1)
     data = responses["queries"][0]["data"]
     self.assertIn("name,sum__num\n", data)
     self.assertEqual(len(data.split("\n")), 12)
 def test_fetch_values_predicate_not_in_query(self):
     """
     Ensure that fetch values predicate is not added to query
     """
     self.login(username="******")
     payload = get_query_context("birth_names")
     payload["result_type"] = ChartDataResultType.QUERY.value
     query_context = ChartDataQueryContextSchema().load(payload)
     responses = query_context.get_payload()
     assert len(responses) == 1
     response = responses["queries"][0]
     assert len(response) == 2
     assert response["language"] == "sql"
     assert "123 = 123" not in response["query"]
Пример #10
0
 def test_query_response_type(self):
     """
     Ensure that query result type works
     """
     self.login(username="******")
     payload = get_query_context("birth_names")
     payload["result_type"] = ChartDataResultType.QUERY.value
     query_context = ChartDataQueryContextSchema().load(payload)
     responses = query_context.get_payload()
     self.assertEqual(len(responses), 1)
     response = responses["queries"][0]
     self.assertEqual(len(response), 2)
     self.assertEqual(response["language"], "sql")
     self.assertIn("SELECT", response["query"])
Пример #11
0
 def test_sql_injection_via_columns(self):
     """
     Ensure that calling invalid column names in columns are caught
     """
     self.login(username="******")
     table_name = "birth_names"
     table = self.get_table_by_name(table_name)
     payload = get_query_context(table.name, table.id, table.type)
     payload["queries"][0]["groupby"] = []
     payload["queries"][0]["metrics"] = []
     payload["queries"][0]["columns"] = ["*, 'extra'"]
     query_context = ChartDataQueryContextSchema().load(payload)
     query_payload = query_context.get_payload()
     assert query_payload["queries"][0].get("error") is not None
Пример #12
0
 def test_query_response_type(self):
     """
     Ensure that query result type works
     """
     self.login(username="******")
     payload = get_query_context("birth_names")
     payload["result_type"] = ChartDataResultType.QUERY.value
     query_context = ChartDataQueryContextSchema().load(payload)
     responses = query_context.get_payload()
     assert len(responses) == 1
     response = responses["queries"][0]
     assert len(response) == 2
     assert response["language"] == "sql"
     assert "SELECT" in response["query"]
Пример #13
0
 def test_samples_response_type(self):
     """
     Ensure that samples result type works
     """
     self.login(username="******")
     payload = get_query_context("birth_names")
     payload["result_type"] = ChartDataResultType.SAMPLES.value
     payload["queries"][0]["row_limit"] = 5
     query_context = ChartDataQueryContextSchema().load(payload)
     responses = query_context.get_payload()
     self.assertEqual(len(responses), 1)
     data = responses["queries"][0]["data"]
     self.assertIsInstance(data, list)
     self.assertEqual(len(data), 5)
     self.assertNotIn("sum__num", data[0])
Пример #14
0
    def test_query_cache_key_does_not_change_for_non_existent_or_null(self):
        self.login(username="******")
        payload = get_query_context("birth_names", add_postprocessing_operations=True)
        del payload["queries"][0]["granularity"]

        # construct baseline query_cache_key from query_context with post processing operation
        query_context: QueryContext = ChartDataQueryContextSchema().load(payload)
        query_object: QueryObject = query_context.queries[0]
        cache_key_original = query_context.query_cache_key(query_object)

        payload["queries"][0]["granularity"] = None
        query_context = ChartDataQueryContextSchema().load(payload)
        query_object = query_context.queries[0]

        assert query_context.query_cache_key(query_object) == cache_key_original
Пример #15
0
def _create_query_context_from_form(form_data: Dict[str, Any]) -> QueryContext:
    try:
        return ChartDataQueryContextSchema().load(form_data)
    except KeyError as ex:
        raise ValidationError("Request is incorrect") from ex
    except ValidationError as error:
        raise error
 def test_sql_injection_via_filters(self):
     """
     Ensure that calling invalid columns names in filters are caught
     """
     self.login(username="******")
     table_name = "birth_names"
     table = self.get_table_by_name(table_name)
     payload = get_query_context(table.name, table.id, table.type)
     payload["queries"][0]["groupby"] = ["name"]
     payload["queries"][0]["metrics"] = []
     payload["queries"][0]["filters"] = [
         {"col": "*", "op": FilterOperator.EQUALS.value, "val": ";"}
     ]
     query_context = ChartDataQueryContextSchema().load(payload)
     query_payload = query_context.get_payload()
     assert query_payload[0].get("error") is not None
Пример #17
0
 def test_convert_deprecated_fields(self):
     """
     Ensure that deprecated fields are converted correctly
     """
     self.login(username="******")
     payload = get_query_context("birth_names")
     columns = payload["queries"][0]["columns"]
     payload["queries"][0]["groupby"] = columns
     payload["queries"][0]["timeseries_limit"] = 99
     payload["queries"][0]["timeseries_limit_metric"] = "sum__num"
     del payload["queries"][0]["columns"]
     payload["queries"][0]["granularity_sqla"] = "timecol"
     payload["queries"][0]["having_filters"] = [{
         "col": "a",
         "op": "==",
         "val": "b"
     }]
     query_context = ChartDataQueryContextSchema().load(payload)
     self.assertEqual(len(query_context.queries), 1)
     query_object = query_context.queries[0]
     self.assertEqual(query_object.granularity, "timecol")
     self.assertEqual(query_object.columns, columns)
     self.assertEqual(query_object.series_limit, 99)
     self.assertEqual(query_object.series_limit_metric, "sum__num")
     self.assertIn("having_druid", query_object.extras)
Пример #18
0
    def test_query_context_null_post_processing_op(self):
        self.login(username="******")
        payload = get_query_context("birth_names")

        payload["queries"][0]["post_processing"] = [None]
        query_context = ChartDataQueryContextSchema().load(payload)
        self.assertEqual(query_context.queries[0].post_processing, [])
Пример #19
0
    def test_schema_deserialization(self):
        """
        Ensure that the deserialized QueryContext contains all required fields.
        """

        payload = get_query_context("birth_names", add_postprocessing_operations=True)
        query_context = ChartDataQueryContextSchema().load(payload)
        self.assertEqual(len(query_context.queries), len(payload["queries"]))

        for query_idx, query in enumerate(query_context.queries):
            payload_query = payload["queries"][query_idx]

            # check basic properies
            self.assertEqual(query.extras, payload_query["extras"])
            self.assertEqual(query.filter, payload_query["filters"])
            self.assertEqual(query.columns, payload_query["columns"])

            # metrics are mutated during creation
            for metric_idx, metric in enumerate(query.metrics):
                payload_metric = payload_query["metrics"][metric_idx]
                payload_metric = (
                    payload_metric
                    if "expressionType" in payload_metric
                    else payload_metric["label"]
                )
                self.assertEqual(metric, payload_metric)

            self.assertEqual(query.orderby, payload_query["orderby"])
            self.assertEqual(query.time_range, payload_query["time_range"])

            # check post processing operation properties
            for post_proc_idx, post_proc in enumerate(query.post_processing):
                payload_post_proc = payload_query["post_processing"][post_proc_idx]
                self.assertEqual(post_proc["operation"], payload_post_proc["operation"])
                self.assertEqual(post_proc["options"], payload_post_proc["options"])
Пример #20
0
 def test_query_context_null_timegrain(self):
     self.login(username="******")
     table_name = "birth_names"
     table = self.get_table_by_name(table_name)
     payload = get_query_context(table.name, table.id, table.type)
     payload["queries"][0]["extras"]["time_grain_sqla"] = None
     _ = ChartDataQueryContextSchema().load(payload)
Пример #21
0
 def test_handle_metrics_field(self):
     """
     Should support both predefined and adhoc metrics.
     """
     self.login(username="******")
     adhoc_metric = {
         "expressionType": "SIMPLE",
         "column": {
             "column_name": "num_boys",
             "type": "BIGINT(20)"
         },
         "aggregate": "SUM",
         "label": "Boys",
         "optionName": "metric_11",
     }
     payload = get_query_context("birth_names")
     payload["queries"][0]["metrics"] = [
         "sum__num", {
             "label": "abc"
         }, adhoc_metric
     ]
     query_context = ChartDataQueryContextSchema().load(payload)
     query_object = query_context.queries[0]
     self.assertEqual(query_object.metrics,
                      ["sum__num", "abc", adhoc_metric])
Пример #22
0
    def test_cache(self):
        table_name = "birth_names"
        table = self.get_table_by_name(table_name)
        payload = get_query_context(table.name, table.id)
        payload["force"] = True

        query_context = ChartDataQueryContextSchema().load(payload)
        query_object = query_context.queries[0]
        query_cache_key = query_context.query_cache_key(query_object)

        response = query_context.get_payload(cache_query_context=True)
        cache_key = response["cache_key"]
        assert cache_key is not None

        cached = cache_manager.cache.get(cache_key)
        assert cached is not None

        rehydrated_qc = ChartDataQueryContextSchema().load(cached["data"])
        rehydrated_qo = rehydrated_qc.queries[0]
        rehydrated_query_cache_key = rehydrated_qc.query_cache_key(
            rehydrated_qo)

        self.assertEqual(rehydrated_qc.datasource, query_context.datasource)
        self.assertEqual(len(rehydrated_qc.queries), 1)
        self.assertEqual(query_cache_key, rehydrated_query_cache_key)
        self.assertEqual(rehydrated_qc.result_type, query_context.result_type)
        self.assertEqual(rehydrated_qc.result_format,
                         query_context.result_format)
        self.assertFalse(rehydrated_qc.force)
Пример #23
0
    def test_cache(self):
        table_name = "birth_names"
        payload = get_query_context(
            query_name=table_name,
            add_postprocessing_operations=True,
        )
        payload["force"] = True

        query_context = ChartDataQueryContextSchema().load(payload)
        query_object = query_context.queries[0]
        query_cache_key = query_context.query_cache_key(query_object)

        response = query_context.get_payload(cache_query_context=True)
        # MUST BE a successful query
        query_dump = response["queries"][0]
        assert query_dump["status"] == QueryStatus.SUCCESS

        cache_key = response["cache_key"]
        assert cache_key is not None

        cached = cache_manager.cache.get(cache_key)
        assert cached is not None

        rehydrated_qc = ChartDataQueryContextSchema().load(cached["data"])
        rehydrated_qo = rehydrated_qc.queries[0]
        rehydrated_query_cache_key = rehydrated_qc.query_cache_key(rehydrated_qo)

        self.assertEqual(rehydrated_qc.datasource, query_context.datasource)
        self.assertEqual(len(rehydrated_qc.queries), 1)
        self.assertEqual(query_cache_key, rehydrated_query_cache_key)
        self.assertEqual(rehydrated_qc.result_type, query_context.result_type)
        self.assertEqual(rehydrated_qc.result_format, query_context.result_format)
        self.assertFalse(rehydrated_qc.force)
Пример #24
0
    def test_query_context_null_post_processing_op(self):
        self.login(username="******")
        table_name = "birth_names"
        table = self.get_table_by_name(table_name)
        payload = get_query_context(table.name, table.id, table.type)

        payload["queries"][0]["post_processing"] = [None]
        query_context = ChartDataQueryContextSchema().load(payload)
        self.assertEqual(query_context.queries[0].post_processing, [])
Пример #25
0
 def test_sql_injection_via_metrics(self):
     """
     Ensure that calling invalid column names in filters are caught
     """
     self.login(username="******")
     payload = get_query_context("birth_names")
     payload["queries"][0]["groupby"] = ["name"]
     payload["queries"][0]["metrics"] = [
         {
             "expressionType": AdhocMetricExpressionType.SIMPLE.value,
             "column": {"column_name": "invalid_col"},
             "aggregate": "SUM",
             "label": "My Simple Label",
         }
     ]
     query_context = ChartDataQueryContextSchema().load(payload)
     query_payload = query_context.get_payload()
     assert query_payload["queries"][0].get("error") is not None
    def test_csv_response_format(self):
        """
        Ensure that CSV result format works
        """
        self.login(username="******")
        table_name = "birth_names"
        table = self.get_table_by_name(table_name)
        payload = get_query_context(table.name, table.id, table.type)
        payload["result_format"] = ChartDataResultFormat.CSV.value
        payload["queries"][0]["row_limit"] = 10
        query_context = ChartDataQueryContextSchema().load(payload)
        responses = query_context.get_payload()
        self.assertEqual(len(responses), 1)
        data = responses[0]["data"]
        self.assertIn("name,sum__num\n", data)
        self.assertEqual(len(data.split("\n")), 12)

        ck = db.session.query(CacheKey).order_by(CacheKey.id.desc()).first()
        assert ck.datasource_uid == "3__table"
Пример #27
0
    def test_query_context_limit_and_offset(self):
        self.login(username="******")
        payload = get_query_context("birth_names")

        # too low limit and offset
        payload["queries"][0]["row_limit"] = -1
        payload["queries"][0]["row_offset"] = -1
        with self.assertRaises(ValidationError) as context:
            _ = ChartDataQueryContextSchema().load(payload)
        self.assertIn("row_limit", context.exception.messages["queries"][0])
        self.assertIn("row_offset", context.exception.messages["queries"][0])
 def test_query_response_type(self):
     """
     Ensure that query result type works
     """
     self.login(username="******")
     payload = get_query_context("birth_names")
     payload["result_type"] = ChartDataResultType.QUERY.value
     query_context = ChartDataQueryContextSchema().load(payload)
     responses = query_context.get_payload()
     assert len(responses) == 1
     response = responses["queries"][0]
     assert len(response) == 2
     sql_text = response["query"]
     assert response["language"] == "sql"
     assert "SELECT" in sql_text
     assert re.search(r'[`"\[]?num[`"\]]? IS NOT NULL', sql_text)
     assert re.search(
         r"""NOT \([`"\[]?name[`"\]]? IS NULL[\s\n]* """
         r"""OR [`"\[]?name[`"\]]? IN \('abc'\)\)""",
         sql_text,
     )
Пример #29
0
    def test_time_offsets_sql(self):
        payload = get_query_context("birth_names")
        payload["queries"][0]["metrics"] = ["sum__num"]
        payload["queries"][0]["groupby"] = ["state"]
        payload["queries"][0]["is_timeseries"] = True
        payload["queries"][0]["timeseries_limit"] = 5
        payload["queries"][0]["time_offsets"] = []
        payload["queries"][0]["time_range"] = "1980 : 1991"
        payload["queries"][0]["granularity"] = "ds"
        payload["queries"][0]["extras"]["time_grain_sqla"] = "P1Y"
        query_context = ChartDataQueryContextSchema().load(payload)
        query_object = query_context.queries[0]
        query_result = query_context.get_query_result(query_object)
        # get main query dataframe
        df = query_result.df

        # set time_offsets to query_object
        payload["queries"][0]["time_offsets"] = ["3 years ago", "3 years later"]
        query_context = ChartDataQueryContextSchema().load(payload)
        query_object = query_context.queries[0]
        time_offsets_obj = query_context.processing_time_offsets(df, query_object)
        query_from_1977_to_1988 = time_offsets_obj["queries"][0]
        query_from_1983_to_1994 = time_offsets_obj["queries"][1]

        # should generate expected date range in sql
        assert "1977-01-01" in query_from_1977_to_1988
        assert "1988-01-01" in query_from_1977_to_1988
        assert "1983-01-01" in query_from_1983_to_1994
        assert "1994-01-01" in query_from_1983_to_1994
Пример #30
0
    def test_query_cache_key_changes_when_datasource_is_updated(self):
        self.login(username="******")
        payload = get_query_context("birth_names")

        # construct baseline query_cache_key
        query_context = ChartDataQueryContextSchema().load(payload)
        query_object = query_context.queries[0]
        cache_key_original = query_context.query_cache_key(query_object)

        # make temporary change and revert it to refresh the changed_on property
        datasource = ConnectorRegistry.get_datasource(
            datasource_type=payload["datasource"]["type"],
            datasource_id=payload["datasource"]["id"],
            session=db.session,
        )
        description_original = datasource.description
        datasource.description = "temporary description"
        db.session.commit()
        datasource.description = description_original
        db.session.commit()

        # create new QueryContext with unchanged attributes, extract new query_cache_key
        query_context = ChartDataQueryContextSchema().load(payload)
        query_object = query_context.queries[0]
        cache_key_new = query_context.query_cache_key(query_object)

        # the new cache_key should be different due to updated datasource
        self.assertNotEqual(cache_key_original, cache_key_new)