示例#1
0
class CloudSpannerInstanceDatabaseDeleteOperator(BaseOperator):
    """
    Deletes a Cloud Spanner database.

    :param project_id: The ID of the project that owns the Cloud Spanner Database.
    :type project_id: str
    :param instance_id: Cloud Spanner instance ID.
    :type instance_id: str
    :param database_id: Cloud Spanner database ID.
    :type database_id: str
    :param gcp_conn_id: The connection ID used to connect to Google Cloud Platform.
    :type gcp_conn_id: str
    """
    # [START gcp_spanner_database_delete_template_fields]
    template_fields = ('project_id', 'instance_id', 'database_id',
                       'gcp_conn_id')
    # [END gcp_spanner_database_delete_template_fields]

    @apply_defaults
    def __init__(self,
                 project_id,
                 instance_id,
                 database_id,
                 gcp_conn_id='google_cloud_default',
                 *args,
                 **kwargs):
        self.instance_id = instance_id
        self.project_id = project_id
        self.database_id = database_id
        self.gcp_conn_id = gcp_conn_id
        self._validate_inputs()
        self._hook = CloudSpannerHook(gcp_conn_id=gcp_conn_id)
        super(CloudSpannerInstanceDatabaseDeleteOperator,
              self).__init__(*args, **kwargs)

    def _validate_inputs(self):
        if not self.project_id:
            raise AirflowException(
                "The required parameter 'project_id' is empty")
        if not self.instance_id:
            raise AirflowException(
                "The required parameter 'instance_id' is empty")
        if not self.database_id:
            raise AirflowException(
                "The required parameter 'database_id' is empty")

    def execute(self, context):
        db = self._hook.get_database(self.project_id, self.instance_id,
                                     self.database_id)
        if not db:
            self.log.info(
                "The Cloud Spanner database was missing: "
                "'%s' in project '%s' and instance '%s'. Assuming success.",
                self.database_id, self.project_id, self.instance_id)
            return True
        else:
            return self._hook.delete_database(project_id=self.project_id,
                                              instance_id=self.instance_id,
                                              database_id=self.database_id)
class CloudSpannerInstanceDatabaseDeleteOperator(BaseOperator):
    """
    Deletes a Cloud Spanner database.

    :param instance_id: Cloud Spanner instance ID.
    :type instance_id: str
    :param database_id: Cloud Spanner database ID.
    :type database_id: str
    :param project_id: Optional, the ID of the project that owns the Cloud Spanner
        Database.  If set to None or missing, the default project_id from the GCP connection is used.
    :type project_id: str
    :param gcp_conn_id: The connection ID used to connect to Google Cloud Platform.
    :type gcp_conn_id: str
    """
    # [START gcp_spanner_database_delete_template_fields]
    template_fields = ('project_id', 'instance_id', 'database_id',
                       'gcp_conn_id')
    # [END gcp_spanner_database_delete_template_fields]

    @apply_defaults
    def __init__(self,
                 instance_id,
                 database_id,
                 project_id=None,
                 gcp_conn_id='google_cloud_default',
                 *args, **kwargs):
        self.instance_id = instance_id
        self.project_id = project_id
        self.database_id = database_id
        self.gcp_conn_id = gcp_conn_id
        self._validate_inputs()
        self._hook = CloudSpannerHook(gcp_conn_id=gcp_conn_id)
        super(CloudSpannerInstanceDatabaseDeleteOperator, self).__init__(*args, **kwargs)

    def _validate_inputs(self):
        if self.project_id == '':
            raise AirflowException("The required parameter 'project_id' is empty")
        if not self.instance_id:
            raise AirflowException("The required parameter 'instance_id' is empty"
                                   " or None")
        if not self.database_id:
            raise AirflowException("The required parameter 'database_id' is empty"
                                   " or None")

    def execute(self, context):
        db = self._hook.get_database(project_id=self.project_id,
                                     instance_id=self.instance_id,
                                     database_id=self.database_id)
        if not db:
            self.log.info("The Cloud Spanner database was missing: "
                          "'%s' in project '%s' and instance '%s'. Assuming success.",
                          self.database_id, self.project_id, self.instance_id)
            return True
        else:
            return self._hook.delete_database(project_id=self.project_id,
                                              instance_id=self.instance_id,
                                              database_id=self.database_id)
示例#3
0
 def __init__(self,
              instance_id,
              project_id=None,
              gcp_conn_id='google_cloud_default',
              *args, **kwargs):
     self.instance_id = instance_id
     self.project_id = project_id
     self.gcp_conn_id = gcp_conn_id
     self._validate_inputs()
     self._hook = CloudSpannerHook(gcp_conn_id=gcp_conn_id)
     super().__init__(*args, **kwargs)
示例#4
0
class CloudSpannerInstanceDeleteOperator(BaseOperator):
    """
    Deletes a Cloud Spanner instance. If an instance does not exist,
    no action is taken and the operator succeeds.

    .. seealso::
        For more information on how to use this operator, take a look at the guide:
        :ref:`howto/operator:CloudSpannerInstanceDeleteOperator`

    :param instance_id: The Cloud Spanner instance ID.
    :type instance_id: str
    :param project_id: Optional, the ID of the project that owns the Cloud Spanner
        Database.  If set to None or missing, the default project_id from the GCP connection is used.
    :type project_id: str
    :param gcp_conn_id: The connection ID used to connect to Google Cloud Platform.
    :type gcp_conn_id: str
    """
    # [START gcp_spanner_delete_template_fields]
    template_fields = ('project_id', 'instance_id', 'gcp_conn_id')
    # [END gcp_spanner_delete_template_fields]

    @apply_defaults
    def __init__(self,
                 instance_id,
                 project_id=None,
                 gcp_conn_id='google_cloud_default',
                 *args,
                 **kwargs):
        self.instance_id = instance_id
        self.project_id = project_id
        self.gcp_conn_id = gcp_conn_id
        self._validate_inputs()
        self._hook = CloudSpannerHook(gcp_conn_id=gcp_conn_id)
        super(CloudSpannerInstanceDeleteOperator,
              self).__init__(*args, **kwargs)

    def _validate_inputs(self):
        if self.project_id == '':
            raise AirflowException(
                "The required parameter 'project_id' is empty")
        if not self.instance_id:
            raise AirflowException("The required parameter 'instance_id' "
                                   "is empty or None")

    def execute(self, context):
        if self._hook.get_instance(project_id=self.project_id,
                                   instance_id=self.instance_id):
            return self._hook.delete_instance(project_id=self.project_id,
                                              instance_id=self.instance_id)
        else:
            self.log.info(
                "Instance '%s' does not exist in project '%s'. "
                "Aborting delete.", self.instance_id, self.project_id)
            return True
class CloudSpannerInstanceDeleteOperator(BaseOperator):
    """
    Deletes a Cloud Spanner instance. If an instance does not exist,
    no action is taken and the operator succeeds.

    .. seealso::
        For more information on how to use this operator, take a look at the guide:
        :ref:`howto/operator:CloudSpannerInstanceDeleteOperator`

    :param instance_id: The Cloud Spanner instance ID.
    :type instance_id: str
    :param project_id: Optional, the ID of the project that owns the Cloud Spanner
        Database.  If set to None or missing, the default project_id from the GCP connection is used.
    :type project_id: str
    :param gcp_conn_id: The connection ID used to connect to Google Cloud Platform.
    :type gcp_conn_id: str
    """
    # [START gcp_spanner_delete_template_fields]
    template_fields = ('project_id', 'instance_id', 'gcp_conn_id')
    # [END gcp_spanner_delete_template_fields]

    @apply_defaults
    def __init__(self,
                 instance_id,
                 project_id=None,
                 gcp_conn_id='google_cloud_default',
                 *args, **kwargs):
        self.instance_id = instance_id
        self.project_id = project_id
        self.gcp_conn_id = gcp_conn_id
        self._validate_inputs()
        self._hook = CloudSpannerHook(gcp_conn_id=gcp_conn_id)
        super(CloudSpannerInstanceDeleteOperator, self).__init__(*args, **kwargs)

    def _validate_inputs(self):
        if self.project_id == '':
            raise AirflowException("The required parameter 'project_id' is empty")
        if not self.instance_id:
            raise AirflowException("The required parameter 'instance_id' "
                                   "is empty or None")

    def execute(self, context):
        if self._hook.get_instance(project_id=self.project_id, instance_id=self.instance_id):
            return self._hook.delete_instance(project_id=self.project_id,
                                              instance_id=self.instance_id)
        else:
            self.log.info("Instance '%s' does not exist in project '%s'. "
                          "Aborting delete.", self.instance_id, self.project_id)
            return True
示例#6
0
 def __init__(self,
              instance_id,
              database_id,
              project_id=None,
              gcp_conn_id='google_cloud_default',
              *args,
              **kwargs):
     self.instance_id = instance_id
     self.project_id = project_id
     self.database_id = database_id
     self.gcp_conn_id = gcp_conn_id
     self._validate_inputs()
     self._hook = CloudSpannerHook(gcp_conn_id=gcp_conn_id)
     super(CloudSpannerInstanceDatabaseDeleteOperator,
           self).__init__(*args, **kwargs)
