def test_connection_type_not_supported(self, mock_get_connection):
        conn = get_airflow_connection("NOT_SUPPORT")
        mock_get_connection.return_value = conn
        hook = GrpcHook("grpc_default")

        with self.assertRaises(AirflowConfigException):
            hook.get_conn()
Exemple #2
0
    def test_connection_type_not_supported(self, mock_get_connection):
        conn = get_airflow_connection("NOT_SUPPORT")
        mock_get_connection.return_value = conn
        hook = GrpcHook("grpc_default")

        with self.assertRaises(AirflowConfigException):
            hook.get_conn()
    def test_connection_with_jwt(self,
                                 mock_secure_channel,
                                 mock_google_default_auth,
                                 mock_google_cred,
                                 mock_get_connection):
        conn = get_airflow_connection(
            auth_type="JWT_GOOGLE"
        )
        mock_get_connection.return_value = conn
        hook = GrpcHook("grpc_default")
        mocked_channel = self.channel_mock.return_value
        mock_secure_channel.return_value = mocked_channel
        mock_credential_object = "test_credential_object"
        mock_google_default_auth.return_value = (mock_credential_object, "")
        mock_google_cred.return_value = mock_credential_object

        channel = hook.get_conn()
        expected_url = "test:8080"

        mock_google_cred.assert_called_once_with(mock_credential_object)
        mock_secure_channel.assert_called_once_with(
            mock_credential_object,
            None,
            expected_url
        )
        self.assertEquals(channel, mocked_channel)
    def test_custom_connection_with_no_connection_func(self, mock_get_connection):
        conn = get_airflow_connection("CUSTOM")
        mock_get_connection.return_value = conn
        hook = GrpcHook("grpc_default")

        with self.assertRaises(AirflowConfigException):
            hook.get_conn()
Exemple #5
0
    def test_connection_with_google_oauth(self,
                                          mock_secure_channel,
                                          mock_google_default_auth,
                                          mock_google_auth_request,
                                          mock_get_connection):
        conn = get_airflow_connection(
            auth_type="OATH_GOOGLE",
            scopes="grpc,gcs"
        )
        mock_get_connection.return_value = conn
        hook = GrpcHook("grpc_default")
        mocked_channel = self.channel_mock.return_value
        mock_secure_channel.return_value = mocked_channel
        mock_credential_object = "test_credential_object"
        mock_google_default_auth.return_value = (mock_credential_object, "")
        mock_google_auth_request.return_value = "request"

        channel = hook.get_conn()
        expected_url = "test:8080"

        mock_google_default_auth.assert_called_once_with(scopes=[u"grpc", u"gcs"])
        mock_secure_channel.assert_called_once_with(
            mock_credential_object,
            "request",
            expected_url
        )
        self.assertEqual(channel, mocked_channel)
    def test_connection_with_tls(self,
                                 mock_secure_channel,
                                 mock_channel_credentials,
                                 mock_get_connection,
                                 mock_open):
        conn = get_airflow_connection(
            auth_type="TLS",
            credential_pem_file="pem"
        )
        mock_get_connection.return_value = conn
        mock_open.return_value = StringIO('credential')
        hook = GrpcHook("grpc_default")
        mocked_channel = self.channel_mock.return_value
        mock_secure_channel.return_value = mocked_channel
        mock_credential_object = "test_credential_object"
        mock_channel_credentials.return_value = mock_credential_object

        channel = hook.get_conn()
        expected_url = "test:8080"

        mock_open.assert_called_once_with("pem")
        mock_channel_credentials.assert_called_once_with('credential')
        mock_secure_channel.assert_called_once_with(
            expected_url,
            mock_credential_object
        )
        self.assertEquals(channel, mocked_channel)
Exemple #7
0
    def test_custom_connection_with_no_connection_func(self, mock_get_connection):
        conn = get_airflow_connection("CUSTOM")
        mock_get_connection.return_value = conn
        hook = GrpcHook("grpc_default")

        with self.assertRaises(AirflowConfigException):
            hook.get_conn()
