def test_get_partitions(self, mock_get_conn):
        response = [{
            'Partitions': [{
                'Values': ['2015-01-01']
            }]
        }]
        mock_paginator = mock.Mock()
        mock_paginator.paginate.return_value = response
        mock_conn = mock.Mock()
        mock_conn.get_paginator.return_value = mock_paginator
        mock_get_conn.return_value = mock_conn
        hook = AwsGlueCatalogHook(region_name="us-east-1")
        result = hook.get_partitions('db',
                                     'tbl',
                                     expression='foo=bar',
                                     page_size=2,
                                     max_items=3)

        self.assertEqual(result, {('2015-01-01',)})
        mock_conn.get_paginator.assert_called_once_with('get_partitions')
        mock_paginator.paginate.assert_called_once_with(DatabaseName='db',
                                                        TableName='tbl',
                                                        Expression='foo=bar',
                                                        PaginationConfig={
                                                            'PageSize': 2,
                                                            'MaxItems': 3})
示例#2
0
    def test_check_for_partition(self, mock_get_partitions):
        mock_get_partitions.return_value = {('2018-01-01', )}
        hook = AwsGlueCatalogHook(region_name="us-east-1")

        self.assertTrue(hook.check_for_partition('db', 'tbl', 'expr'))
        mock_get_partitions.assert_called_once_with('db',
                                                    'tbl',
                                                    'expr',
                                                    max_items=1)
    def get_hook(self) -> AwsGlueCatalogHook:
        """Gets the AwsGlueCatalogHook"""
        if self.hook:
            return self.hook

        self.hook = AwsGlueCatalogHook(aws_conn_id=self.aws_conn_id,
                                       region_name=self.region_name)
        return self.hook
    def get_hook(self):
        """
        Gets the AwsGlueCatalogHook
        """
        if not hasattr(self, 'hook'):
            from airflow.providers.amazon.aws.hooks.glue_catalog import AwsGlueCatalogHook
            self.hook = AwsGlueCatalogHook(aws_conn_id=self.aws_conn_id,
                                           region_name=self.region_name)

        return self.hook
    def test_get_partitions_empty(self, mock_get_conn):
        response = set()
        mock_get_conn.get_paginator.paginate.return_value = response
        hook = AwsGlueCatalogHook(region_name="us-east-1")

        self.assertEqual(hook.get_partitions('db', 'tbl'), set())
 def test_region(self):
     hook = AwsGlueCatalogHook(region_name="us-west-2")
     self.assertEqual(hook.region_name, 'us-west-2')
 def test_conn_id(self):
     hook = AwsGlueCatalogHook(aws_conn_id='my_aws_conn_id', region_name="us-east-1")
     self.assertEqual(hook.aws_conn_id, 'my_aws_conn_id')
 def test_get_conn_returns_a_boto3_connection(self):
     hook = AwsGlueCatalogHook(region_name="us-east-1")
     self.assertIsNotNone(hook.get_conn())
 def setUp(self):
     self.client = boto3.client('glue', region_name='us-east-1')
     self.hook = AwsGlueCatalogHook(region_name="us-east-1")
