コード例 #1
0
 def test_sql_injection_via_groupby(self):
     """
     Ensure that calling invalid columns names in groupby are caught
     """
     self.login(username="******")
     payload = get_query_context("birth_names")
     payload["queries"][0]["groupby"] = ["currentDatabase()"]
     query_context = ChartDataQueryContextSchema().load(payload)
     query_payload = query_context.get_payload()
     assert query_payload["queries"][0].get("error") is not None
コード例 #2
0
 def data(self) -> Response:
     """
     Takes a query context constructed in the client and returns payload
     data response for the given query.
     ---
     post:
       description: >-
         Takes a query context constructed in the client and returns payload data
         response for the given query.
       requestBody:
         description: >-
           A query context consists of a datasource from which to fetch data
           and one or many query objects.
         required: true
         content:
           application/json:
             schema:
               $ref: "#/components/schemas/ChartDataQueryContextSchema"
       responses:
         200:
           description: Query result
           content:
             application/json:
               schema:
                 $ref: "#/components/schemas/ChartDataResponseSchema"
         400:
           $ref: '#/components/responses/400'
         500:
           $ref: '#/components/responses/500'
         """
     if not request.is_json:
         return self.response_400(message="Request is not JSON")
     try:
         query_context, errors = ChartDataQueryContextSchema().load(
             request.json)
         if errors:
             return self.response_400(
                 message=_("Request is incorrect: %(error)s", error=errors))
     except KeyError:
         return self.response_400(message="Request is incorrect")
     try:
         security_manager.assert_query_context_permission(query_context)
     except SupersetSecurityException:
         return self.response_401()
     payload_json = query_context.get_payload()
     response_data = simplejson.dumps({"result": payload_json},
                                      default=json_int_dttm_ser,
                                      ignore_nan=True)
     resp = make_response(response_data, 200)
     resp.headers["Content-Type"] = "application/json; charset=utf-8"
     return resp
コード例 #3
0
ファイル: query_context_tests.py プロジェクト: peifd/superset
 def test_query_response_type(self):
     """
     Ensure that query result type works
     """
     self.login(username="******")
     payload = get_query_context("birth_names")
     payload["result_type"] = ChartDataResultType.QUERY.value
     query_context = ChartDataQueryContextSchema().load(payload)
     responses = query_context.get_payload()
     self.assertEqual(len(responses), 1)
     response = responses["queries"][0]
     self.assertEqual(len(response), 2)
     self.assertEqual(response["language"], "sql")
     self.assertIn("SELECT", response["query"])
コード例 #4
0
ファイル: query_context_tests.py プロジェクト: peifd/superset
 def test_csv_response_format(self):
     """
     Ensure that CSV result format works
     """
     self.login(username="******")
     payload = get_query_context("birth_names")
     payload["result_format"] = ChartDataResultFormat.CSV.value
     payload["queries"][0]["row_limit"] = 10
     query_context = ChartDataQueryContextSchema().load(payload)
     responses = query_context.get_payload()
     self.assertEqual(len(responses), 1)
     data = responses["queries"][0]["data"]
     self.assertIn("name,sum__num\n", data)
     self.assertEqual(len(data.split("\n")), 12)
コード例 #5
0
 def test_sql_injection_via_columns(self):
     """
     Ensure that calling invalid column names in columns are caught
     """
     self.login(username="******")
     table_name = "birth_names"
     table = self.get_table_by_name(table_name)
     payload = get_query_context(table.name, table.id, table.type)
     payload["queries"][0]["groupby"] = []
     payload["queries"][0]["metrics"] = []
     payload["queries"][0]["columns"] = ["*, 'extra'"]
     query_context = ChartDataQueryContextSchema().load(payload)
     query_payload = query_context.get_payload()
     assert query_payload["queries"][0].get("error") is not None
コード例 #6
0
 def test_fetch_values_predicate_not_in_query(self):
     """
     Ensure that fetch values predicate is not added to query
     """
     self.login(username="******")
     payload = get_query_context("birth_names")
     payload["result_type"] = ChartDataResultType.QUERY.value
     query_context = ChartDataQueryContextSchema().load(payload)
     responses = query_context.get_payload()
     assert len(responses) == 1
     response = responses["queries"][0]
     assert len(response) == 2
     assert response["language"] == "sql"
     assert "123 = 123" not in response["query"]