Exemple #8
0
    def test_connection_with_jwt(self,
                                 mock_secure_channel,
                                 mock_google_default_auth,
                                 mock_google_cred,
                                 mock_get_connection):
        conn = get_airflow_connection(
            auth_type="JWT_GOOGLE"
        )
        mock_get_connection.return_value = conn
        hook = GrpcHook("grpc_default")
        mocked_channel = self.channel_mock.return_value
        mock_secure_channel.return_value = mocked_channel
        mock_credential_object = "test_credential_object"
        mock_google_default_auth.return_value = (mock_credential_object, "")
        mock_google_cred.return_value = mock_credential_object

        channel = hook.get_conn()
        expected_url = "test:8080"

        mock_google_cred.assert_called_once_with(mock_credential_object)
        mock_secure_channel.assert_called_once_with(
            mock_credential_object,
            None,
            expected_url
        )
        self.assertEqual(channel, mocked_channel)
Exemple #9
0
    def test_connection_with_tls(self,
                                 mock_secure_channel,
                                 mock_channel_credentials,
                                 mock_get_connection,
                                 mock_open):
        conn = get_airflow_connection(
            auth_type="TLS",
            credential_pem_file="pem"
        )
        mock_get_connection.return_value = conn
        mock_open.return_value = StringIO('credential')
        hook = GrpcHook("grpc_default")
        mocked_channel = self.channel_mock.return_value
        mock_secure_channel.return_value = mocked_channel
        mock_credential_object = "test_credential_object"
        mock_channel_credentials.return_value = mock_credential_object

        channel = hook.get_conn()
        expected_url = "test:8080"

        mock_open.assert_called_once_with("pem")
        mock_channel_credentials.assert_called_once_with('credential')
        mock_secure_channel.assert_called_once_with(
            expected_url,
            mock_credential_object
        )
        self.assertEqual(channel, mocked_channel)
Exemple #10
0
    def test_connection_with_google_oauth(self,
                                          mock_secure_channel,
                                          mock_google_default_auth,
                                          mock_google_auth_request,
                                          mock_get_connection):
        conn = get_airflow_connection(
            auth_type="OATH_GOOGLE",
            scopes="grpc,gcs"
        )
        mock_get_connection.return_value = conn
        hook = GrpcHook("grpc_default")
        mocked_channel = self.channel_mock.return_value
        mock_secure_channel.return_value = mocked_channel
        mock_credential_object = "test_credential_object"
        mock_google_default_auth.return_value = (mock_credential_object, "")
        mock_google_auth_request.return_value = "request"

        channel = hook.get_conn()
        expected_url = "test:8080"

        mock_google_default_auth.assert_called_once_with(scopes=[u"grpc", u"gcs"])
        mock_secure_channel.assert_called_once_with(
            mock_credential_object,
            "request",
            expected_url
        )
        self.assertEquals(channel, mocked_channel)
Exemple #11
0
    def test_custom_connection(self, mock_get_connection):
        conn = get_airflow_connection("CUSTOM")
        mock_get_connection.return_value = conn
        mocked_channel = self.channel_mock.return_value
        hook = GrpcHook("grpc_default", custom_connection_func=self.custom_conn_func)

        channel = hook.get_conn()

        self.assertEquals(channel, mocked_channel)
Exemple #12
0
    def test_custom_connection(self, mock_get_connection):
        conn = get_airflow_connection("CUSTOM")
        mock_get_connection.return_value = conn
        mocked_channel = self.channel_mock.return_value
        hook = GrpcHook("grpc_default", custom_connection_func=self.custom_conn_func)

        channel = hook.get_conn()

        self.assertEqual(channel, mocked_channel)
Exemple #13
0
    def test_stream_run(self, mock_get_conn, mock_get_connection):
        conn = get_airflow_connection()
        mock_get_connection.return_value = conn
        mocked_channel = mock.Mock()
        mocked_channel.__enter__ = mock.Mock(return_value=(mock.Mock(), None))
        mocked_channel.__exit__ = mock.Mock(return_value=None)
        hook = GrpcHook("grpc_default")
        mock_get_conn.return_value = mocked_channel

        response = hook.run(StubClass, "stream_call", data={'data': ['hello!', "hi"]})

        self.assertEquals(next(response), ["streaming", "call"])
