예제 #1
0
    def execute(self, context):
        self.log.info('Exporting data to Cloud Storage bucket ' + self.bucket)

        if self.overwrite_existing and self.namespace:
            gcs_hook = GoogleCloudStorageHook(self.cloud_storage_conn_id)
            objects = gcs_hook.list(self.bucket, prefix=self.namespace)
            for o in objects:
                gcs_hook.delete(self.bucket, o)

        ds_hook = DatastoreHook(self.datastore_conn_id, self.delegate_to)
        result = ds_hook.export_to_storage_bucket(
            bucket=self.bucket,
            namespace=self.namespace,
            entity_filter=self.entity_filter,
            labels=self.labels)
        operation_name = result['name']
        result = ds_hook.poll_operation_until_done(
            operation_name, self.polling_interval_in_seconds)

        state = result['metadata']['common']['state']
        if state != 'SUCCESSFUL':
            raise AirflowException(
                'Operation failed: result={}'.format(result))

        return result
예제 #2
0
    def execute(self, context):
        self.log.info('Importing data from Cloud Storage bucket %s', self.bucket)
        ds_hook = DatastoreHook(self.datastore_conn_id, self.delegate_to)
        result = ds_hook.import_from_storage_bucket(bucket=self.bucket,
                                                    file=self.file,
                                                    namespace=self.namespace,
                                                    entity_filter=self.entity_filter,
                                                    labels=self.labels)
        operation_name = result['name']
        result = ds_hook.poll_operation_until_done(operation_name,
                                                   self.polling_interval_in_seconds)

        state = result['metadata']['common']['state']
        if state != 'SUCCESSFUL':
            raise AirflowException('Operation failed: result={}'.format(result))

        if self.xcom_push:
            return result
    def execute(self, context):
        self.log.info('Exporting data to Cloud Storage bucket ' + self.bucket)

        if self.overwrite_existing and self.namespace:
            gcs_hook = GoogleCloudStorageHook(self.cloud_storage_conn_id)
            objects = gcs_hook.list(self.bucket, prefix=self.namespace)
            for o in objects:
                gcs_hook.delete(self.bucket, o)

        ds_hook = DatastoreHook(self.datastore_conn_id, self.delegate_to)
        result = ds_hook.export_to_storage_bucket(bucket=self.bucket,
                                                  namespace=self.namespace,
                                                  entity_filter=self.entity_filter,
                                                  labels=self.labels)
        operation_name = result['name']
        result = ds_hook.poll_operation_until_done(operation_name,
                                                   self.polling_interval_in_seconds)

        state = result['metadata']['common']['state']
        if state != 'SUCCESSFUL':
            raise AirflowException('Operation failed: result={}'.format(result))

        return result
예제 #4
0
 def setUp(self):
     with patch(
             'airflow.contrib.hooks.gcp_api_base_hook.GoogleCloudBaseHook.__init__',
             new=mock_init):
         self.datastore_hook = DatastoreHook()
