Esempio n. 1
0
    def test_get_lb_policy(self):
        # test LB policies with no args
        self._assert_get_lb_policy('RoundRobinPolicy', {}, RoundRobinPolicy)
        self._assert_get_lb_policy('DCAwareRoundRobinPolicy', {},
                                   DCAwareRoundRobinPolicy)
        self._assert_get_lb_policy('TokenAwarePolicy', {},
                                   TokenAwarePolicy,
                                   expected_child_policy_type=RoundRobinPolicy)

        # test DCAwareRoundRobinPolicy with args
        self._assert_get_lb_policy('DCAwareRoundRobinPolicy', {
            'local_dc': 'foo',
            'used_hosts_per_remote_dc': '3'
        }, DCAwareRoundRobinPolicy)

        # test WhiteListRoundRobinPolicy with args
        fake_addr_info = [[
            'family', 'sockettype', 'proto', 'canonname',
            ('2606:2800:220:1:248:1893:25c8:1946', 80, 0, 0)
        ]]
        with patch('socket.getaddrinfo', return_value=fake_addr_info):
            self._assert_get_lb_policy('WhiteListRoundRobinPolicy',
                                       {'hosts': ['host1', 'host2']},
                                       WhiteListRoundRobinPolicy)

        # test TokenAwarePolicy with args
        with patch('socket.getaddrinfo', return_value=fake_addr_info):
            self._assert_get_lb_policy(
                'TokenAwarePolicy', {
                    'child_load_balancing_policy': 'WhiteListRoundRobinPolicy',
                    'child_load_balancing_policy_args': {
                        'hosts': ['host-1', 'host-2']
                    }
                },
                TokenAwarePolicy,
                expected_child_policy_type=WhiteListRoundRobinPolicy)

        # test invalid policy name should default to RoundRobinPolicy
        self._assert_get_lb_policy('DoesNotExistPolicy', {}, RoundRobinPolicy)

        # test invalid child policy name should default child policy to RoundRobinPolicy
        self._assert_get_lb_policy('TokenAwarePolicy', {},
                                   TokenAwarePolicy,
                                   expected_child_policy_type=RoundRobinPolicy)
        self._assert_get_lb_policy(
            'TokenAwarePolicy',
            {'child_load_balancing_policy': 'DoesNotExistPolicy'},
            TokenAwarePolicy,
            expected_child_policy_type=RoundRobinPolicy)

        # test host not specified for WhiteListRoundRobinPolicy should throw exception
        self._assert_get_lb_policy('WhiteListRoundRobinPolicy', {},
                                   WhiteListRoundRobinPolicy,
                                   should_throw=True)
        self._assert_get_lb_policy(
            'TokenAwarePolicy',
            {'child_load_balancing_policy': 'WhiteListRoundRobinPolicy'},
            TokenAwarePolicy,
            expected_child_policy_type=RoundRobinPolicy,
            should_throw=True)