Exemple #14
0
    def test_stream_run(self, mock_get_conn, mock_get_connection):
        conn = get_airflow_connection()
        mock_get_connection.return_value = conn
        mocked_channel = mock.Mock()
        mocked_channel.__enter__ = mock.Mock(return_value=(mock.Mock(), None))
        mocked_channel.__exit__ = mock.Mock(return_value=None)
        hook = GrpcHook("grpc_default")
        mock_get_conn.return_value = mocked_channel

        response = hook.run(StubClass, "stream_call", data={'data': ['hello!', "hi"]})

        self.assertEquals(next(response), ["streaming", "call"])
Exemple #15
0
    def test_connection_with_port(self, mock_get_connection, mock_insecure_channel):
        conn = get_airflow_connection_with_port()
        mock_get_connection.return_value = conn
        hook = GrpcHook("grpc_default")
        mocked_channel = self.channel_mock.return_value
        mock_insecure_channel.return_value = mocked_channel

        channel = hook.get_conn()
        expected_url = "test.com:1234"

        mock_insecure_channel.assert_called_once_with(expected_url)
        self.assertEqual(channel, mocked_channel)
Exemple #16
0
    def test_connection_with_port(self, mock_get_connection, mock_insecure_channel):
        conn = get_airflow_connection_with_port()
        mock_get_connection.return_value = conn
        hook = GrpcHook("grpc_default")
        mocked_channel = self.channel_mock.return_value
        mock_insecure_channel.return_value = mocked_channel

        channel = hook.get_conn()
        expected_url = "test.com:1234"

        mock_insecure_channel.assert_called_once_with(expected_url)
        self.assertEquals(channel, mocked_channel)
Exemple #17
0
    def test_simple_run(self, mock_get_conn, mock_get_connection):
        conn = get_airflow_connection()
        mock_get_connection.return_value = conn
        mocked_channel = mock.Mock()
        mocked_channel.__enter__ = mock.Mock(return_value=(mock.Mock(), None))
        mocked_channel.__exit__ = mock.Mock(return_value=None)
        hook = GrpcHook("grpc_default")
        mock_get_conn.return_value = mocked_channel

        response = hook.run(StubClass, "single_call", data={'data': 'hello'})

        self.assertEqual(next(response), "hello")
Exemple #18
0
    def test_connection_with_interceptors(self, mock_insecure_channel,
                                          mock_get_connection,
                                          mock_intercept_channel):
        conn = get_airflow_connection()
        mock_get_connection.return_value = conn
        mocked_channel = self.channel_mock.return_value
        hook = GrpcHook("grpc_default", interceptors=["test1"])
        mock_insecure_channel.return_value = mocked_channel
        mock_intercept_channel.return_value = mocked_channel

        channel = hook.get_conn()

        self.assertEqual(channel, mocked_channel)
        mock_intercept_channel.assert_called_once_with(mocked_channel, "test1")
Exemple #19
0
    def test_connection_with_interceptors(self,
                                          mock_insecure_channel,
                                          mock_get_connection,
                                          mock_intercept_channel):
        conn = get_airflow_connection()
        mock_get_connection.return_value = conn
        mocked_channel = self.channel_mock.return_value
        hook = GrpcHook("grpc_default", interceptors=["test1"])
        mock_insecure_channel.return_value = mocked_channel
        mock_intercept_channel.return_value = mocked_channel

        channel = hook.get_conn()

        self.assertEquals(channel, mocked_channel)
        mock_intercept_channel.assert_called_once_with(mocked_channel, "test1")