示例#7
0
class CloudSpannerInstanceDeleteOperator(BaseOperator):
    """
    Deletes a Cloud Spanner instance.
    If an instance does not exist, no action will be taken and the operator will succeed.

    :param project_id: The ID of the project which owns the instances, tables and data.
    :type project_id: str
    :param instance_id: Cloud Spanner instance ID.
    :type instance_id: str
    :param gcp_conn_id: The connection ID used to connect to Google Cloud Platform.
    :type gcp_conn_id: str
    """
    # [START gcp_spanner_delete_template_fields]
    template_fields = ('project_id', 'instance_id', 'gcp_conn_id')
    # [END gcp_spanner_delete_template_fields]

    @apply_defaults
    def __init__(self,
                 project_id,
                 instance_id,
                 gcp_conn_id='google_cloud_default',
                 *args,
                 **kwargs):
        self.instance_id = instance_id
        self.project_id = project_id
        self.gcp_conn_id = gcp_conn_id
        self._validate_inputs()
        self._hook = CloudSpannerHook(gcp_conn_id=gcp_conn_id)
        super(CloudSpannerInstanceDeleteOperator,
              self).__init__(*args, **kwargs)

    def _validate_inputs(self):
        if not self.project_id:
            raise AirflowException(
                "The required parameter 'project_id' is empty")
        if not self.instance_id:
            raise AirflowException(
                "The required parameter 'instance_id' is empty")

    def execute(self, context):
        if self._hook.get_instance(self.project_id, self.instance_id):
            return self._hook.delete_instance(self.project_id,
                                              self.instance_id)
        else:
            self.log.info(
                "Instance '%s' does not exist in project '%s'. "
                "Aborting delete.", self.instance_id, self.project_id)
            return True
示例#8
0
 def __init__(self,
              instance_id,
              configuration_name,
              node_count,
              display_name,
              project_id=None,
              gcp_conn_id='google_cloud_default',
              *args, **kwargs):
     self.instance_id = instance_id
     self.project_id = project_id
     self.configuration_name = configuration_name
     self.node_count = node_count
     self.display_name = display_name
     self.gcp_conn_id = gcp_conn_id
     self._validate_inputs()
     self._hook = CloudSpannerHook(gcp_conn_id=gcp_conn_id)
     super().__init__(*args, **kwargs)
示例#9
0
 def __init__(self,
              instance_id,
              database_id,
              ddl_statements,
              project_id=None,
              operation_id=None,
              gcp_conn_id='google_cloud_default',
              *args, **kwargs):
     self.instance_id = instance_id
     self.project_id = project_id
     self.database_id = database_id
     self.ddl_statements = ddl_statements
     self.operation_id = operation_id
     self.gcp_conn_id = gcp_conn_id
     self._validate_inputs()
     self._hook = CloudSpannerHook(gcp_conn_id=gcp_conn_id)
     super().__init__(*args, **kwargs)
class CloudSpannerInstanceDeleteOperator(BaseOperator):
    """
    Deletes a Cloud Spanner instance. If an instance does not exist,
    no action is taken and the operator succeeds.

    :param project_id: The ID of the project that owns the Cloud Spanner Database.
    :type project_id: str
    :param instance_id: The Cloud Spanner instance ID.
    :type instance_id: str
    :param gcp_conn_id: The connection ID used to connect to Google Cloud Platform.
    :type gcp_conn_id: str
    """
    # [START gcp_spanner_delete_template_fields]
    template_fields = ('project_id', 'instance_id', 'gcp_conn_id')
    # [END gcp_spanner_delete_template_fields]

    @apply_defaults
    def __init__(self,
                 project_id,
                 instance_id,
                 gcp_conn_id='google_cloud_default',
                 *args, **kwargs):
        self.instance_id = instance_id
        self.project_id = project_id
        self.gcp_conn_id = gcp_conn_id
        self._validate_inputs()
        self._hook = CloudSpannerHook(gcp_conn_id=gcp_conn_id)
        super(CloudSpannerInstanceDeleteOperator, self).__init__(*args, **kwargs)

    def _validate_inputs(self):
        if not self.project_id:
            raise AirflowException("The required parameter 'project_id' is empty")
        if not self.instance_id:
            raise AirflowException("The required parameter 'instance_id' is empty")

    def execute(self, context):
        if self._hook.get_instance(self.project_id, self.instance_id):
            return self._hook.delete_instance(self.project_id,
                                              self.instance_id)
        else:
            self.log.info("Instance '%s' does not exist in project '%s'. "
                          "Aborting delete.", self.instance_id, self.project_id)
            return True
 def __init__(self,
              instance_id,
              project_id=None,
              gcp_conn_id='google_cloud_default',
              *args, **kwargs):
     self.instance_id = instance_id
     self.project_id = project_id
     self.gcp_conn_id = gcp_conn_id
     self._validate_inputs()
     self._hook = CloudSpannerHook(gcp_conn_id=gcp_conn_id)
     super(CloudSpannerInstanceDeleteOperator, self).__init__(*args, **kwargs)
 def __init__(self,
              project_id,
              instance_id,
              database_id,
              query,
              gcp_conn_id='google_cloud_default',
              *args, **kwargs):
     self.instance_id = instance_id
     self.project_id = project_id
     self.database_id = database_id
     self.query = query
     self.gcp_conn_id = gcp_conn_id
     self._validate_inputs()
     self._hook = CloudSpannerHook(gcp_conn_id=gcp_conn_id)
     super(CloudSpannerInstanceDatabaseQueryOperator, self).__init__(*args, **kwargs)
 def __init__(self,
              instance_id,
              database_id,
              ddl_statements,
              project_id=None,
              gcp_conn_id='google_cloud_default',
              *args, **kwargs):
     self.instance_id = instance_id
     self.project_id = project_id
     self.database_id = database_id
     self.ddl_statements = ddl_statements
     self.gcp_conn_id = gcp_conn_id
     self._validate_inputs()
     self._hook = CloudSpannerHook(gcp_conn_id=gcp_conn_id)
     super(CloudSpannerInstanceDatabaseDeployOperator, self).__init__(*args, **kwargs)
 def __init__(self,
              instance_id,
              configuration_name,
              node_count,
              display_name,
              project_id=None,
              gcp_conn_id='google_cloud_default',
              *args, **kwargs):
     self.instance_id = instance_id
     self.project_id = project_id
     self.configuration_name = configuration_name
     self.node_count = node_count
     self.display_name = display_name
     self.gcp_conn_id = gcp_conn_id
     self._validate_inputs()
     self._hook = CloudSpannerHook(gcp_conn_id=gcp_conn_id)
     super(CloudSpannerInstanceDeployOperator, self).__init__(*args, **kwargs)
 def setUp(self):
     with mock.patch('airflow.contrib.hooks.gcp_api_base_hook.GoogleCloudBaseHook.__init__',
                     new=mock_base_gcp_hook_no_default_project_id):
         self.spanner_hook_no_default_project_id = CloudSpannerHook(gcp_conn_id='test')