コード例 #7
0
 def test_query_response_type(self):
     """
     Ensure that query result type works
     """
     self.login(username="******")
     payload = get_query_context("birth_names")
     payload["result_type"] = ChartDataResultType.QUERY.value
     query_context = ChartDataQueryContextSchema().load(payload)
     responses = query_context.get_payload()
     assert len(responses) == 1
     response = responses["queries"][0]
     assert len(response) == 2
     assert response["language"] == "sql"
     assert "SELECT" in response["query"]
コード例 #8
0
ファイル: query_context_tests.py プロジェクト: peifd/superset
 def test_samples_response_type(self):
     """
     Ensure that samples result type works
     """
     self.login(username="******")
     payload = get_query_context("birth_names")
     payload["result_type"] = ChartDataResultType.SAMPLES.value
     payload["queries"][0]["row_limit"] = 5
     query_context = ChartDataQueryContextSchema().load(payload)
     responses = query_context.get_payload()
     self.assertEqual(len(responses), 1)
     data = responses["queries"][0]["data"]
     self.assertIsInstance(data, list)
     self.assertEqual(len(data), 5)
     self.assertNotIn("sum__num", data[0])
コード例 #9
0
 def test_sql_injection_via_filters(self):
     """
     Ensure that calling invalid columns names in filters are caught
     """
     self.login(username="******")
     table_name = "birth_names"
     table = self.get_table_by_name(table_name)
     payload = get_query_context(table.name, table.id, table.type)
     payload["queries"][0]["groupby"] = ["name"]
     payload["queries"][0]["metrics"] = []
     payload["queries"][0]["filters"] = [
         {"col": "*", "op": FilterOperator.EQUALS.value, "val": ";"}
     ]
     query_context = ChartDataQueryContextSchema().load(payload)
     query_payload = query_context.get_payload()
     assert query_payload[0].get("error") is not None
コード例 #10
0
 def test_sql_injection_via_metrics(self):
     """
     Ensure that calling invalid column names in filters are caught
     """
     self.login(username="******")
     payload = get_query_context("birth_names")
     payload["queries"][0]["groupby"] = ["name"]
     payload["queries"][0]["metrics"] = [
         {
             "expressionType": AdhocMetricExpressionType.SIMPLE.value,
             "column": {"column_name": "invalid_col"},
             "aggregate": "SUM",
             "label": "My Simple Label",
         }
     ]
     query_context = ChartDataQueryContextSchema().load(payload)
     query_payload = query_context.get_payload()
     assert query_payload["queries"][0].get("error") is not None
コード例 #11
0
    def test_csv_response_format(self):
        """
        Ensure that CSV result format works
        """
        self.login(username="******")
        table_name = "birth_names"
        table = self.get_table_by_name(table_name)
        payload = get_query_context(table.name, table.id, table.type)
        payload["result_format"] = ChartDataResultFormat.CSV.value
        payload["queries"][0]["row_limit"] = 10
        query_context = ChartDataQueryContextSchema().load(payload)
        responses = query_context.get_payload()
        self.assertEqual(len(responses), 1)
        data = responses[0]["data"]
        self.assertIn("name,sum__num\n", data)
        self.assertEqual(len(data.split("\n")), 12)

        ck = db.session.query(CacheKey).order_by(CacheKey.id.desc()).first()
        assert ck.datasource_uid == "3__table"
コード例 #12
0
 def test_query_response_type(self):
     """
     Ensure that query result type works
     """
     self.login(username="******")
     payload = get_query_context("birth_names")
     payload["result_type"] = ChartDataResultType.QUERY.value
     query_context = ChartDataQueryContextSchema().load(payload)
     responses = query_context.get_payload()
     assert len(responses) == 1
     response = responses["queries"][0]
     assert len(response) == 2
     sql_text = response["query"]
     assert response["language"] == "sql"
     assert "SELECT" in sql_text
     assert re.search(r'[`"\[]?num[`"\]]? IS NOT NULL', sql_text)
     assert re.search(
         r"""NOT \([`"\[]?name[`"\]]? IS NULL[\s\n]* """
         r"""OR [`"\[]?name[`"\]]? IN \('abc'\)\)""",
         sql_text,
     )