예제 #5
0
class TestDatastoreHook(unittest.TestCase):
    def setUp(self):
        with patch(
                'airflow.contrib.hooks.gcp_api_base_hook.GoogleCloudBaseHook.__init__',
                new=mock_init):
            self.datastore_hook = DatastoreHook()

    @patch('airflow.contrib.hooks.datastore_hook.DatastoreHook._authorize')
    @patch('airflow.contrib.hooks.datastore_hook.build')
    def test_get_conn(self, mock_build, mock_authorize):
        conn = self.datastore_hook.get_conn()

        mock_build.assert_called_once_with('datastore',
                                           'v1',
                                           http=mock_authorize.return_value,
                                           cache_discovery=False)
        self.assertEqual(conn, mock_build.return_value)
        self.assertEqual(conn, self.datastore_hook.connection)

    @patch('airflow.contrib.hooks.datastore_hook.DatastoreHook.get_conn')
    def test_allocate_ids(self, mock_get_conn):
        self.datastore_hook.connection = mock_get_conn.return_value
        partial_keys = []

        keys = self.datastore_hook.allocate_ids(partial_keys)

        projects = self.datastore_hook.connection.projects
        projects.assert_called_once_with()
        allocate_ids = projects.return_value.allocateIds
        allocate_ids.assert_called_once_with(
            projectId=self.datastore_hook.project_id,
            body={'keys': partial_keys})
        execute = allocate_ids.return_value.execute
        execute.assert_called_once_with()
        self.assertEqual(keys, execute.return_value['keys'])

    @patch('airflow.contrib.hooks.datastore_hook.DatastoreHook.get_conn')
    def test_begin_transaction(self, mock_get_conn):
        self.datastore_hook.connection = mock_get_conn.return_value

        transaction = self.datastore_hook.begin_transaction()

        projects = self.datastore_hook.connection.projects
        projects.assert_called_once_with()
        begin_transaction = projects.return_value.beginTransaction
        begin_transaction.assert_called_once_with(
            projectId=self.datastore_hook.project_id, body={})
        execute = begin_transaction.return_value.execute
        execute.assert_called_once_with()
        self.assertEqual(transaction, execute.return_value['transaction'])

    @patch('airflow.contrib.hooks.datastore_hook.DatastoreHook.get_conn')
    def test_commit(self, mock_get_conn):
        self.datastore_hook.connection = mock_get_conn.return_value
        body = {'item': 'a'}

        resp = self.datastore_hook.commit(body)

        projects = self.datastore_hook.connection.projects
        projects.assert_called_once_with()
        commit = projects.return_value.commit
        commit.assert_called_once_with(
            projectId=self.datastore_hook.project_id, body=body)
        execute = commit.return_value.execute
        execute.assert_called_once_with()
        self.assertEqual(resp, execute.return_value)

    @patch('airflow.contrib.hooks.datastore_hook.DatastoreHook.get_conn')
    def test_lookup(self, mock_get_conn):
        self.datastore_hook.connection = mock_get_conn.return_value
        keys = []
        read_consistency = 'ENUM'
        transaction = 'transaction'

        resp = self.datastore_hook.lookup(keys, read_consistency, transaction)

        projects = self.datastore_hook.connection.projects
        projects.assert_called_once_with()
        lookup = projects.return_value.lookup
        lookup.assert_called_once_with(
            projectId=self.datastore_hook.project_id,
            body={
                'keys': keys,
                'readConsistency': read_consistency,
                'transaction': transaction
            })
        execute = lookup.return_value.execute
        execute.assert_called_once_with()
        self.assertEqual(resp, execute.return_value)

    @patch('airflow.contrib.hooks.datastore_hook.DatastoreHook.get_conn')
    def test_rollback(self, mock_get_conn):
        self.datastore_hook.connection = mock_get_conn.return_value
        transaction = 'transaction'

        self.datastore_hook.rollback(transaction)

        projects = self.datastore_hook.connection.projects
        projects.assert_called_once_with()
        rollback = projects.return_value.rollback
        rollback.assert_called_once_with(
            projectId=self.datastore_hook.project_id,
            body={'transaction': transaction})
        execute = rollback.return_value.execute
        execute.assert_called_once_with()

    @patch('airflow.contrib.hooks.datastore_hook.DatastoreHook.get_conn')
    def test_run_query(self, mock_get_conn):
        self.datastore_hook.connection = mock_get_conn.return_value
        body = {'item': 'a'}

        resp = self.datastore_hook.run_query(body)

        projects = self.datastore_hook.connection.projects
        projects.assert_called_once_with()
        run_query = projects.return_value.runQuery
        run_query.assert_called_once_with(
            projectId=self.datastore_hook.project_id, body=body)
        execute = run_query.return_value.execute
        execute.assert_called_once_with()
        self.assertEqual(resp, execute.return_value['batch'])

    @patch('airflow.contrib.hooks.datastore_hook.DatastoreHook.get_conn')
    def test_get_operation(self, mock_get_conn):
        self.datastore_hook.connection = mock_get_conn.return_value
        name = 'name'

        resp = self.datastore_hook.get_operation(name)

        projects = self.datastore_hook.connection.projects
        projects.assert_called_once_with()
        operations = projects.return_value.operations
        operations.assert_called_once_with()
        get = operations.return_value.get
        get.assert_called_once_with(name=name)
        execute = get.return_value.execute
        execute.assert_called_once_with()
        self.assertEqual(resp, execute.return_value)

    @patch('airflow.contrib.hooks.datastore_hook.DatastoreHook.get_conn')
    def test_delete_operation(self, mock_get_conn):
        self.datastore_hook.connection = mock_get_conn.return_value
        name = 'name'

        resp = self.datastore_hook.delete_operation(name)

        projects = self.datastore_hook.connection.projects
        projects.assert_called_once_with()
        operations = projects.return_value.operations
        operations.assert_called_once_with()
        delete = operations.return_value.delete
        delete.assert_called_once_with(name=name)
        execute = delete.return_value.execute
        execute.assert_called_once_with()
        self.assertEqual(resp, execute.return_value)

    @patch('airflow.contrib.hooks.datastore_hook.time.sleep')
    @patch('airflow.contrib.hooks.datastore_hook.DatastoreHook.get_operation',
           side_effect=[{
               'metadata': {
                   'common': {
                       'state': 'PROCESSING'
                   }
               }
           }, {
               'metadata': {
                   'common': {
                       'state': 'NOT PROCESSING'
                   }
               }
           }])
    def test_poll_operation_until_done(self, mock_get_operation,
                                       mock_time_sleep):
        name = 'name'
        polling_interval_in_seconds = 10

        result = self.datastore_hook.poll_operation_until_done(
            name, polling_interval_in_seconds)

        mock_get_operation.assert_has_calls([call(name), call(name)])
        mock_time_sleep.assert_called_once_with(polling_interval_in_seconds)
        self.assertEqual(result,
                         {'metadata': {
                             'common': {
                                 'state': 'NOT PROCESSING'
                             }
                         }})

    @patch('airflow.contrib.hooks.datastore_hook.DatastoreHook.get_conn')
    def test_export_to_storage_bucket(self, mock_get_conn):
        self.datastore_hook.admin_connection = mock_get_conn.return_value
        bucket = 'bucket'
        namespace = None
        entity_filter = {}
        labels = {}

        resp = self.datastore_hook.export_to_storage_bucket(
            bucket, namespace, entity_filter, labels)

        projects = self.datastore_hook.admin_connection.projects
        projects.assert_called_once_with()
        export = projects.return_value.export
        export.assert_called_once_with(
            projectId=self.datastore_hook.project_id,
            body={
                'outputUrlPrefix':
                'gs://' + '/'.join(filter(None, [bucket, namespace])),
                'entityFilter':
                entity_filter,
                'labels':
                labels,
            })
        execute = export.return_value.execute
        execute.assert_called_once_with()
        self.assertEqual(resp, execute.return_value)

    @patch('airflow.contrib.hooks.datastore_hook.DatastoreHook.get_conn')
    def test_import_from_storage_bucket(self, mock_get_conn):
        self.datastore_hook.admin_connection = mock_get_conn.return_value
        bucket = 'bucket'
        file = 'file'
        namespace = None
        entity_filter = {}
        labels = {}

        resp = self.datastore_hook.import_from_storage_bucket(
            bucket, file, namespace, entity_filter, labels)

        projects = self.datastore_hook.admin_connection.projects
        projects.assert_called_once_with()
        import_ = projects.return_value.import_
        import_.assert_called_once_with(
            projectId=self.datastore_hook.project_id,
            body={
                'inputUrl':
                'gs://' + '/'.join(filter(None, [bucket, namespace, file])),
                'entityFilter':
                entity_filter,
                'labels':
                labels,
            })
        execute = import_.return_value.execute
        execute.assert_called_once_with()
        self.assertEqual(resp, execute.return_value)
 def setUp(self):
     with patch('airflow.contrib.hooks.gcp_api_base_hook.GoogleCloudBaseHook.__init__', new=mock_init):
         self.datastore_hook = DatastoreHook()