class TestGcpSpannerHookNoDefaultProjectID(unittest.TestCase):

    def setUp(self):
        with mock.patch('airflow.contrib.hooks.gcp_api_base_hook.GoogleCloudBaseHook.__init__',
                        new=mock_base_gcp_hook_no_default_project_id):
            self.spanner_hook_no_default_project_id = CloudSpannerHook(gcp_conn_id='test')

    @mock.patch('airflow.contrib.hooks.gcp_spanner_hook.CloudSpannerHook._get_client')
    def test_get_existing_instance_missing_project_id(self, get_client):
        instance_method = get_client.return_value.instance
        instance_exists_method = instance_method.return_value.exists
        instance_exists_method.return_value = True
        with self.assertRaises(AirflowException) as cm:
            self.spanner_hook_no_default_project_id.get_instance(instance_id=SPANNER_INSTANCE)
        get_client.assert_not_called()
        instance_method.assert_not_called()
        instance_exists_method.assert_not_called()
        err = cm.exception
        self.assertIn("The project id must be passed", str(err))

    @mock.patch('airflow.contrib.hooks.gcp_spanner_hook.CloudSpannerHook._get_client')
    def test_get_existing_instance_overridden_project_id(self, get_client):
        instance_method = get_client.return_value.instance
        instance_exists_method = instance_method.return_value.exists
        instance_exists_method.return_value = True
        res = self.spanner_hook_no_default_project_id.get_instance(instance_id=SPANNER_INSTANCE,
                                                                   project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST)
        get_client.assert_called_once_with(project_id='example-project')
        instance_method.assert_called_once_with(instance_id='instance')
        self.assertIsNotNone(res)

    @mock.patch('airflow.contrib.hooks.gcp_spanner_hook.CloudSpannerHook._get_client')
    def test_get_non_existing_instance(self, get_client):
        instance_method = get_client.return_value.instance
        instance_exists_method = instance_method.return_value.exists
        instance_exists_method.return_value = False
        res = self.spanner_hook_no_default_project_id.get_instance(instance_id=SPANNER_INSTANCE,
                                                                   project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST)
        get_client.assert_called_once_with(project_id='example-project')
        instance_method.assert_called_once_with(instance_id='instance')
        self.assertIsNone(res)

    @mock.patch('airflow.contrib.hooks.gcp_spanner_hook.CloudSpannerHook._get_client')
    def test_create_instance_missing_project_id(self, get_client):
        instance_method = get_client.return_value.instance
        create_method = instance_method.return_value.create
        create_method.return_value = False
        with self.assertRaises(AirflowException) as cm:
            self.spanner_hook_no_default_project_id.create_instance(
                instance_id=SPANNER_INSTANCE,
                configuration_name=SPANNER_CONFIGURATION,
                node_count=1,
                display_name=SPANNER_DATABASE)
            get_client.assert_called_once_with(project_id='example-project')
        get_client.assert_not_called()
        instance_method.assert_not_called()
        create_method.assert_not_called()
        err = cm.exception
        self.assertIn("The project id must be passed", str(err))

    @mock.patch('airflow.contrib.hooks.gcp_spanner_hook.CloudSpannerHook._get_client')
    def test_create_instance_overridden_project_id(self, get_client):
        instance_method = get_client.return_value.instance
        create_method = instance_method.return_value.create
        create_method.return_value = False
        res = self.spanner_hook_no_default_project_id.create_instance(
            project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST,
            instance_id=SPANNER_INSTANCE,
            configuration_name=SPANNER_CONFIGURATION,
            node_count=1,
            display_name=SPANNER_DATABASE)
        get_client.assert_called_once_with(project_id='example-project')
        instance_method.assert_called_once_with(
            instance_id='instance',
            configuration_name='configuration',
            display_name='database-name',
            node_count=1)
        self.assertIsNone(res)

    @mock.patch('airflow.contrib.hooks.gcp_spanner_hook.CloudSpannerHook._get_client')
    def test_update_instance_missing_project_id(self, get_client):
        instance_method = get_client.return_value.instance
        instance_exists_method = instance_method.return_value.exists
        instance_exists_method.return_value = True
        update_method = instance_method.return_value.update
        update_method.return_value = False
        with self.assertRaises(AirflowException) as cm:
            self.spanner_hook_no_default_project_id.update_instance(
                instance_id=SPANNER_INSTANCE,
                configuration_name=SPANNER_CONFIGURATION,
                node_count=2,
                display_name=SPANNER_DATABASE)
        get_client.assert_not_called()
        instance_method.assert_not_called()
        update_method.assert_not_called()
        err = cm.exception
        self.assertIn("The project id must be passed", str(err))

    @mock.patch('airflow.contrib.hooks.gcp_spanner_hook.CloudSpannerHook._get_client')
    def test_update_instance_overridden_project_id(self, get_client):
        instance_method = get_client.return_value.instance
        instance_exists_method = instance_method.return_value.exists
        instance_exists_method.return_value = True
        update_method = instance_method.return_value.update
        update_method.return_value = False
        res = self.spanner_hook_no_default_project_id.update_instance(
            project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST,
            instance_id=SPANNER_INSTANCE,
            configuration_name=SPANNER_CONFIGURATION,
            node_count=2,
            display_name=SPANNER_DATABASE)
        get_client.assert_called_once_with(project_id='example-project')
        instance_method.assert_called_once_with(
            instance_id='instance', configuration_name='configuration', display_name='database-name',
            node_count=2)
        update_method.assert_called_with()
        self.assertIsNone(res)

    @mock.patch('airflow.contrib.hooks.gcp_spanner_hook.CloudSpannerHook._get_client')
    def test_delete_instance_missing_project_id(self, get_client):
        instance_method = get_client.return_value.instance
        instance_exists_method = instance_method.return_value.exists
        instance_exists_method.return_value = True
        delete_method = instance_method.return_value.delete
        delete_method.return_value = False
        with self.assertRaises(AirflowException) as cm:
            self.spanner_hook_no_default_project_id.delete_instance(
                instance_id=SPANNER_INSTANCE)
        get_client.assert_not_called()
        instance_method.assert_not_called()
        delete_method.assert_not_called()
        err = cm.exception
        self.assertIn("The project id must be passed", str(err))

    @mock.patch('airflow.contrib.hooks.gcp_spanner_hook.CloudSpannerHook._get_client')
    def test_delete_instance_overridden_project_id(self, get_client):
        instance_method = get_client.return_value.instance
        instance_exists_method = instance_method.return_value.exists
        instance_exists_method.return_value = True
        delete_method = instance_method.return_value.delete
        delete_method.return_value = False
        res = self.spanner_hook_no_default_project_id.delete_instance(
            project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST,
            instance_id=SPANNER_INSTANCE)
        get_client.assert_called_once_with(project_id='example-project')
        instance_method.assert_called_once_with(
            'instance')
        delete_method.assert_called_with()
        self.assertIsNone(res)

    @mock.patch('airflow.contrib.hooks.gcp_spanner_hook.CloudSpannerHook._get_client')
    def test_get_database_missing_project_id(self, get_client):
        instance_method = get_client.return_value.instance
        instance_exists_method = instance_method.return_value.exists
        instance_exists_method.return_value = True
        database_method = instance_method.return_value.database
        database_exists_method = instance_method.return_value.exists
        database_exists_method.return_value = True
        with self.assertRaises(AirflowException) as cm:
            self.spanner_hook_no_default_project_id.get_database(
                instance_id=SPANNER_INSTANCE,
                database_id=SPANNER_DATABASE)
        get_client.assert_not_called()
        instance_method.assert_not_called()
        database_method.assert_not_called()
        err = cm.exception
        self.assertIn("The project id must be passed", str(err))

    @mock.patch('airflow.contrib.hooks.gcp_spanner_hook.CloudSpannerHook._get_client')
    def test_get_database_overridden_project_id(self, get_client):
        instance_method = get_client.return_value.instance
        instance_exists_method = instance_method.return_value.exists
        instance_exists_method.return_value = True
        database_method = instance_method.return_value.database
        database_exists_method = instance_method.return_value.exists
        database_exists_method.return_value = True
        res = self.spanner_hook_no_default_project_id.get_database(
            project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST,
            instance_id=SPANNER_INSTANCE,
            database_id=SPANNER_DATABASE)
        get_client.assert_called_once_with(project_id='example-project')
        instance_method.assert_called_once_with(instance_id='instance')
        database_method.assert_called_with(database_id='database-name')
        database_exists_method.assert_called_with()
        self.assertIsNotNone(res)

    @mock.patch('airflow.contrib.hooks.gcp_spanner_hook.CloudSpannerHook._get_client')
    def test_create_database_missing_project_id(self, get_client):
        instance_method = get_client.return_value.instance
        instance_exists_method = instance_method.return_value.exists
        instance_exists_method.return_value = True
        database_method = instance_method.return_value.database
        database_create_method = database_method.return_value.create
        with self.assertRaises(AirflowException) as cm:
            self.spanner_hook_no_default_project_id.create_database(
                instance_id=SPANNER_INSTANCE,
                database_id=SPANNER_DATABASE,
                ddl_statements=[])
        get_client.assert_not_called()
        instance_method.assert_not_called()
        database_method.assert_not_called()
        database_create_method.assert_not_called()
        err = cm.exception
        self.assertIn("The project id must be passed", str(err))

    @mock.patch('airflow.contrib.hooks.gcp_spanner_hook.CloudSpannerHook._get_client')
    def test_create_database_overridden_project_id(self, get_client):
        instance_method = get_client.return_value.instance
        instance_exists_method = instance_method.return_value.exists
        instance_exists_method.return_value = True
        database_method = instance_method.return_value.database
        database_create_method = database_method.return_value.create
        res = self.spanner_hook_no_default_project_id.create_database(
            project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST,
            instance_id=SPANNER_INSTANCE,
            database_id=SPANNER_DATABASE,
            ddl_statements=[])
        get_client.assert_called_once_with(project_id='example-project')
        instance_method.assert_called_once_with(instance_id='instance')
        database_method.assert_called_with(database_id='database-name', ddl_statements=[])
        database_create_method.assert_called_with()
        self.assertIsNone(res)

    @mock.patch('airflow.contrib.hooks.gcp_spanner_hook.CloudSpannerHook._get_client')
    def test_update_database_missing_project_id(self, get_client):
        instance_method = get_client.return_value.instance
        instance_exists_method = instance_method.return_value.exists
        instance_exists_method.return_value = True
        database_method = instance_method.return_value.database
        database_update_method = database_method.return_value.update
        with self.assertRaises(AirflowException) as cm:
            self.spanner_hook_no_default_project_id.update_database(
                instance_id=SPANNER_INSTANCE,
                database_id=SPANNER_DATABASE,
                ddl_statements=[])
        get_client.assert_not_called()
        instance_method.assert_not_called()
        database_method.assert_not_called()
        database_update_method.assert_not_called()
        err = cm.exception
        self.assertIn("The project id must be passed", str(err))

    @mock.patch('airflow.contrib.hooks.gcp_spanner_hook.CloudSpannerHook._get_client')
    def test_update_database_overridden_project_id(self, get_client):
        instance_method = get_client.return_value.instance
        instance_exists_method = instance_method.return_value.exists
        instance_exists_method.return_value = True
        database_method = instance_method.return_value.database
        database_update_ddl_method = database_method.return_value.update_ddl
        res = self.spanner_hook_no_default_project_id.update_database(
            project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST,
            instance_id=SPANNER_INSTANCE,
            database_id=SPANNER_DATABASE,
            ddl_statements=[])
        get_client.assert_called_once_with(project_id='example-project')
        instance_method.assert_called_once_with(instance_id='instance')
        database_method.assert_called_with(database_id='database-name')
        database_update_ddl_method.assert_called_with(ddl_statements=[], operation_id=None)
        self.assertIsNone(res)

    @mock.patch('airflow.contrib.hooks.gcp_spanner_hook.CloudSpannerHook._get_client')
    def test_update_database_overridden_project_id_and_operation(self, get_client):
        instance_method = get_client.return_value.instance
        instance_exists_method = instance_method.return_value.exists
        instance_exists_method.return_value = True
        database_method = instance_method.return_value.database
        database_update_ddl_method = database_method.return_value.update_ddl
        res = self.spanner_hook_no_default_project_id.update_database(
            project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST,
            instance_id=SPANNER_INSTANCE,
            database_id=SPANNER_DATABASE,
            operation_id="operation",
            ddl_statements=[])
        get_client.assert_called_once_with(project_id='example-project')
        instance_method.assert_called_once_with(instance_id='instance')
        database_method.assert_called_with(database_id='database-name')
        database_update_ddl_method.assert_called_with(ddl_statements=[], operation_id="operation")
        self.assertIsNone(res)

    @mock.patch('airflow.contrib.hooks.gcp_spanner_hook.CloudSpannerHook._get_client')
    def test_delete_database_missing_project_id(self, get_client):
        instance_method = get_client.return_value.instance
        instance_exists_method = instance_method.return_value.exists
        instance_exists_method.return_value = True
        database_method = instance_method.return_value.database
        database_drop_method = database_method.return_value.drop
        database_exists_method = database_method.return_value.exists
        database_exists_method.return_value = True
        with self.assertRaises(AirflowException) as cm:
            self.spanner_hook_no_default_project_id.delete_database(
                instance_id=SPANNER_INSTANCE,
                database_id=SPANNER_DATABASE)
        get_client.assert_not_called()
        instance_method.assert_not_called()
        database_method.assert_not_called()
        database_exists_method.assert_not_called()
        database_drop_method.assert_not_called()
        err = cm.exception
        self.assertIn("The project id must be passed", str(err))

    @mock.patch('airflow.contrib.hooks.gcp_spanner_hook.CloudSpannerHook._get_client')
    def test_delete_database_overridden_project_id(self, get_client):
        instance_method = get_client.return_value.instance
        instance_exists_method = instance_method.return_value.exists
        instance_exists_method.return_value = True
        database_method = instance_method.return_value.database
        database_drop_method = database_method.return_value.drop
        database_exists_method = database_method.return_value.exists
        database_exists_method.return_value = True
        res = self.spanner_hook_no_default_project_id.delete_database(
            project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST,
            instance_id=SPANNER_INSTANCE,
            database_id=SPANNER_DATABASE)
        get_client.assert_called_once_with(project_id='example-project')
        instance_method.assert_called_once_with(instance_id='instance')
        database_method.assert_called_with(database_id='database-name')
        database_exists_method.assert_called_with()
        database_drop_method.assert_called_with()
        self.assertIsNone(res)

    @mock.patch('airflow.contrib.hooks.gcp_spanner_hook.CloudSpannerHook._get_client')
    def test_delete_database_missing_database(self, get_client):
        instance_method = get_client.return_value.instance
        instance_exists_method = instance_method.return_value.exists
        instance_exists_method.return_value = True
        database_method = instance_method.return_value.database
        database_drop_method = database_method.return_value.drop
        database_exists_method = database_method.return_value.exists
        database_exists_method.return_value = False
        self.spanner_hook_no_default_project_id.delete_database(
            project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST,
            instance_id=SPANNER_INSTANCE,
            database_id=SPANNER_DATABASE)
        get_client.assert_called_once_with(project_id='example-project')
        instance_method.assert_called_once_with(instance_id='instance')
        database_method.assert_called_with(database_id='database-name')
        database_exists_method.assert_called_with()
        database_drop_method.assert_not_called()

    @mock.patch('airflow.contrib.hooks.gcp_spanner_hook.CloudSpannerHook._get_client')
    def test_execute_dml_missing_project_id(self, get_client):
        instance_method = get_client.return_value.instance
        instance_exists_method = instance_method.return_value.exists
        instance_exists_method.return_value = True
        database_method = instance_method.return_value.database
        run_in_transaction_method = database_method.return_value.run_in_transaction
        with self.assertRaises(AirflowException) as cm:
            self.spanner_hook_no_default_project_id.execute_dml(
                instance_id=SPANNER_INSTANCE,
                database_id=SPANNER_DATABASE,
                queries='')
        get_client.assert_not_called()
        instance_method.assert_not_called()
        database_method.assert_not_called()
        run_in_transaction_method.assert_not_called()
        err = cm.exception
        self.assertIn("The project id must be passed", str(err))

    @mock.patch('airflow.contrib.hooks.gcp_spanner_hook.CloudSpannerHook._get_client')
    def test_execute_dml_overridden_project_id(self, get_client):
        instance_method = get_client.return_value.instance
        instance_exists_method = instance_method.return_value.exists
        instance_exists_method.return_value = True
        database_method = instance_method.return_value.database
        run_in_transaction_method = database_method.return_value.run_in_transaction
        res = self.spanner_hook_no_default_project_id.execute_dml(
            project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST,
            instance_id=SPANNER_INSTANCE,
            database_id=SPANNER_DATABASE,
            queries='')
        get_client.assert_called_once_with(project_id='example-project')
        instance_method.assert_called_once_with(instance_id='instance')
        database_method.assert_called_with(database_id='database-name')
        run_in_transaction_method.assert_called_with(mock.ANY)
        self.assertIsNone(res)