コード例 #13
0
    def test_time_offsets_in_query_object(self):
        """
        Ensure that time_offsets can generate the correct query
        """
        self.login(username="******")
        payload = get_query_context("birth_names")
        payload["queries"][0]["metrics"] = ["sum__num"]
        payload["queries"][0]["groupby"] = ["name"]
        payload["queries"][0]["is_timeseries"] = True
        payload["queries"][0]["timeseries_limit"] = 5
        payload["queries"][0]["time_offsets"] = ["1 year ago", "1 year later"]
        payload["queries"][0]["time_range"] = "1990 : 1991"
        query_context = ChartDataQueryContextSchema().load(payload)
        responses = query_context.get_payload()
        self.assertEqual(
            responses["queries"][0]["colnames"],
            [
                "__timestamp",
                "name",
                "sum__num",
                "sum__num__1 year ago",
                "sum__num__1 year later",
            ],
        )

        sqls = [
            sql for sql in responses["queries"][0]["query"].split(";")
            if sql.strip()
        ]
        self.assertEqual(len(sqls), 3)
        # 1 year ago
        assert re.search(r"1989-01-01.+1990-01-01", sqls[1], re.S)
        assert re.search(r"1990-01-01.+1991-01-01", sqls[1], re.S)

        # # 1 year later
        assert re.search(r"1991-01-01.+1992-01-01", sqls[2], re.S)
        assert re.search(r"1990-01-01.+1991-01-01", sqls[2], re.S)
コード例 #14
0
ファイル: api.py プロジェクト: milindgv94/incubator-superset
    def data(self) -> Response:  # pylint: disable=too-many-return-statements
        """
        Takes a query context constructed in the client and returns payload
        data response for the given query.
        ---
        post:
          description: >-
            Takes a query context constructed in the client and returns payload data
            response for the given query.
          requestBody:
            description: >-
              A query context consists of a datasource from which to fetch data
              and one or many query objects.
            required: true
            content:
              application/json:
                schema:
                  $ref: "#/components/schemas/ChartDataQueryContextSchema"
          responses:
            200:
              description: Query result
              content:
                application/json:
                  schema:
                    $ref: "#/components/schemas/ChartDataResponseSchema"
            400:
              $ref: '#/components/responses/400'
            500:
              $ref: '#/components/responses/500'
        """
        if request.is_json:
            json_body = request.json
        elif request.form.get("form_data"):
            # CSV export submits regular form data
            json_body = json.loads(request.form["form_data"])
        else:
            return self.response_400(message="Request is not JSON")
        try:
            query_context = ChartDataQueryContextSchema().load(json_body)
        except KeyError:
            return self.response_400(message="Request is incorrect")
        except ValidationError as error:
            return self.response_400(message=_(
                "Request is incorrect: %(error)s", error=error.messages))
        try:
            query_context.raise_for_access()
        except SupersetSecurityException:
            return self.response_401()
        payload = query_context.get_payload()
        for query in payload:
            if query.get("error"):
                return self.response_400(message=f"Error: {query['error']}")
        result_format = query_context.result_format
        if result_format == ChartDataResultFormat.CSV:
            # return the first result
            result = payload[0]["data"]
            return CsvResponse(
                result,
                status=200,
                headers=generate_download_headers("csv"),
                mimetype="application/csv",
            )

        if result_format == ChartDataResultFormat.JSON:
            response_data = simplejson.dumps({"result": payload},
                                             default=json_int_dttm_ser,
                                             ignore_nan=True)
            resp = make_response(response_data, 200)
            resp.headers["Content-Type"] = "application/json; charset=utf-8"
            return resp

        return self.response_400(
            message=f"Unsupported result_format: {result_format}")