Exemple #20
0
 def _get_grpc_hook(self):
     return GrpcHook(self.grpc_conn_id,
                     interceptors=self.interceptors,
                     custom_connection_func=self.custom_connection_func)
 def get_hook(self):
     if self.conn_type == 'mysql':
         from airflow.hooks.mysql_hook import MySqlHook
         return MySqlHook(mysql_conn_id=self.conn_id)
     elif self.conn_type == 'google_cloud_platform':
         from airflow.gcp.hooks.bigquery import BigQueryHook
         return BigQueryHook(bigquery_conn_id=self.conn_id)
     elif self.conn_type == 'postgres':
         from airflow.hooks.postgres_hook import PostgresHook
         return PostgresHook(postgres_conn_id=self.conn_id)
     elif self.conn_type == 'pig_cli':
         from airflow.hooks.pig_hook import PigCliHook
         return PigCliHook(pig_cli_conn_id=self.conn_id)
     elif self.conn_type == 'hive_cli':
         from airflow.hooks.hive_hooks import HiveCliHook
         return HiveCliHook(hive_cli_conn_id=self.conn_id)
     elif self.conn_type == 'presto':
         from airflow.hooks.presto_hook import PrestoHook
         return PrestoHook(presto_conn_id=self.conn_id)
     elif self.conn_type == 'hiveserver2':
         from airflow.hooks.hive_hooks import HiveServer2Hook
         return HiveServer2Hook(hiveserver2_conn_id=self.conn_id)
     elif self.conn_type == 'sqlite':
         from airflow.hooks.sqlite_hook import SqliteHook
         return SqliteHook(sqlite_conn_id=self.conn_id)
     elif self.conn_type == 'jdbc':
         from airflow.hooks.jdbc_hook import JdbcHook
         return JdbcHook(jdbc_conn_id=self.conn_id)
     elif self.conn_type == 'mssql':
         from airflow.hooks.mssql_hook import MsSqlHook
         return MsSqlHook(mssql_conn_id=self.conn_id)
     elif self.conn_type == 'oracle':
         from airflow.hooks.oracle_hook import OracleHook
         return OracleHook(oracle_conn_id=self.conn_id)
     elif self.conn_type == 'vertica':
         from airflow.contrib.hooks.vertica_hook import VerticaHook
         return VerticaHook(vertica_conn_id=self.conn_id)
     elif self.conn_type == 'cloudant':
         from airflow.contrib.hooks.cloudant_hook import CloudantHook
         return CloudantHook(cloudant_conn_id=self.conn_id)
     elif self.conn_type == 'jira':
         from airflow.contrib.hooks.jira_hook import JiraHook
         return JiraHook(jira_conn_id=self.conn_id)
     elif self.conn_type == 'redis':
         from airflow.contrib.hooks.redis_hook import RedisHook
         return RedisHook(redis_conn_id=self.conn_id)
     elif self.conn_type == 'wasb':
         from airflow.contrib.hooks.wasb_hook import WasbHook
         return WasbHook(wasb_conn_id=self.conn_id)
     elif self.conn_type == 'docker':
         from airflow.hooks.docker_hook import DockerHook
         return DockerHook(docker_conn_id=self.conn_id)
     elif self.conn_type == 'azure_data_lake':
         from airflow.contrib.hooks.azure_data_lake_hook import AzureDataLakeHook
         return AzureDataLakeHook(azure_data_lake_conn_id=self.conn_id)
     elif self.conn_type == 'azure_cosmos':
         from airflow.contrib.hooks.azure_cosmos_hook import AzureCosmosDBHook
         return AzureCosmosDBHook(azure_cosmos_conn_id=self.conn_id)
     elif self.conn_type == 'cassandra':
         from airflow.contrib.hooks.cassandra_hook import CassandraHook
         return CassandraHook(cassandra_conn_id=self.conn_id)
     elif self.conn_type == 'mongo':
         from airflow.contrib.hooks.mongo_hook import MongoHook
         return MongoHook(conn_id=self.conn_id)
     elif self.conn_type == 'gcpcloudsql':
         from airflow.gcp.hooks.cloud_sql import CloudSqlDatabaseHook
         return CloudSqlDatabaseHook(gcp_cloudsql_conn_id=self.conn_id)
     elif self.conn_type == 'grpc':
         from airflow.contrib.hooks.grpc_hook import GrpcHook
         return GrpcHook(grpc_conn_id=self.conn_id)
     raise AirflowException("Unknown hook type {}".format(self.conn_type))