示例#17
0
 def setUp(self):
     with mock.patch(
             'airflow.contrib.hooks.gcp_api_base_hook.GoogleCloudBaseHook.__init__',
             new=mock_base_gcp_hook_no_default_project_id):
         self.spanner_hook_no_default_project_id = CloudSpannerHook(
             gcp_conn_id='test')
示例#18
0
class TestGcpSpannerHookNoDefaultProjectID(unittest.TestCase):
    def setUp(self):
        with mock.patch(
                'airflow.contrib.hooks.gcp_api_base_hook.GoogleCloudBaseHook.__init__',
                new=mock_base_gcp_hook_no_default_project_id):
            self.spanner_hook_no_default_project_id = CloudSpannerHook(
                gcp_conn_id='test')

    @mock.patch(
        'airflow.contrib.hooks.gcp_spanner_hook.CloudSpannerHook._get_client')
    def test_get_existing_instance_missing_project_id(self, get_client):
        instance_method = get_client.return_value.instance
        instance_exists_method = instance_method.return_value.exists
        instance_exists_method.return_value = True
        with self.assertRaises(AirflowException) as cm:
            self.spanner_hook_no_default_project_id.get_instance(
                instance_id=SPANNER_INSTANCE)
        get_client.assert_not_called()
        instance_method.assert_not_called()
        instance_exists_method.assert_not_called()
        err = cm.exception
        self.assertIn("The project id must be passed", str(err))

    @mock.patch(
        'airflow.contrib.hooks.gcp_spanner_hook.CloudSpannerHook._get_client')
    def test_get_existing_instance_overridden_project_id(self, get_client):
        instance_method = get_client.return_value.instance
        instance_exists_method = instance_method.return_value.exists
        instance_exists_method.return_value = True
        res = self.spanner_hook_no_default_project_id.get_instance(
            instance_id=SPANNER_INSTANCE,
            project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST)
        get_client.assert_called_once_with(project_id='example-project')
        instance_method.assert_called_once_with(instance_id='instance')
        self.assertIsNotNone(res)

    @mock.patch(
        'airflow.contrib.hooks.gcp_spanner_hook.CloudSpannerHook._get_client')
    def test_get_non_existing_instance(self, get_client):
        instance_method = get_client.return_value.instance
        instance_exists_method = instance_method.return_value.exists
        instance_exists_method.return_value = False
        res = self.spanner_hook_no_default_project_id.get_instance(
            instance_id=SPANNER_INSTANCE,
            project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST)
        get_client.assert_called_once_with(project_id='example-project')
        instance_method.assert_called_once_with(instance_id='instance')
        self.assertIsNone(res)

    @mock.patch(
        'airflow.contrib.hooks.gcp_spanner_hook.CloudSpannerHook._get_client')
    def test_create_instance_missing_project_id(self, get_client):
        instance_method = get_client.return_value.instance
        create_method = instance_method.return_value.create
        create_method.return_value = False
        with self.assertRaises(AirflowException) as cm:
            self.spanner_hook_no_default_project_id.create_instance(
                instance_id=SPANNER_INSTANCE,
                configuration_name=SPANNER_CONFIGURATION,
                node_count=1,
                display_name=SPANNER_DATABASE)
            get_client.assert_called_once_with(project_id='example-project')
        get_client.assert_not_called()
        instance_method.assert_not_called()
        create_method.assert_not_called()
        err = cm.exception
        self.assertIn("The project id must be passed", str(err))

    @mock.patch(
        'airflow.contrib.hooks.gcp_spanner_hook.CloudSpannerHook._get_client')
    def test_create_instance_overridden_project_id(self, get_client):
        instance_method = get_client.return_value.instance
        create_method = instance_method.return_value.create
        create_method.return_value = False
        res = self.spanner_hook_no_default_project_id.create_instance(
            project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST,
            instance_id=SPANNER_INSTANCE,
            configuration_name=SPANNER_CONFIGURATION,
            node_count=1,
            display_name=SPANNER_DATABASE)
        get_client.assert_called_once_with(project_id='example-project')
        instance_method.assert_called_once_with(
            instance_id='instance',
            configuration_name='configuration',
            display_name='database-name',
            node_count=1)
        self.assertIsNone(res)

    @mock.patch(
        'airflow.contrib.hooks.gcp_spanner_hook.CloudSpannerHook._get_client')
    def test_update_instance_missing_project_id(self, get_client):
        instance_method = get_client.return_value.instance
        instance_exists_method = instance_method.return_value.exists
        instance_exists_method.return_value = True
        update_method = instance_method.return_value.update
        update_method.return_value = False
        with self.assertRaises(AirflowException) as cm:
            self.spanner_hook_no_default_project_id.update_instance(
                instance_id=SPANNER_INSTANCE,
                configuration_name=SPANNER_CONFIGURATION,
                node_count=2,
                display_name=SPANNER_DATABASE)
        get_client.assert_not_called()
        instance_method.assert_not_called()
        update_method.assert_not_called()
        err = cm.exception
        self.assertIn("The project id must be passed", str(err))

    @mock.patch(
        'airflow.contrib.hooks.gcp_spanner_hook.CloudSpannerHook._get_client')
    def test_update_instance_overridden_project_id(self, get_client):
        instance_method = get_client.return_value.instance
        instance_exists_method = instance_method.return_value.exists
        instance_exists_method.return_value = True
        update_method = instance_method.return_value.update
        update_method.return_value = False
        res = self.spanner_hook_no_default_project_id.update_instance(
            project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST,
            instance_id=SPANNER_INSTANCE,
            configuration_name=SPANNER_CONFIGURATION,
            node_count=2,
            display_name=SPANNER_DATABASE)
        get_client.assert_called_once_with(project_id='example-project')
        instance_method.assert_called_once_with(
            instance_id='instance',
            configuration_name='configuration',
            display_name='database-name',
            node_count=2)
        update_method.assert_called_with()
        self.assertIsNone(res)

    @mock.patch(
        'airflow.contrib.hooks.gcp_spanner_hook.CloudSpannerHook._get_client')
    def test_delete_instance_missing_project_id(self, get_client):
        instance_method = get_client.return_value.instance
        instance_exists_method = instance_method.return_value.exists
        instance_exists_method.return_value = True
        delete_method = instance_method.return_value.delete
        delete_method.return_value = False
        with self.assertRaises(AirflowException) as cm:
            self.spanner_hook_no_default_project_id.delete_instance(
                instance_id=SPANNER_INSTANCE)
        get_client.assert_not_called()
        instance_method.assert_not_called()
        delete_method.assert_not_called()
        err = cm.exception
        self.assertIn("The project id must be passed", str(err))

    @mock.patch(
        'airflow.contrib.hooks.gcp_spanner_hook.CloudSpannerHook._get_client')
    def test_delete_instance_overridden_project_id(self, get_client):
        instance_method = get_client.return_value.instance
        instance_exists_method = instance_method.return_value.exists
        instance_exists_method.return_value = True
        delete_method = instance_method.return_value.delete
        delete_method.return_value = False
        res = self.spanner_hook_no_default_project_id.delete_instance(
            project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST,
            instance_id=SPANNER_INSTANCE)
        get_client.assert_called_once_with(project_id='example-project')
        instance_method.assert_called_once_with('instance')
        delete_method.assert_called_with()
        self.assertIsNone(res)

    @mock.patch(
        'airflow.contrib.hooks.gcp_spanner_hook.CloudSpannerHook._get_client')
    def test_get_database_missing_project_id(self, get_client):
        instance_method = get_client.return_value.instance
        instance_exists_method = instance_method.return_value.exists
        instance_exists_method.return_value = True
        database_method = instance_method.return_value.database
        database_exists_method = instance_method.return_value.exists
        database_exists_method.return_value = True
        with self.assertRaises(AirflowException) as cm:
            self.spanner_hook_no_default_project_id.get_database(
                instance_id=SPANNER_INSTANCE, database_id=SPANNER_DATABASE)
        get_client.assert_not_called()
        instance_method.assert_not_called()
        database_method.assert_not_called()
        err = cm.exception
        self.assertIn("The project id must be passed", str(err))

    @mock.patch(
        'airflow.contrib.hooks.gcp_spanner_hook.CloudSpannerHook._get_client')
    def test_get_database_overridden_project_id(self, get_client):
        instance_method = get_client.return_value.instance
        instance_exists_method = instance_method.return_value.exists
        instance_exists_method.return_value = True
        database_method = instance_method.return_value.database
        database_exists_method = instance_method.return_value.exists
        database_exists_method.return_value = True
        res = self.spanner_hook_no_default_project_id.get_database(
            project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST,
            instance_id=SPANNER_INSTANCE,
            database_id=SPANNER_DATABASE)
        get_client.assert_called_once_with(project_id='example-project')
        instance_method.assert_called_once_with(instance_id='instance')
        database_method.assert_called_with(database_id='database-name')
        database_exists_method.assert_called_with()
        self.assertIsNotNone(res)

    @mock.patch(
        'airflow.contrib.hooks.gcp_spanner_hook.CloudSpannerHook._get_client')
    def test_create_database_missing_project_id(self, get_client):
        instance_method = get_client.return_value.instance
        instance_exists_method = instance_method.return_value.exists
        instance_exists_method.return_value = True
        database_method = instance_method.return_value.database
        database_create_method = database_method.return_value.create
        with self.assertRaises(AirflowException) as cm:
            self.spanner_hook_no_default_project_id.create_database(
                instance_id=SPANNER_INSTANCE,
                database_id=SPANNER_DATABASE,
                ddl_statements=[])
        get_client.assert_not_called()
        instance_method.assert_not_called()
        database_method.assert_not_called()
        database_create_method.assert_not_called()
        err = cm.exception
        self.assertIn("The project id must be passed", str(err))

    @mock.patch(
        'airflow.contrib.hooks.gcp_spanner_hook.CloudSpannerHook._get_client')
    def test_create_database_overridden_project_id(self, get_client):
        instance_method = get_client.return_value.instance
        instance_exists_method = instance_method.return_value.exists
        instance_exists_method.return_value = True
        database_method = instance_method.return_value.database
        database_create_method = database_method.return_value.create
        res = self.spanner_hook_no_default_project_id.create_database(
            project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST,
            instance_id=SPANNER_INSTANCE,
            database_id=SPANNER_DATABASE,
            ddl_statements=[])
        get_client.assert_called_once_with(project_id='example-project')
        instance_method.assert_called_once_with(instance_id='instance')
        database_method.assert_called_with(database_id='database-name',
                                           ddl_statements=[])
        database_create_method.assert_called_with()
        self.assertIsNone(res)

    @mock.patch(
        'airflow.contrib.hooks.gcp_spanner_hook.CloudSpannerHook._get_client')
    def test_update_database_missing_project_id(self, get_client):
        instance_method = get_client.return_value.instance
        instance_exists_method = instance_method.return_value.exists
        instance_exists_method.return_value = True
        database_method = instance_method.return_value.database
        database_update_method = database_method.return_value.update
        with self.assertRaises(AirflowException) as cm:
            self.spanner_hook_no_default_project_id.update_database(
                instance_id=SPANNER_INSTANCE,
                database_id=SPANNER_DATABASE,
                ddl_statements=[])
        get_client.assert_not_called()
        instance_method.assert_not_called()
        database_method.assert_not_called()
        database_update_method.assert_not_called()
        err = cm.exception
        self.assertIn("The project id must be passed", str(err))

    @mock.patch(
        'airflow.contrib.hooks.gcp_spanner_hook.CloudSpannerHook._get_client')
    def test_update_database_overridden_project_id(self, get_client):
        instance_method = get_client.return_value.instance
        instance_exists_method = instance_method.return_value.exists
        instance_exists_method.return_value = True
        database_method = instance_method.return_value.database
        database_update_ddl_method = database_method.return_value.update_ddl
        res = self.spanner_hook_no_default_project_id.update_database(
            project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST,
            instance_id=SPANNER_INSTANCE,
            database_id=SPANNER_DATABASE,
            ddl_statements=[])
        get_client.assert_called_once_with(project_id='example-project')
        instance_method.assert_called_once_with(instance_id='instance')
        database_method.assert_called_with(database_id='database-name')
        database_update_ddl_method.assert_called_with(ddl_statements=[],
                                                      operation_id=None)
        self.assertIsNone(res)

    @mock.patch(
        'airflow.contrib.hooks.gcp_spanner_hook.CloudSpannerHook._get_client')
    def test_update_database_overridden_project_id_and_operation(
            self, get_client):
        instance_method = get_client.return_value.instance
        instance_exists_method = instance_method.return_value.exists
        instance_exists_method.return_value = True
        database_method = instance_method.return_value.database
        database_update_ddl_method = database_method.return_value.update_ddl
        res = self.spanner_hook_no_default_project_id.update_database(
            project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST,
            instance_id=SPANNER_INSTANCE,
            database_id=SPANNER_DATABASE,
            operation_id="operation",
            ddl_statements=[])
        get_client.assert_called_once_with(project_id='example-project')
        instance_method.assert_called_once_with(instance_id='instance')
        database_method.assert_called_with(database_id='database-name')
        database_update_ddl_method.assert_called_with(ddl_statements=[],
                                                      operation_id="operation")
        self.assertIsNone(res)

    @mock.patch(
        'airflow.contrib.hooks.gcp_spanner_hook.CloudSpannerHook._get_client')
    def test_delete_database_missing_project_id(self, get_client):
        instance_method = get_client.return_value.instance
        instance_exists_method = instance_method.return_value.exists
        instance_exists_method.return_value = True
        database_method = instance_method.return_value.database
        database_drop_method = database_method.return_value.drop
        database_exists_method = database_method.return_value.exists
        database_exists_method.return_value = True
        with self.assertRaises(AirflowException) as cm:
            self.spanner_hook_no_default_project_id.delete_database(
                instance_id=SPANNER_INSTANCE, database_id=SPANNER_DATABASE)
        get_client.assert_not_called()
        instance_method.assert_not_called()
        database_method.assert_not_called()
        database_exists_method.assert_not_called()
        database_drop_method.assert_not_called()
        err = cm.exception
        self.assertIn("The project id must be passed", str(err))

    @mock.patch(
        'airflow.contrib.hooks.gcp_spanner_hook.CloudSpannerHook._get_client')
    def test_delete_database_overridden_project_id(self, get_client):
        instance_method = get_client.return_value.instance
        instance_exists_method = instance_method.return_value.exists
        instance_exists_method.return_value = True
        database_method = instance_method.return_value.database
        database_drop_method = database_method.return_value.drop
        database_exists_method = database_method.return_value.exists
        database_exists_method.return_value = True
        res = self.spanner_hook_no_default_project_id.delete_database(
            project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST,
            instance_id=SPANNER_INSTANCE,
            database_id=SPANNER_DATABASE)
        get_client.assert_called_once_with(project_id='example-project')
        instance_method.assert_called_once_with(instance_id='instance')
        database_method.assert_called_with(database_id='database-name')
        database_exists_method.assert_called_with()
        database_drop_method.assert_called_with()
        self.assertIsNone(res)

    @mock.patch(
        'airflow.contrib.hooks.gcp_spanner_hook.CloudSpannerHook._get_client')
    def test_delete_database_missing_database(self, get_client):
        instance_method = get_client.return_value.instance
        instance_exists_method = instance_method.return_value.exists
        instance_exists_method.return_value = True
        database_method = instance_method.return_value.database
        database_drop_method = database_method.return_value.drop
        database_exists_method = database_method.return_value.exists
        database_exists_method.return_value = False
        self.spanner_hook_no_default_project_id.delete_database(
            project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST,
            instance_id=SPANNER_INSTANCE,
            database_id=SPANNER_DATABASE)
        get_client.assert_called_once_with(project_id='example-project')
        instance_method.assert_called_once_with(instance_id='instance')
        database_method.assert_called_with(database_id='database-name')
        database_exists_method.assert_called_with()
        database_drop_method.assert_not_called()

    @mock.patch(
        'airflow.contrib.hooks.gcp_spanner_hook.CloudSpannerHook._get_client')
    def test_execute_dml_missing_project_id(self, get_client):
        instance_method = get_client.return_value.instance
        instance_exists_method = instance_method.return_value.exists
        instance_exists_method.return_value = True
        database_method = instance_method.return_value.database
        run_in_transaction_method = database_method.return_value.run_in_transaction
        with self.assertRaises(AirflowException) as cm:
            self.spanner_hook_no_default_project_id.execute_dml(
                instance_id=SPANNER_INSTANCE,
                database_id=SPANNER_DATABASE,
                queries='')
        get_client.assert_not_called()
        instance_method.assert_not_called()
        database_method.assert_not_called()
        run_in_transaction_method.assert_not_called()
        err = cm.exception
        self.assertIn("The project id must be passed", str(err))

    @mock.patch(
        'airflow.contrib.hooks.gcp_spanner_hook.CloudSpannerHook._get_client')
    def test_execute_dml_overridden_project_id(self, get_client):
        instance_method = get_client.return_value.instance
        instance_exists_method = instance_method.return_value.exists
        instance_exists_method.return_value = True
        database_method = instance_method.return_value.database
        run_in_transaction_method = database_method.return_value.run_in_transaction
        res = self.spanner_hook_no_default_project_id.execute_dml(
            project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST,
            instance_id=SPANNER_INSTANCE,
            database_id=SPANNER_DATABASE,
            queries='')
        get_client.assert_called_once_with(project_id='example-project')
        instance_method.assert_called_once_with(instance_id='instance')
        database_method.assert_called_with(database_id='database-name')
        run_in_transaction_method.assert_called_with(mock.ANY)
        self.assertIsNone(res)