Esempio n. 2
0
    def test_hive_to_mysql(self):
        test_hive_results = 'test_hive_results'

        mock_hive_hook = MockHiveServer2Hook()
        mock_hive_hook.get_records = MagicMock(return_value=test_hive_results)

        mock_mysql_hook = MockMySqlHook()
        mock_mysql_hook.run = MagicMock()
        mock_mysql_hook.insert_rows = MagicMock()

        with patch('airflow.operators.hive_to_mysql.HiveServer2Hook',
                   return_value=mock_hive_hook):
            with patch('airflow.operators.hive_to_mysql.MySqlHook',
                       return_value=mock_mysql_hook):

                op = HiveToMySqlTransfer(
                    mysql_conn_id='airflow_db',
                    task_id='hive_to_mysql_check',
                    sql="""
                        SELECT name
                        FROM airflow.static_babynames
                        LIMIT 100
                        """,
                    mysql_table='test_static_babynames',
                    mysql_preoperator=[
                        'DROP TABLE IF EXISTS test_static_babynames;',
                        'CREATE TABLE test_static_babynames (name VARCHAR(500))',
                    ],
                    dag=self.dag)
                op.clear(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
                op.run(start_date=DEFAULT_DATE,
                       end_date=DEFAULT_DATE,
                       ignore_ti_state=True)

        raw_select_name_query = mock_hive_hook.get_records.call_args_list[0][
            0][0]
        actual_select_name_query = re.sub(r'\s{2,}', ' ',
                                          raw_select_name_query).strip()
        expected_select_name_query = 'SELECT name FROM airflow.static_babynames LIMIT 100'
        self.assertEqual(expected_select_name_query, actual_select_name_query)

        actual_hive_conf = mock_hive_hook.get_records.call_args_list[0][1][
            'hive_conf']
        expected_hive_conf = {
            'airflow.ctx.dag_owner': 'airflow',
            'airflow.ctx.dag_id': 'test_dag_id',
            'airflow.ctx.task_id': 'hive_to_mysql_check',
            'airflow.ctx.execution_date': '2015-01-01T00:00:00+00:00'
        }
        self.assertEqual(expected_hive_conf, actual_hive_conf)

        expected_mysql_preoperator = [
            'DROP TABLE IF EXISTS test_static_babynames;',
            'CREATE TABLE test_static_babynames (name VARCHAR(500))'
        ]
        mock_mysql_hook.run.assert_called_with(expected_mysql_preoperator)

        mock_mysql_hook.insert_rows.assert_called_with(
            table='test_static_babynames', rows=test_hive_results)
    def test_skip_serve_logs_on_worker_start(self):
        with patch(
                'airflow.cli.commands.worker_command.Process') as mock_popen:
            args = self.parser.parse_args(['worker', '-c', '-1', '-s'])

            with patch('celery.platforms.check_privileges') as mock_privil:
                mock_privil.return_value = 0
                worker_command.worker(args)
                mock_popen.assert_not_called()
    def test_skip_serve_logs_on_worker_start(self):
        with patch('airflow.cli.commands.worker_command.subprocess.Popen') as mock_popen:
            mock_popen.return_value.communicate.return_value = (b'output', b'error')
            mock_popen.return_value.returncode = 0
            args = self.parser.parse_args(['worker', '-c', '-1', '-s'])

            with patch('celery.platforms.check_privileges') as mock_privil:
                mock_privil.return_value = 0
                worker_command.worker(args)
                mock_popen.assert_not_called()
 def setUp(self):
     with patch(
             "airflow.contrib.hooks.gcp_api_base_hook.GoogleCloudBaseHook.__init__",
             new=mock_base_gcp_hook_default_project_id,
     ):
         self.gcp_text_to_speech_hook = GCPTextToSpeechHook(
             gcp_conn_id="test")
Esempio n. 6
0
 def setUp(self):
     with patch(
             "airflow.gcp.hooks.base.CloudBaseHook.__init__",
             new=mock_base_gcp_hook_default_project_id,
     ):
         self.gcp_text_to_speech_hook = CloudTextToSpeechHook(
             gcp_conn_id="test")
Esempio n. 7
0
def test_s3_server_ignore_subdomain_for_bucketnames():
    with patch("moto.s3.responses.S3_IGNORE_SUBDOMAIN_BUCKETNAME", True):
        test_client = authenticated_client()

        res = test_client.put("/mybucket", "http://foobaz.localhost:5000/")
        res.status_code.should.equal(200)
        res.data.should.contain(b"mybucket")
Esempio n. 8
0
    def setUp(self):
        self._upload_dataframe()
        args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}
        self.dag = DAG('test_dag_id', default_args=args)
        self.database = 'airflow'
        self.table = 'hive_server_hook'

        self.hql = """
        CREATE DATABASE IF NOT EXISTS {{ params.database }};
        USE {{ params.database }};
        DROP TABLE IF EXISTS {{ params.table }};
        CREATE TABLE IF NOT EXISTS {{ params.table }} (
            a int,
            b int)
        ROW FORMAT DELIMITED
        FIELDS TERMINATED BY ',';
        LOAD DATA LOCAL INPATH '{{ params.csv_path }}'
        OVERWRITE INTO TABLE {{ params.table }};
        """
        self.columns = ['{}.a'.format(self.table),
                        '{}.b'.format(self.table)]

        with patch('airflow.hooks.hive_hooks.HiveMetastoreHook.get_metastore_client') \
                as get_metastore_mock:
            get_metastore_mock.return_value = MagicMock()

            self.hook = HiveMetastoreHook()
Esempio n. 9
0
 def setUp(self):
     with patch(
             "airflow.gcp.hooks.base.GoogleCloudBaseHook.__init__",
             new=mock_base_gcp_hook_default_project_id,
     ):
         self.gcp_speech_to_text_hook = GCPSpeechToTextHook(
             gcp_conn_id="test")
