Esempio n. 1
0
    def test_auth_custom_auth(self):
        database = Mock()
        auth_class = Mock()

        auth_method = "custom_auth"
        auth_params = {"params1": "params1", "params2": "params2"}
        database.encrypted_extra = json.dumps({
            "auth_method": auth_method,
            "auth_params": auth_params
        })

        with patch.dict(
                "superset.config.ALLOWED_EXTRA_AUTHENTICATIONS",
            {"trino": {
                "custom_auth": auth_class
            }},
                clear=True,
        ):
            params: Dict[str, Any] = {}
            TrinoEngineSpec.update_encrypted_extra_params(database, params)

            connect_args = params.setdefault("connect_args", {})
            self.assertEqual(connect_args.get("http_scheme"), "https")

            auth_class.assert_called_once_with(**auth_params)
Esempio n. 2
0
    def test_adjust_database_uri_when_selected_schema_is_none(self):
        url = URL(drivername="trino", database="hive")
        TrinoEngineSpec.adjust_database_uri(url, selected_schema=None)
        self.assertEqual(url.database, "hive")

        url.database = "hive/default"
        TrinoEngineSpec.adjust_database_uri(url, selected_schema=None)
        self.assertEqual(url.database, "hive/default")
Esempio n. 3
0
    def test_convert_dttm(self):
        dttm = self.get_dttm()

        self.assertEqual(
            TrinoEngineSpec.convert_dttm("DATE", dttm),
            "DATE '2019-01-02'",
        )

        self.assertEqual(
            TrinoEngineSpec.convert_dttm("TIMESTAMP", dttm),
            "TIMESTAMP '2019-01-02T03:04:05.678900'",
        )
Esempio n. 4
0
    def test_convert_dttm(self):
        dttm = self.get_dttm()

        self.assertEqual(
            TrinoEngineSpec.convert_dttm("DATE", dttm),
            "from_iso8601_date('2019-01-02')",
        )

        self.assertEqual(
            TrinoEngineSpec.convert_dttm("TIMESTAMP", dttm),
            "from_iso8601_timestamp('2019-01-02T03:04:05.678900')",
        )
Esempio n. 5
0
    def test_auth_basic(self, auth: Mock):
        database = Mock()

        auth_params = {"username": "******", "password": "******"}
        database.encrypted_extra = json.dumps(
            {"auth_method": "basic", "auth_params": auth_params}
        )

        params: Dict[str, Any] = {}
        TrinoEngineSpec.update_encrypted_extra_params(database, params)
        connect_args = params.setdefault("connect_args", {})
        self.assertEqual(connect_args.get("http_scheme"), "https")
        auth.assert_called_once_with(**auth_params)
Esempio n. 6
0
def test_cancel_query_failed(engine_mock: mock.Mock) -> None:
    from superset.db_engine_specs.trino import TrinoEngineSpec
    from superset.models.sql_lab import Query

    query = Query()
    cursor_mock = engine_mock.raiseError.side_effect = Exception()
    assert TrinoEngineSpec.cancel_query(cursor_mock, query, "123") is False
Esempio n. 7
0
def test_cancel_query_success(engine_mock: mock.Mock) -> None:
    from superset.db_engine_specs.trino import TrinoEngineSpec
    from superset.models.sql_lab import Query

    query = Query()
    cursor_mock = engine_mock.return_value.__enter__.return_value
    assert TrinoEngineSpec.cancel_query(cursor_mock, query, "123") is True
Esempio n. 8
0
    def test_auth_custom_auth_denied(self):
        database = Mock()
        auth_method = "my.module:TrinoAuthClass"
        auth_params = {"params1": "params1", "params2": "params2"}
        database.encrypted_extra = json.dumps({
            "auth_method": auth_method,
            "auth_params": auth_params
        })

        superset.config.ALLOWED_EXTRA_AUTHENTICATIONS = {}

        with pytest.raises(ValueError) as excinfo:
            TrinoEngineSpec.update_encrypted_extra_params(database, {})

        assert str(excinfo.value) == (
            f"For security reason, custom authentication '{auth_method}' "
            f"must be listed in 'ALLOWED_EXTRA_AUTHENTICATIONS' config")
Esempio n. 9
0
    def test_get_extra_params(self):
        database = Mock()

        database.extra = json.dumps({})
        database.server_cert = None
        extra = TrinoEngineSpec.get_extra_params(database)
        expected = {"engine_params": {"connect_args": {}}}
        self.assertEqual(extra, expected)

        expected = {
            "first": 1,
            "engine_params": {"second": "two", "connect_args": {"third": "three"}},
        }
        database.extra = json.dumps(expected)
        database.server_cert = None
        extra = TrinoEngineSpec.get_extra_params(database)
        self.assertEqual(extra, expected)
Esempio n. 10
0
    def test_auth_kerberos(self, auth: Mock):
        database = Mock()

        auth_params = {
            "service_name": "superset",
            "mutual_authentication": False,
            "delegate": True,
        }
        database.encrypted_extra = json.dumps(
            {"auth_method": "kerberos", "auth_params": auth_params}
        )

        params: Dict[str, Any] = {}
        TrinoEngineSpec.update_encrypted_extra_params(database, params)
        connect_args = params.setdefault("connect_args", {})
        self.assertEqual(connect_args.get("http_scheme"), "https")
        auth.assert_called_once_with(**auth_params)
Esempio n. 11
0
    def test_get_extra_params_with_server_cert(self, create_ssl_cert_file_func: Mock):
        database = Mock()

        database.extra = json.dumps({})
        database.server_cert = "TEST_CERT"
        create_ssl_cert_file_func.return_value = "/path/to/tls.crt"
        extra = TrinoEngineSpec.get_extra_params(database)

        connect_args = extra.get("engine_params", {}).get("connect_args", {})
        self.assertEqual(connect_args.get("http_scheme"), "https")
        self.assertEqual(connect_args.get("verify"), "/path/to/tls.crt")
        create_ssl_cert_file_func.assert_called_once_with(database.server_cert)
Esempio n. 12
0
 def test_adjust_database_uri_when_database_contain_schema(self):
     url = URL(drivername="trino", database="hive/default")
     TrinoEngineSpec.adjust_database_uri(url, selected_schema="foobar")
     self.assertEqual(url.database, "hive/foobar")