示例#19
0
class CloudSpannerInstanceDatabaseUpdateOperator(BaseOperator):
    """
    Updates a Cloud Spanner database with the specified DDL statement.

    .. seealso::
        For more information on how to use this operator, take a look at the guide:
        :ref:`howto/operator:CloudSpannerInstanceDatabaseUpdateOperator`

    :param instance_id: The Cloud Spanner instance ID.
    :type instance_id: str
    :param database_id: The Cloud Spanner database ID.
    :type database_id: str
    :param ddl_statements: The string list containing DDL to apply to the database.
    :type ddl_statements: list[str]
    :param project_id: Optional, the ID of the project that owns the the Cloud Spanner
        Database.  If set to None or missing, the default project_id from the GCP connection is used.
    :type project_id: str
    :param operation_id: (Optional) Unique per database operation id that can
           be specified to implement idempotency check.
    :type operation_id: str
    :param gcp_conn_id: The connection ID used to connect to Google Cloud Platform.
    :type gcp_conn_id: str
    """
    # [START gcp_spanner_database_update_template_fields]
    template_fields = ('project_id', 'instance_id', 'database_id',
                       'ddl_statements', 'gcp_conn_id')
    template_ext = ('.sql', )
    # [END gcp_spanner_database_update_template_fields]

    @apply_defaults
    def __init__(self,
                 instance_id,
                 database_id,
                 ddl_statements,
                 project_id=None,
                 operation_id=None,
                 gcp_conn_id='google_cloud_default',
                 *args,
                 **kwargs):
        self.instance_id = instance_id
        self.project_id = project_id
        self.database_id = database_id
        self.ddl_statements = ddl_statements
        self.operation_id = operation_id
        self.gcp_conn_id = gcp_conn_id
        self._validate_inputs()
        self._hook = CloudSpannerHook(gcp_conn_id=gcp_conn_id)
        super().__init__(*args, **kwargs)

    def _validate_inputs(self):
        if self.project_id == '':
            raise AirflowException(
                "The required parameter 'project_id' is empty")
        if not self.instance_id:
            raise AirflowException(
                "The required parameter 'instance_id' is empty"
                " or None")
        if not self.database_id:
            raise AirflowException(
                "The required parameter 'database_id' is empty"
                " or None")
        if not self.ddl_statements:
            raise AirflowException(
                "The required parameter 'ddl_statements' is empty"
                " or None")

    def execute(self, context):
        if not self._hook.get_database(project_id=self.project_id,
                                       instance_id=self.instance_id,
                                       database_id=self.database_id):
            raise AirflowException(
                "The Cloud Spanner database '{}' in project '{}' and "
                "instance '{}' is missing. Create the database first "
                "before you can update it.".format(self.database_id,
                                                   self.project_id,
                                                   self.instance_id))
        else:
            return self._hook.update_database(
                project_id=self.project_id,
                instance_id=self.instance_id,
                database_id=self.database_id,
                ddl_statements=self.ddl_statements,
                operation_id=self.operation_id)