Esempio n. 10
0
    def test_get_conn(self):
        with patch('airflow.hooks.hive_hooks.HiveMetastoreHook._find_valid_server') \
                as find_valid_server:
            find_valid_server.return_value = MagicMock(return_value={})
            metastore_hook = HiveMetastoreHook()

        self.assertIsInstance(metastore_hook.get_conn(), HMSClient)
    def test_runs_for_hive_stats(self, mock_hive_metastore_hook):
        mock_mysql_hook = MockMySqlHook()
        mock_presto_hook = MockPrestoHook()
        with patch('airflow.operators.hive_stats_operator.PrestoHook',
                   return_value=mock_presto_hook):
            with patch('airflow.operators.hive_stats_operator.MySqlHook',
                       return_value=mock_mysql_hook):
                op = HiveStatsCollectionOperator(
                    task_id='hive_stats_check',
                    table="airflow.static_babynames_partitioned",
                    partition={'ds': DEFAULT_DATE_DS},
                    dag=self.dag)
                op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
                       ignore_ti_state=True)

        select_count_query = "SELECT COUNT(*) AS __count FROM airflow." \
            + "static_babynames_partitioned WHERE ds = '2015-01-01';"
        mock_presto_hook.get_first.assert_called_with(hql=select_count_query)

        expected_stats_select_query = "SELECT 1 FROM hive_stats WHERE table_name='airflow." \
            + "static_babynames_partitioned' AND " \
            + "partition_repr='{\"ds\": \"2015-01-01\"}' AND " \
            + "dttm='2015-01-01T00:00:00+00:00' " \
            + "LIMIT 1;"

        raw_stats_select_query = mock_mysql_hook.get_records.call_args_list[0][0][0]
        actual_stats_select_query = re.sub(r'\s{2,}', ' ', raw_stats_select_query).strip()

        self.assertEqual(expected_stats_select_query, actual_stats_select_query)

        insert_rows_val = [('2015-01-01', '2015-01-01T00:00:00+00:00',
                            'airflow.static_babynames_partitioned',
                            '{"ds": "2015-01-01"}', '', 'count', ['val_0', 'val_1'])]

        mock_mysql_hook.insert_rows.assert_called_with(table='hive_stats',
                                                       rows=insert_rows_val,
                                                       target_fields=[
                                                           'ds',
                                                           'dttm',
                                                           'table_name',
                                                           'partition_repr',
                                                           'col',
                                                           'metric',
                                                           'value',
                                                       ])
    def test_get_lb_policy_with_args(self):
        # test DCAwareRoundRobinPolicy with args
        self._assert_get_lb_policy('DCAwareRoundRobinPolicy',
                                   {'local_dc': 'foo', 'used_hosts_per_remote_dc': '3'},
                                   DCAwareRoundRobinPolicy)

        # test WhiteListRoundRobinPolicy with args
        fake_addr_info = [['family', 'sockettype', 'proto', 'canonname', ('2606:2800:220:1:248:1893:25c8:1946', 80, 0, 0)]] # noqa
        with patch('socket.getaddrinfo', return_value=fake_addr_info):
            self._assert_get_lb_policy('WhiteListRoundRobinPolicy',
                                       {'hosts': ['host1', 'host2']},
                                       WhiteListRoundRobinPolicy)

        # test TokenAwarePolicy with args
        with patch('socket.getaddrinfo', return_value=fake_addr_info):
            self._assert_get_lb_policy('TokenAwarePolicy',
                                       {'child_load_balancing_policy': 'WhiteListRoundRobinPolicy',  # noqa
                                        'child_load_balancing_policy_args': {'hosts': ['host-1', 'host-2']}},  # noqa
                                       TokenAwarePolicy,
                                       expected_child_policy_type=WhiteListRoundRobinPolicy)  # noqa
Esempio n. 13
0
    def setUp(self):
        self.next_day = (DEFAULT_DATE +
                         datetime.timedelta(days=1)).isoformat()[:10]
        self.database = 'airflow'
        self.partition_by = 'ds'
        self.table = 'static_babynames_partitioned'
        with patch('airflow.hooks.hive_hooks.HiveMetastoreHook.get_metastore_client') \
                as get_metastore_mock:
            get_metastore_mock.return_value = MagicMock()

            self.hook = HiveMetastoreHook()
