Example #1
0
    def test_connection_type_not_supported(self, mock_get_connection):
        conn = get_airflow_connection("NOT_SUPPORT")
        mock_get_connection.return_value = conn
        hook = GrpcHook("grpc_default")

        with self.assertRaises(AirflowConfigException):
            hook.get_conn()
Example #2
0
    def test_custom_connection_with_no_connection_func(self,
                                                       mock_get_connection):
        conn = get_airflow_connection("CUSTOM")
        mock_get_connection.return_value = conn
        hook = GrpcHook("grpc_default")

        with self.assertRaises(AirflowConfigException):
            hook.get_conn()
Example #3
0
    def test_custom_connection(self, mock_get_connection):
        conn = get_airflow_connection("CUSTOM")
        mock_get_connection.return_value = conn
        mocked_channel = self.channel_mock.return_value
        hook = GrpcHook("grpc_default",
                        custom_connection_func=self.custom_conn_func)

        channel = hook.get_conn()

        self.assertEqual(channel, mocked_channel)
Example #4
0
    def test_connection_with_port(self, mock_get_connection,
                                  mock_insecure_channel):
        conn = get_airflow_connection_with_port()
        mock_get_connection.return_value = conn
        hook = GrpcHook("grpc_default")
        mocked_channel = self.channel_mock.return_value
        mock_insecure_channel.return_value = mocked_channel

        channel = hook.get_conn()
        expected_url = "test.com:1234"

        mock_insecure_channel.assert_called_once_with(expected_url)
        self.assertEqual(channel, mocked_channel)
Example #5
0
    def test_no_auth_connection(self, mock_get_connection,
                                mock_insecure_channel):
        conn = get_airflow_connection()
        mock_get_connection.return_value = conn
        hook = GrpcHook("grpc_default")
        mocked_channel = self.channel_mock.return_value
        mock_insecure_channel.return_value = mocked_channel

        channel = hook.get_conn()
        expected_url = "test:8080"

        mock_insecure_channel.assert_called_once_with(expected_url)
        assert channel == mocked_channel
Example #6
0
    def test_connection_with_interceptors(self, mock_insecure_channel,
                                          mock_get_connection,
                                          mock_intercept_channel):
        conn = get_airflow_connection()
        mock_get_connection.return_value = conn
        mocked_channel = self.channel_mock.return_value
        hook = GrpcHook("grpc_default", interceptors=["test1"])
        mock_insecure_channel.return_value = mocked_channel
        mock_intercept_channel.return_value = mocked_channel

        channel = hook.get_conn()

        self.assertEqual(channel, mocked_channel)
        mock_intercept_channel.assert_called_once_with(mocked_channel, "test1")
Example #7
0
    def test_connection_with_jwt(self, mock_secure_channel,
                                 mock_google_default_auth, mock_google_cred,
                                 mock_get_connection):
        conn = get_airflow_connection(auth_type="JWT_GOOGLE")
        mock_get_connection.return_value = conn
        hook = GrpcHook("grpc_default")
        mocked_channel = self.channel_mock.return_value
        mock_secure_channel.return_value = mocked_channel
        mock_credential_object = "test_credential_object"
        mock_google_default_auth.return_value = (mock_credential_object, "")
        mock_google_cred.return_value = mock_credential_object

        channel = hook.get_conn()
        expected_url = "test:8080"

        mock_google_cred.assert_called_once_with(mock_credential_object)
        mock_secure_channel.assert_called_once_with(mock_credential_object,
                                                    None, expected_url)
        self.assertEqual(channel, mocked_channel)
Example #8
0
    def test_connection_with_tls(self, mock_secure_channel,
                                 mock_channel_credentials, mock_get_connection,
                                 mock_open):
        conn = get_airflow_connection(auth_type="TLS",
                                      credential_pem_file="pem")
        mock_get_connection.return_value = conn
        mock_open.return_value = StringIO('credential')
        hook = GrpcHook("grpc_default")
        mocked_channel = self.channel_mock.return_value
        mock_secure_channel.return_value = mocked_channel
        mock_credential_object = "test_credential_object"
        mock_channel_credentials.return_value = mock_credential_object

        channel = hook.get_conn()
        expected_url = "test:8080"

        mock_open.assert_called_once_with("pem")
        mock_channel_credentials.assert_called_once_with('credential')
        mock_secure_channel.assert_called_once_with(expected_url,
                                                    mock_credential_object)
        self.assertEqual(channel, mocked_channel)
Example #9
0
    def test_connection_with_google_oauth(self, mock_secure_channel,
                                          mock_google_default_auth,
                                          mock_google_auth_request,
                                          mock_get_connection):
        conn = get_airflow_connection(auth_type="OATH_GOOGLE",
                                      scopes="grpc,gcs")
        mock_get_connection.return_value = conn
        hook = GrpcHook("grpc_default")
        mocked_channel = self.channel_mock.return_value
        mock_secure_channel.return_value = mocked_channel
        mock_credential_object = "test_credential_object"
        mock_google_default_auth.return_value = (mock_credential_object, "")
        mock_google_auth_request.return_value = "request"

        channel = hook.get_conn()
        expected_url = "test:8080"

        mock_google_default_auth.assert_called_once_with(
            scopes=["grpc", "gcs"])
        mock_secure_channel.assert_called_once_with(mock_credential_object,
                                                    "request", expected_url)
        self.assertEqual(channel, mocked_channel)