示例#20
0
class CloudSpannerInstanceDeployOperator(BaseOperator):
    """
    Creates a new Cloud Spanner instance, or if an instance with the same instance_id
    exists in the specified project, updates the Cloud Spanner instance.

    :param instance_id: Cloud Spanner instance ID.
    :type instance_id: str
    :param configuration_name:  The name of the Cloud Spanner instance configuration
      defining how the instance will be created. Required for
      instances that do not yet exist.
    :type configuration_name: str
    :param node_count: (Optional) The number of nodes allocated to the Cloud Spanner
      instance.
    :type node_count: int
    :param display_name: (Optional) The display name for the Cloud Spanner  instance in
      the GCP Console. (Must be between 4 and 30 characters.) If this value is not set
      in the constructor, the name is the same as the instance ID.
    :type display_name: str
    :param project_id: Optional, the ID of the project which owns the Cloud Spanner
        Database.  If set to None or missing, the default project_id from the GCP connection is used.
    :type project_id: str
    :param gcp_conn_id: The connection ID used to connect to Google Cloud Platform.
    :type gcp_conn_id: str
    """
    # [START gcp_spanner_deploy_template_fields]
    template_fields = ('project_id', 'instance_id', 'configuration_name',
                       'display_name', 'gcp_conn_id')
    # [END gcp_spanner_deploy_template_fields]

    @apply_defaults
    def __init__(self,
                 instance_id,
                 configuration_name,
                 node_count,
                 display_name,
                 project_id=None,
                 gcp_conn_id='google_cloud_default',
                 *args,
                 **kwargs):
        self.instance_id = instance_id
        self.project_id = project_id
        self.configuration_name = configuration_name
        self.node_count = node_count
        self.display_name = display_name
        self.gcp_conn_id = gcp_conn_id
        self._validate_inputs()
        self._hook = CloudSpannerHook(gcp_conn_id=gcp_conn_id)
        super().__init__(*args, **kwargs)

    def _validate_inputs(self):
        if self.project_id == '':
            raise AirflowException(
                "The required parameter 'project_id' is empty")
        if not self.instance_id:
            raise AirflowException("The required parameter 'instance_id' "
                                   "is empty or None")

    def execute(self, context):
        if not self._hook.get_instance(project_id=self.project_id,
                                       instance_id=self.instance_id):
            self.log.info("Creating Cloud Spanner instance '%s'",
                          self.instance_id)
            func = self._hook.create_instance
        else:
            self.log.info("Updating Cloud Spanner instance '%s'",
                          self.instance_id)
            func = self._hook.update_instance
        func(project_id=self.project_id,
             instance_id=self.instance_id,
             configuration_name=self.configuration_name,
             node_count=self.node_count,
             display_name=self.display_name)