Esempio n. 14
0
def test_socktype_bad_python_version_regression():
    """ Some versions of python accidentally internally shadowed the SockType
    variable, so it was no longer the socket object but and int Enum representing
    the socket type e.g. AF_INET. Make sure we don't patch SockType in these cases
    https://bugs.python.org/issue20386
    """
    import socket
    someObject = object()
    with patch('socket.SocketType', someObject):
        HTTPretty.enable()
        expect(socket.SocketType).to.equal(someObject)
        HTTPretty.disable()
    def test_hive2samba(self, mock_hive_server_hook, mock_temp_dir):
        mock_temp_dir.return_value = "tst"

        samba_hook = MockSambaHook(self.kwargs['samba_conn_id'])
        samba_hook.upload = MagicMock()

        with patch('airflow.operators.hive_to_samba_operator.SambaHook',
                   return_value=samba_hook):
            samba_hook.conn.upload = MagicMock()
            op = Hive2SambaOperator(
                task_id='hive2samba_check',
                samba_conn_id='tableau_samba',
                hql="SELECT * FROM airflow.static_babynames LIMIT 10000",
                destination_filepath='test_airflow.csv',
                dag=self.dag)
            op.run(start_date=DEFAULT_DATE,
                   end_date=DEFAULT_DATE,
                   ignore_ti_state=True)

        samba_hook.conn.upload.assert_called_with('/tmp/tmptst',
                                                  'test_airflow.csv')
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

import unittest
import sqlalchemy
import airflow
from argparse import Namespace
from tests.compat import mock, patch

patch('airflow.utils.cli.action_logging', lambda x: x).start()
from airflow.bin import cli # noqa
mock_args = Namespace(queues=1, concurrency=1)


class TestWorkerPrecheck(unittest.TestCase):

    def setUp(self):
        airflow.configuration.load_test_config()

    @mock.patch('airflow.settings.validate_session')
    def test_error(self, mock_validate_session):
        """
        Test to verify the exit mechanism of airflow-worker cli
        by mocking validate_session method
        """
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

import unittest
from argparse import Namespace

import sqlalchemy

import airflow
from airflow.bin import cli  # noqa
from tests.compat import mock, patch
from tests.test_utils.config import conf_vars

patch('airflow.utils.cli.action_logging', lambda x: x).start()
mock_args = Namespace(queues=1, concurrency=1)


class TestWorkerPrecheck(unittest.TestCase):
    @mock.patch('airflow.settings.validate_session')
    def test_error(self, mock_validate_session):
        """
        Test to verify the exit mechanism of airflow-worker cli
        by mocking validate_session method
        """
        mock_validate_session.return_value = False
        with self.assertRaises(SystemExit) as cm:
            # airflow.bin.cli.worker(mock_args)
            cli.worker(mock_args)
        self.assertEqual(cm.exception.code, 1)
 def setUp(self):
     with patch(
         "airflow.contrib.hooks.gcp_api_base_hook.GoogleCloudBaseHook.__init__",
         new=mock_base_gcp_hook_default_project_id,
     ):
         self.gcp_text_to_speech_hook = GCPTextToSpeechHook(gcp_conn_id="test")
Esempio n. 19
0
 def setUp(self):
     clear_db_runs()
     patcher = patch('airflow.jobs.base_job.sleep')
     self.addCleanup(patcher.stop)
     self.mock_base_job_sleep = patcher.start()
Esempio n. 20
0
def test_force_ignore_subdomain_for_bucketnames(monkeypatch):
    with patch("moto.s3.utils.S3_IGNORE_SUBDOMAIN_BUCKETNAME", True):
        expect(
            bucket_name_from_url(
                "https://subdomain.localhost:5000/abc/resource")).should.equal(
                    None)
Esempio n. 21
0
 def decorated(*args, **kwargs):
     ELASTIC_INSTANCES.clear()
     with patch('elasticsearch.Elasticsearch', _get_elasticmock):
         result = f(*args, **kwargs)
     return result
Esempio n. 22
0
def test_socktype_good_python_version():
    import socket
    with patch('socket.SocketType', socket.socket):
        HTTPretty.enable()
        expect(socket.SocketType).to.equal(socket.socket)
        HTTPretty.disable()