class TestDatastoreHook(unittest.TestCase):

    def setUp(self):
        with patch('airflow.contrib.hooks.gcp_api_base_hook.GoogleCloudBaseHook.__init__', new=mock_init):
            self.datastore_hook = DatastoreHook()

    @patch('airflow.contrib.hooks.datastore_hook.DatastoreHook._authorize')
    @patch('airflow.contrib.hooks.datastore_hook.build')
    def test_get_conn(self, mock_build, mock_authorize):
        conn = self.datastore_hook.get_conn()

        mock_build.assert_called_once_with('datastore', 'v1', http=mock_authorize.return_value,
                                           cache_discovery=False)
        self.assertEqual(conn, mock_build.return_value)
        self.assertEqual(conn, self.datastore_hook.connection)

    @patch('airflow.contrib.hooks.datastore_hook.DatastoreHook.get_conn')
    def test_allocate_ids(self, mock_get_conn):
        self.datastore_hook.connection = mock_get_conn.return_value
        partial_keys = []

        keys = self.datastore_hook.allocate_ids(partial_keys)

        projects = self.datastore_hook.connection.projects
        projects.assert_called_once_with()
        allocate_ids = projects.return_value.allocateIds
        allocate_ids.assert_called_once_with(projectId=self.datastore_hook.project_id,
                                             body={'keys': partial_keys})
        execute = allocate_ids.return_value.execute
        execute.assert_called_once_with(num_retries=mock.ANY)
        self.assertEqual(keys, execute.return_value['keys'])

    @patch('airflow.contrib.hooks.datastore_hook.DatastoreHook.get_conn')
    def test_begin_transaction(self, mock_get_conn):
        self.datastore_hook.connection = mock_get_conn.return_value

        transaction = self.datastore_hook.begin_transaction()

        projects = self.datastore_hook.connection.projects
        projects.assert_called_once_with()
        begin_transaction = projects.return_value.beginTransaction
        begin_transaction.assert_called_once_with(projectId=self.datastore_hook.project_id, body={})
        execute = begin_transaction.return_value.execute
        execute.assert_called_once_with(num_retries=mock.ANY)
        self.assertEqual(transaction, execute.return_value['transaction'])

    @patch('airflow.contrib.hooks.datastore_hook.DatastoreHook.get_conn')
    def test_commit(self, mock_get_conn):
        self.datastore_hook.connection = mock_get_conn.return_value
        body = {'item': 'a'}

        resp = self.datastore_hook.commit(body)

        projects = self.datastore_hook.connection.projects
        projects.assert_called_once_with()
        commit = projects.return_value.commit
        commit.assert_called_once_with(projectId=self.datastore_hook.project_id, body=body)
        execute = commit.return_value.execute
        execute.assert_called_once_with(num_retries=mock.ANY)
        self.assertEqual(resp, execute.return_value)

    @patch('airflow.contrib.hooks.datastore_hook.DatastoreHook.get_conn')
    def test_lookup(self, mock_get_conn):
        self.datastore_hook.connection = mock_get_conn.return_value
        keys = []
        read_consistency = 'ENUM'
        transaction = 'transaction'

        resp = self.datastore_hook.lookup(keys, read_consistency, transaction)

        projects = self.datastore_hook.connection.projects
        projects.assert_called_once_with()
        lookup = projects.return_value.lookup
        lookup.assert_called_once_with(projectId=self.datastore_hook.project_id,
                                       body={
                                           'keys': keys,
                                           'readConsistency': read_consistency,
                                           'transaction': transaction
                                       })
        execute = lookup.return_value.execute
        execute.assert_called_once_with(num_retries=mock.ANY)
        self.assertEqual(resp, execute.return_value)

    @patch('airflow.contrib.hooks.datastore_hook.DatastoreHook.get_conn')
    def test_rollback(self, mock_get_conn):
        self.datastore_hook.connection = mock_get_conn.return_value
        transaction = 'transaction'

        self.datastore_hook.rollback(transaction)

        projects = self.datastore_hook.connection.projects
        projects.assert_called_once_with()
        rollback = projects.return_value.rollback
        rollback.assert_called_once_with(projectId=self.datastore_hook.project_id,
                                         body={'transaction': transaction})
        execute = rollback.return_value.execute
        execute.assert_called_once_with(num_retries=mock.ANY)

    @patch('airflow.contrib.hooks.datastore_hook.DatastoreHook.get_conn')
    def test_run_query(self, mock_get_conn):
        self.datastore_hook.connection = mock_get_conn.return_value
        body = {'item': 'a'}

        resp = self.datastore_hook.run_query(body)

        projects = self.datastore_hook.connection.projects
        projects.assert_called_once_with()
        run_query = projects.return_value.runQuery
        run_query.assert_called_once_with(projectId=self.datastore_hook.project_id, body=body)
        execute = run_query.return_value.execute
        execute.assert_called_once_with(num_retries=mock.ANY)
        self.assertEqual(resp, execute.return_value['batch'])

    @patch('airflow.contrib.hooks.datastore_hook.DatastoreHook.get_conn')
    def test_get_operation(self, mock_get_conn):
        self.datastore_hook.connection = mock_get_conn.return_value
        name = 'name'

        resp = self.datastore_hook.get_operation(name)

        projects = self.datastore_hook.connection.projects
        projects.assert_called_once_with()
        operations = projects.return_value.operations
        operations.assert_called_once_with()
        get = operations.return_value.get
        get.assert_called_once_with(name=name)
        execute = get.return_value.execute
        execute.assert_called_once_with(num_retries=mock.ANY)
        self.assertEqual(resp, execute.return_value)

    @patch('airflow.contrib.hooks.datastore_hook.DatastoreHook.get_conn')
    def test_delete_operation(self, mock_get_conn):
        self.datastore_hook.connection = mock_get_conn.return_value
        name = 'name'

        resp = self.datastore_hook.delete_operation(name)

        projects = self.datastore_hook.connection.projects
        projects.assert_called_once_with()
        operations = projects.return_value.operations
        operations.assert_called_once_with()
        delete = operations.return_value.delete
        delete.assert_called_once_with(name=name)
        execute = delete.return_value.execute
        execute.assert_called_once_with(num_retries=mock.ANY)
        self.assertEqual(resp, execute.return_value)

    @patch('airflow.contrib.hooks.datastore_hook.time.sleep')
    @patch('airflow.contrib.hooks.datastore_hook.DatastoreHook.get_operation',
           side_effect=[
               {'metadata': {'common': {'state': 'PROCESSING'}}},
               {'metadata': {'common': {'state': 'NOT PROCESSING'}}}
           ])
    def test_poll_operation_until_done(self, mock_get_operation, mock_time_sleep):
        name = 'name'
        polling_interval_in_seconds = 10

        result = self.datastore_hook.poll_operation_until_done(name, polling_interval_in_seconds)

        mock_get_operation.assert_has_calls([call(name), call(name)])
        mock_time_sleep.assert_called_once_with(polling_interval_in_seconds)
        self.assertEqual(result, {'metadata': {'common': {'state': 'NOT PROCESSING'}}})

    @patch('airflow.contrib.hooks.datastore_hook.DatastoreHook.get_conn')
    def test_export_to_storage_bucket(self, mock_get_conn):
        self.datastore_hook.admin_connection = mock_get_conn.return_value
        bucket = 'bucket'
        namespace = None
        entity_filter = {}
        labels = {}

        resp = self.datastore_hook.export_to_storage_bucket(bucket, namespace, entity_filter, labels)

        projects = self.datastore_hook.admin_connection.projects
        projects.assert_called_once_with()
        export = projects.return_value.export
        export.assert_called_once_with(projectId=self.datastore_hook.project_id,
                                       body={
                                           'outputUrlPrefix': 'gs://' + '/'.join(
                                               filter(None, [bucket, namespace])
                                           ),
                                           'entityFilter': entity_filter,
                                           'labels': labels,
                                       })
        execute = export.return_value.execute
        execute.assert_called_once_with(num_retries=mock.ANY)
        self.assertEqual(resp, execute.return_value)

    @patch('airflow.contrib.hooks.datastore_hook.DatastoreHook.get_conn')
    def test_import_from_storage_bucket(self, mock_get_conn):
        self.datastore_hook.admin_connection = mock_get_conn.return_value
        bucket = 'bucket'
        file = 'file'
        namespace = None
        entity_filter = {}
        labels = {}

        resp = self.datastore_hook.import_from_storage_bucket(bucket, file, namespace, entity_filter, labels)

        projects = self.datastore_hook.admin_connection.projects
        projects.assert_called_once_with()
        import_ = projects.return_value.import_
        import_.assert_called_once_with(projectId=self.datastore_hook.project_id,
                                        body={
                                            'inputUrl': 'gs://' + '/'.join(
                                                filter(None, [bucket, namespace, file])
                                            ),
                                            'entityFilter': entity_filter,
                                            'labels': labels,
                                        })
        execute = import_.return_value.execute
        execute.assert_called_once_with(num_retries=mock.ANY)
        self.assertEqual(resp, execute.return_value)