示例#21
0
class CloudSpannerInstanceDatabaseDeployOperator(BaseOperator):
    """
    Creates a new Cloud Spanner database, or if database exists,
    the operator does nothing.

    .. seealso::
        For more information on how to use this operator, take a look at the guide:
        :ref:`howto/operator:CloudSpannerInstanceDatabaseDeployOperator`

    :param instance_id: The Cloud Spanner instance ID.
    :type instance_id: str
    :param database_id: The Cloud Spanner database ID.
    :type database_id: str
    :param ddl_statements: The string list containing DDL for the new database.
    :type ddl_statements: list[str]
    :param project_id: Optional, the ID of the project that owns the Cloud Spanner
        Database.  If set to None or missing, the default project_id from the GCP connection is used.
    :type project_id: str
    :param gcp_conn_id: The connection ID used to connect to Google Cloud Platform.
    :type gcp_conn_id: str
    """
    # [START gcp_spanner_database_deploy_template_fields]
    template_fields = ('project_id', 'instance_id', 'database_id',
                       'ddl_statements', 'gcp_conn_id')
    template_ext = ('.sql', )
    # [END gcp_spanner_database_deploy_template_fields]

    @apply_defaults
    def __init__(self,
                 instance_id,
                 database_id,
                 ddl_statements,
                 project_id=None,
                 gcp_conn_id='google_cloud_default',
                 *args,
                 **kwargs):
        self.instance_id = instance_id
        self.project_id = project_id
        self.database_id = database_id
        self.ddl_statements = ddl_statements
        self.gcp_conn_id = gcp_conn_id
        self._validate_inputs()
        self._hook = CloudSpannerHook(gcp_conn_id=gcp_conn_id)
        super().__init__(*args, **kwargs)

    def _validate_inputs(self):
        if self.project_id == '':
            raise AirflowException(
                "The required parameter 'project_id' is empty")
        if not self.instance_id:
            raise AirflowException(
                "The required parameter 'instance_id' is empty "
                "or None")
        if not self.database_id:
            raise AirflowException(
                "The required parameter 'database_id' is empty"
                " or None")

    def execute(self, context):
        if not self._hook.get_database(project_id=self.project_id,
                                       instance_id=self.instance_id,
                                       database_id=self.database_id):
            self.log.info(
                "Creating Cloud Spanner database "
                "'%s' in project '%s' and instance '%s'", self.database_id,
                self.project_id, self.instance_id)
            return self._hook.create_database(
                project_id=self.project_id,
                instance_id=self.instance_id,
                database_id=self.database_id,
                ddl_statements=self.ddl_statements)
        else:
            self.log.info(
                "The database '%s' in project '%s' and instance '%s'"
                " already exists. Nothing to do. Exiting.", self.database_id,
                self.project_id, self.instance_id)
        return True
示例#22
0
class CloudSpannerInstanceDatabaseQueryOperator(BaseOperator):
    """
    Executes an arbitrary DML query (INSERT, UPDATE, DELETE).

    .. seealso::
        For more information on how to use this operator, take a look at the guide:
        :ref:`howto/operator:CloudSpannerInstanceDatabaseQueryOperator`

    :param instance_id: The Cloud Spanner instance ID.
    :type instance_id: str
    :param database_id: The Cloud Spanner database ID.
    :type database_id: str
    :param query: The query or list of queries to be executed. Can be a path to a SQL
       file.
    :type query: str or list
    :param project_id: Optional, the ID of the project that owns the Cloud Spanner
        Database.  If set to None or missing, the default project_id from the GCP connection is used.
    :type project_id: str
    :param gcp_conn_id: The connection ID used to connect to Google Cloud Platform.
    :type gcp_conn_id: str
    """
    # [START gcp_spanner_query_template_fields]
    template_fields = ('project_id', 'instance_id', 'database_id', 'query',
                       'gcp_conn_id')
    template_ext = ('.sql', )
    # [END gcp_spanner_query_template_fields]

    @apply_defaults
    def __init__(self,
                 instance_id,
                 database_id,
                 query,
                 project_id=None,
                 gcp_conn_id='google_cloud_default',
                 *args,
                 **kwargs):
        self.instance_id = instance_id
        self.project_id = project_id
        self.database_id = database_id
        self.query = query
        self.gcp_conn_id = gcp_conn_id
        self._validate_inputs()
        self._hook = CloudSpannerHook(gcp_conn_id=gcp_conn_id)
        super().__init__(*args, **kwargs)

    def _validate_inputs(self):
        if self.project_id == '':
            raise AirflowException(
                "The required parameter 'project_id' is empty")
        if not self.instance_id:
            raise AirflowException("The required parameter 'instance_id' "
                                   "is empty or None")
        if not self.database_id:
            raise AirflowException("The required parameter 'database_id' "
                                   "is empty or None")
        if not self.query:
            raise AirflowException("The required parameter 'query' is empty")

    def execute(self, context):
        queries = self.query
        if isinstance(self.query, six.string_types):
            queries = [x.strip() for x in self.query.split(';')]
            self.sanitize_queries(queries)
        self.log.info(
            "Executing DML query(-ies) on "
            "projects/%s/instances/%s/databases/%s", self.project_id,
            self.instance_id, self.database_id)
        self.log.info(queries)
        self._hook.execute_dml(project_id=self.project_id,
                               instance_id=self.instance_id,
                               database_id=self.database_id,
                               queries=queries)

    @staticmethod
    def sanitize_queries(queries):
        if len(queries) and queries[-1] == '':
            del queries[-1]
class CloudSpannerInstanceDatabaseQueryOperator(BaseOperator):
    """
    Executes an arbitrary DML query (INSERT, UPDATE, DELETE).

    .. seealso::
        For more information on how to use this operator, take a look at the guide:
        :ref:`howto/operator:CloudSpannerInstanceDatabaseQueryOperator`

    :param instance_id: The Cloud Spanner instance ID.
    :type instance_id: str
    :param database_id: The Cloud Spanner database ID.
    :type database_id: str
    :param query: The query or list of queries to be executed. Can be a path to a SQL
       file.
    :type query: str or list
    :param project_id: Optional, the ID of the project that owns the Cloud Spanner
        Database.  If set to None or missing, the default project_id from the GCP connection is used.
    :type project_id: str
    :param gcp_conn_id: The connection ID used to connect to Google Cloud Platform.
    :type gcp_conn_id: str
    """
    # [START gcp_spanner_query_template_fields]
    template_fields = ('project_id', 'instance_id', 'database_id', 'query', 'gcp_conn_id')
    template_ext = ('.sql',)
    # [END gcp_spanner_query_template_fields]

    @apply_defaults
    def __init__(self,
                 instance_id,
                 database_id,
                 query,
                 project_id=None,
                 gcp_conn_id='google_cloud_default',
                 *args, **kwargs):
        self.instance_id = instance_id
        self.project_id = project_id
        self.database_id = database_id
        self.query = query
        self.gcp_conn_id = gcp_conn_id
        self._validate_inputs()
        self._hook = CloudSpannerHook(gcp_conn_id=gcp_conn_id)
        super(CloudSpannerInstanceDatabaseQueryOperator, self).__init__(*args, **kwargs)

    def _validate_inputs(self):
        if self.project_id == '':
            raise AirflowException("The required parameter 'project_id' is empty")
        if not self.instance_id:
            raise AirflowException("The required parameter 'instance_id' "
                                   "is empty or None")
        if not self.database_id:
            raise AirflowException("The required parameter 'database_id' "
                                   "is empty or None")
        if not self.query:
            raise AirflowException("The required parameter 'query' is empty")

    def execute(self, context):
        queries = self.query
        if isinstance(self.query, six.string_types):
            queries = [x.strip() for x in self.query.split(';')]
            self.sanitize_queries(queries)
        self.log.info("Executing DML query(-ies) on "
                      "projects/%s/instances/%s/databases/%s",
                      self.project_id, self.instance_id, self.database_id)
        self.log.info(queries)
        self._hook.execute_dml(project_id=self.project_id,
                               instance_id=self.instance_id,
                               database_id=self.database_id,
                               queries=queries)

    @staticmethod
    def sanitize_queries(queries):
        if len(queries) and queries[-1] == '':
            del queries[-1]