class TestAwsGlueCatalogHook(unittest.TestCase):

    @mock_glue
    def setUp(self):
        self.client = boto3.client('glue', region_name='us-east-1')
        self.hook = AwsGlueCatalogHook(region_name="us-east-1")

    @mock_glue
    def test_get_conn_returns_a_boto3_connection(self):
        hook = AwsGlueCatalogHook(region_name="us-east-1")
        self.assertIsNotNone(hook.get_conn())

    @mock_glue
    def test_conn_id(self):
        hook = AwsGlueCatalogHook(aws_conn_id='my_aws_conn_id', region_name="us-east-1")
        self.assertEqual(hook.aws_conn_id, 'my_aws_conn_id')

    @mock_glue
    def test_region(self):
        hook = AwsGlueCatalogHook(region_name="us-west-2")
        self.assertEqual(hook.region_name, 'us-west-2')

    @mock_glue
    @mock.patch.object(AwsGlueCatalogHook, 'get_conn')
    def test_get_partitions_empty(self, mock_get_conn):
        response = set()
        mock_get_conn.get_paginator.paginate.return_value = response
        hook = AwsGlueCatalogHook(region_name="us-east-1")

        self.assertEqual(hook.get_partitions('db', 'tbl'), set())

    @mock_glue
    @mock.patch.object(AwsGlueCatalogHook, 'get_conn')
    def test_get_partitions(self, mock_get_conn):
        response = [{
            'Partitions': [{
                'Values': ['2015-01-01']
            }]
        }]
        mock_paginator = mock.Mock()
        mock_paginator.paginate.return_value = response
        mock_conn = mock.Mock()
        mock_conn.get_paginator.return_value = mock_paginator
        mock_get_conn.return_value = mock_conn
        hook = AwsGlueCatalogHook(region_name="us-east-1")
        result = hook.get_partitions('db',
                                     'tbl',
                                     expression='foo=bar',
                                     page_size=2,
                                     max_items=3)

        self.assertEqual(result, {('2015-01-01',)})
        mock_conn.get_paginator.assert_called_once_with('get_partitions')
        mock_paginator.paginate.assert_called_once_with(DatabaseName='db',
                                                        TableName='tbl',
                                                        Expression='foo=bar',
                                                        PaginationConfig={
                                                            'PageSize': 2,
                                                            'MaxItems': 3})

    @mock_glue
    @mock.patch.object(AwsGlueCatalogHook, 'get_partitions')
    def test_check_for_partition(self, mock_get_partitions):
        mock_get_partitions.return_value = {('2018-01-01',)}
        hook = AwsGlueCatalogHook(region_name="us-east-1")

        self.assertTrue(hook.check_for_partition('db', 'tbl', 'expr'))
        mock_get_partitions.assert_called_once_with('db', 'tbl', 'expr', max_items=1)

    @mock_glue
    @mock.patch.object(AwsGlueCatalogHook, 'get_partitions')
    def test_check_for_partition_false(self, mock_get_partitions):
        mock_get_partitions.return_value = set()
        hook = AwsGlueCatalogHook(region_name="us-east-1")

        self.assertFalse(hook.check_for_partition('db', 'tbl', 'expr'))

    @mock_glue
    def test_get_table_exists(self):
        self.client.create_database(
            DatabaseInput={
                'Name': DB_NAME
            }
        )
        self.client.create_table(
            DatabaseName=DB_NAME,
            TableInput=TABLE_INPUT
        )

        result = self.hook.get_table(DB_NAME, TABLE_NAME)

        self.assertEqual(result['Name'], TABLE_INPUT['Name'])
        self.assertEqual(result['StorageDescriptor']['Location'],
                         TABLE_INPUT['StorageDescriptor']['Location'])

    @mock_glue
    def test_get_table_not_exists(self):
        self.client.create_database(
            DatabaseInput={
                'Name': DB_NAME
            }
        )
        self.client.create_table(
            DatabaseName=DB_NAME,
            TableInput=TABLE_INPUT
        )

        with self.assertRaises(Exception):
            self.hook.get_table(DB_NAME, 'dummy_table')

    @mock_glue
    def test_get_table_location(self):
        self.client.create_database(
            DatabaseInput={
                'Name': DB_NAME
            }
        )
        self.client.create_table(
            DatabaseName=DB_NAME,
            TableInput=TABLE_INPUT
        )

        result = self.hook.get_table_location(DB_NAME, TABLE_NAME)
        self.assertEqual(result, TABLE_INPUT['StorageDescriptor']['Location'])
    def test_check_for_partition_false(self, mock_get_partitions):
        mock_get_partitions.return_value = set()
        hook = AwsGlueCatalogHook(region_name="us-east-1")

        self.assertFalse(hook.check_for_partition('db', 'tbl', 'expr'))