class CloudSpannerInstanceDatabaseUpdateOperator(BaseOperator):
    """
    Updates a Cloud Spanner database with the specified DDL statement.

    .. seealso::
        For more information on how to use this operator, take a look at the guide:
        :ref:`howto/operator:CloudSpannerInstanceDatabaseUpdateOperator`

    :param instance_id: The Cloud Spanner instance ID.
    :type instance_id: str
    :param database_id: The Cloud Spanner database ID.
    :type database_id: str
    :param ddl_statements: The string list containing DDL to apply to the database.
    :type ddl_statements: list[str]
    :param project_id: Optional, the ID of the project that owns the the Cloud Spanner
        Database.  If set to None or missing, the default project_id from the GCP connection is used.
    :type project_id: str
    :param operation_id: (Optional) Unique per database operation id that can
           be specified to implement idempotency check.
    :type operation_id: str
    :param gcp_conn_id: The connection ID used to connect to Google Cloud Platform.
    :type gcp_conn_id: str
    """
    # [START gcp_spanner_database_update_template_fields]
    template_fields = ('project_id', 'instance_id', 'database_id', 'ddl_statements',
                       'gcp_conn_id')
    template_ext = ('.sql', )
    # [END gcp_spanner_database_update_template_fields]

    @apply_defaults
    def __init__(self,
                 instance_id,
                 database_id,
                 ddl_statements,
                 project_id=None,
                 operation_id=None,
                 gcp_conn_id='google_cloud_default',
                 *args, **kwargs):
        self.instance_id = instance_id
        self.project_id = project_id
        self.database_id = database_id
        self.ddl_statements = ddl_statements
        self.operation_id = operation_id
        self.gcp_conn_id = gcp_conn_id
        self._validate_inputs()
        self._hook = CloudSpannerHook(gcp_conn_id=gcp_conn_id)
        super(CloudSpannerInstanceDatabaseUpdateOperator, self).__init__(*args, **kwargs)

    def _validate_inputs(self):
        if self.project_id == '':
            raise AirflowException("The required parameter 'project_id' is empty")
        if not self.instance_id:
            raise AirflowException("The required parameter 'instance_id' is empty"
                                   " or None")
        if not self.database_id:
            raise AirflowException("The required parameter 'database_id' is empty"
                                   " or None")
        if not self.ddl_statements:
            raise AirflowException("The required parameter 'ddl_statements' is empty"
                                   " or None")

    def execute(self, context):
        if not self._hook.get_database(project_id=self.project_id,
                                       instance_id=self.instance_id,
                                       database_id=self.database_id):
            raise AirflowException("The Cloud Spanner database '{}' in project '{}' and "
                                   "instance '{}' is missing. Create the database first "
                                   "before you can update it.".format(self.database_id,
                                                                      self.project_id,
                                                                      self.instance_id))
        else:
            return self._hook.update_database(project_id=self.project_id,
                                              instance_id=self.instance_id,
                                              database_id=self.database_id,
                                              ddl_statements=self.ddl_statements,
                                              operation_id=self.operation_id)
class CloudSpannerInstanceDeployOperator(BaseOperator):
    """
    Creates a new Cloud Spanner instance, or if an instance with the same instance_id
    exists in the specified project, updates the Cloud Spanner instance.

    :param instance_id: Cloud Spanner instance ID.
    :type instance_id: str
    :param configuration_name:  The name of the Cloud Spanner instance configuration
      defining how the instance will be created. Required for
      instances that do not yet exist.
    :type configuration_name: str
    :param node_count: (Optional) The number of nodes allocated to the Cloud Spanner
      instance.
    :type node_count: int
    :param display_name: (Optional) The display name for the Cloud Spanner  instance in
      the GCP Console. (Must be between 4 and 30 characters.) If this value is not set
      in the constructor, the name is the same as the instance ID.
    :type display_name: str
    :param project_id: Optional, the ID of the project which owns the Cloud Spanner
        Database.  If set to None or missing, the default project_id from the GCP connection is used.
    :type project_id: str
    :param gcp_conn_id: The connection ID used to connect to Google Cloud Platform.
    :type gcp_conn_id: str
    """
    # [START gcp_spanner_deploy_template_fields]
    template_fields = ('project_id', 'instance_id', 'configuration_name', 'display_name',
                       'gcp_conn_id')
    # [END gcp_spanner_deploy_template_fields]

    @apply_defaults
    def __init__(self,
                 instance_id,
                 configuration_name,
                 node_count,
                 display_name,
                 project_id=None,
                 gcp_conn_id='google_cloud_default',
                 *args, **kwargs):
        self.instance_id = instance_id
        self.project_id = project_id
        self.configuration_name = configuration_name
        self.node_count = node_count
        self.display_name = display_name
        self.gcp_conn_id = gcp_conn_id
        self._validate_inputs()
        self._hook = CloudSpannerHook(gcp_conn_id=gcp_conn_id)
        super(CloudSpannerInstanceDeployOperator, self).__init__(*args, **kwargs)

    def _validate_inputs(self):
        if self.project_id == '':
            raise AirflowException("The required parameter 'project_id' is empty")
        if not self.instance_id:
            raise AirflowException("The required parameter 'instance_id' "
                                   "is empty or None")

    def execute(self, context):
        if not self._hook.get_instance(project_id=self.project_id, instance_id=self.instance_id):
            self.log.info("Creating Cloud Spanner instance '%s'", self.instance_id)
            func = self._hook.create_instance
        else:
            self.log.info("Updating Cloud Spanner instance '%s'", self.instance_id)
            func = self._hook.update_instance
        func(project_id=self.project_id,
             instance_id=self.instance_id,
             configuration_name=self.configuration_name,
             node_count=self.node_count,
             display_name=self.display_name)
class CloudSpannerInstanceDatabaseDeployOperator(BaseOperator):
    """
    Creates a new Cloud Spanner database, or if database exists,
    the operator does nothing.

    .. seealso::
        For more information on how to use this operator, take a look at the guide:
        :ref:`howto/operator:CloudSpannerInstanceDatabaseDeployOperator`

    :param instance_id: The Cloud Spanner instance ID.
    :type instance_id: str
    :param database_id: The Cloud Spanner database ID.
    :type database_id: str
    :param ddl_statements: The string list containing DDL for the new database.
    :type ddl_statements: list[str]
    :param project_id: Optional, the ID of the project that owns the Cloud Spanner
        Database.  If set to None or missing, the default project_id from the GCP connection is used.
    :type project_id: str
    :param gcp_conn_id: The connection ID used to connect to Google Cloud Platform.
    :type gcp_conn_id: str
    """
    # [START gcp_spanner_database_deploy_template_fields]
    template_fields = ('project_id', 'instance_id', 'database_id', 'ddl_statements',
                       'gcp_conn_id')
    template_ext = ('.sql', )
    # [END gcp_spanner_database_deploy_template_fields]

    @apply_defaults
    def __init__(self,
                 instance_id,
                 database_id,
                 ddl_statements,
                 project_id=None,
                 gcp_conn_id='google_cloud_default',
                 *args, **kwargs):
        self.instance_id = instance_id
        self.project_id = project_id
        self.database_id = database_id
        self.ddl_statements = ddl_statements
        self.gcp_conn_id = gcp_conn_id
        self._validate_inputs()
        self._hook = CloudSpannerHook(gcp_conn_id=gcp_conn_id)
        super(CloudSpannerInstanceDatabaseDeployOperator, self).__init__(*args, **kwargs)

    def _validate_inputs(self):
        if self.project_id == '':
            raise AirflowException("The required parameter 'project_id' is empty")
        if not self.instance_id:
            raise AirflowException("The required parameter 'instance_id' is empty "
                                   "or None")
        if not self.database_id:
            raise AirflowException("The required parameter 'database_id' is empty"
                                   " or None")

    def execute(self, context):
        if not self._hook.get_database(project_id=self.project_id,
                                       instance_id=self.instance_id,
                                       database_id=self.database_id):
            self.log.info("Creating Cloud Spanner database "
                          "'%s' in project '%s' and instance '%s'",
                          self.database_id, self.project_id, self.instance_id)
            return self._hook.create_database(project_id=self.project_id,
                                              instance_id=self.instance_id,
                                              database_id=self.database_id,
                                              ddl_statements=self.ddl_statements)
        else:
            self.log.info("The database '%s' in project '%s' and instance '%s'"
                          " already exists. Nothing to do. Exiting.",
                          self.database_id, self.project_id, self.instance_id)
        return True