Ejemplo n.º 1
0
 def execute(self, context: 'Context') -> Optional[bool]:
     hook = SpannerHook(
         gcp_conn_id=self.gcp_conn_id,
         impersonation_chain=self.impersonation_chain,
     )
     if not hook.get_database(
         project_id=self.project_id, instance_id=self.instance_id, database_id=self.database_id
     ):
         self.log.info(
             "Creating Cloud Spanner database '%s' in project '%s' and instance '%s'",
             self.database_id,
             self.project_id,
             self.instance_id,
         )
         return hook.create_database(
             project_id=self.project_id,
             instance_id=self.instance_id,
             database_id=self.database_id,
             ddl_statements=self.ddl_statements,
         )
     else:
         self.log.info(
             "The database '%s' in project '%s' and instance '%s'"
             " already exists. Nothing to do. Exiting.",
             self.database_id,
             self.project_id,
             self.instance_id,
         )
     return True
Ejemplo n.º 2
0
 def setUp(self):
     with mock.patch(
             'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__',
             new=mock_base_gcp_hook_no_default_project_id,
     ):
         self.spanner_hook_no_default_project_id = SpannerHook(
             gcp_conn_id='test')
Ejemplo n.º 3
0
 def execute(self, context: 'Context') -> None:
     hook = SpannerHook(
         gcp_conn_id=self.gcp_conn_id,
         impersonation_chain=self.impersonation_chain,
     )
     if not hook.get_database(project_id=self.project_id,
                              instance_id=self.instance_id,
                              database_id=self.database_id):
         raise AirflowException(
             f"The Cloud Spanner database '{self.database_id}' in project '{self.project_id}' "
             f"and instance '{self.instance_id}' is missing. "
             f"Create the database first before you can update it.")
     else:
         SpannerDatabaseLink.persist(
             context=context,
             task_instance=self,
             instance_id=self.instance_id,
             database_id=self.database_id,
             project_id=self.project_id or hook.project_id,
         )
         return hook.update_database(
             project_id=self.project_id,
             instance_id=self.instance_id,
             database_id=self.database_id,
             ddl_statements=self.ddl_statements,
             operation_id=self.operation_id,
         )
Ejemplo n.º 4
0
 def execute(self, context: 'Context'):
     hook = SpannerHook(
         gcp_conn_id=self.gcp_conn_id,
         impersonation_chain=self.impersonation_chain,
     )
     if isinstance(self.query, str):
         queries = [x.strip() for x in self.query.split(';')]
         self.sanitize_queries(queries)
     else:
         queries = self.query
     self.log.info(
         "Executing DML query(-ies) on projects/%s/instances/%s/databases/%s",
         self.project_id,
         self.instance_id,
         self.database_id,
     )
     self.log.info(queries)
     hook.execute_dml(
         project_id=self.project_id,
         instance_id=self.instance_id,
         database_id=self.database_id,
         queries=queries,
     )
     SpannerDatabaseLink.persist(
         context=context,
         task_instance=self,
         instance_id=self.instance_id,
         database_id=self.database_id,
         project_id=self.project_id or hook.project_id,
     )
Ejemplo n.º 5
0
 def execute(self, context: 'Context') -> None:
     hook = SpannerHook(
         gcp_conn_id=self.gcp_conn_id,
         impersonation_chain=self.impersonation_chain,
     )
     if not hook.get_instance(project_id=self.project_id,
                              instance_id=self.instance_id):
         self.log.info("Creating Cloud Spanner instance '%s'",
                       self.instance_id)
         func = hook.create_instance
     else:
         self.log.info("Updating Cloud Spanner instance '%s'",
                       self.instance_id)
         func = hook.update_instance
     func(
         project_id=self.project_id,
         instance_id=self.instance_id,
         configuration_name=self.configuration_name,
         node_count=self.node_count,
         display_name=self.display_name,
     )
     SpannerInstanceLink.persist(
         context=context,
         task_instance=self,
         instance_id=self.instance_id,
         project_id=self.project_id or hook.project_id,
     )
Ejemplo n.º 6
0
 def execute(self, context):
     hook = SpannerHook(gcp_conn_id=self.gcp_conn_id)
     if hook.get_instance(project_id=self.project_id, instance_id=self.instance_id):
         return hook.delete_instance(project_id=self.project_id,
                                     instance_id=self.instance_id)
     else:
         self.log.info("Instance '%s' does not exist in project '%s'. "
                       "Aborting delete.", self.instance_id, self.project_id)
         return True
Ejemplo n.º 7
0
 def execute(self, context):
     hook = SpannerHook(gcp_conn_id=self.gcp_conn_id)
     if not hook.get_instance(project_id=self.project_id, instance_id=self.instance_id):
         self.log.info("Creating Cloud Spanner instance '%s'", self.instance_id)
         func = hook.create_instance
     else:
         self.log.info("Updating Cloud Spanner instance '%s'", self.instance_id)
         func = hook.update_instance
     func(project_id=self.project_id,
          instance_id=self.instance_id,
          configuration_name=self.configuration_name,
          node_count=self.node_count,
          display_name=self.display_name)
Ejemplo n.º 8
0
 def execute(self, context: 'Context') -> Optional[bool]:
     hook = SpannerHook(
         gcp_conn_id=self.gcp_conn_id,
         impersonation_chain=self.impersonation_chain,
     )
     if hook.get_instance(project_id=self.project_id, instance_id=self.instance_id):
         return hook.delete_instance(project_id=self.project_id, instance_id=self.instance_id)
     else:
         self.log.info(
             "Instance '%s' does not exist in project '%s'. Aborting delete.",
             self.instance_id,
             self.project_id,
         )
         return True
Ejemplo n.º 9
0
 def execute(self, context):
     hook = SpannerHook(gcp_conn_id=self.gcp_conn_id)
     database = hook.get_database(project_id=self.project_id,
                                  instance_id=self.instance_id,
                                  database_id=self.database_id)
     if not database:
         self.log.info("The Cloud Spanner database was missing: "
                       "'%s' in project '%s' and instance '%s'. Assuming success.",
                       self.database_id, self.project_id, self.instance_id)
         return True
     else:
         return hook.delete_database(project_id=self.project_id,
                                     instance_id=self.instance_id,
                                     database_id=self.database_id)
Ejemplo n.º 10
0
 def execute(self, context):
     hook = SpannerHook(gcp_conn_id=self.gcp_conn_id)
     queries = self.query
     if isinstance(self.query, str):
         queries = [x.strip() for x in self.query.split(';')]
         self.sanitize_queries(queries)
     self.log.info("Executing DML query(-ies) on "
                   "projects/%s/instances/%s/databases/%s",
                   self.project_id, self.instance_id, self.database_id)
     self.log.info(queries)
     hook.execute_dml(project_id=self.project_id,
                      instance_id=self.instance_id,
                      database_id=self.database_id,
                      queries=queries)
Ejemplo n.º 11
0
 def execute(self, context):
     hook = SpannerHook(gcp_conn_id=self.gcp_conn_id)
     if not hook.get_database(project_id=self.project_id,
                              instance_id=self.instance_id,
                              database_id=self.database_id):
         raise AirflowException("The Cloud Spanner database '{}' in project '{}' and "
                                "instance '{}' is missing. Create the database first "
                                "before you can update it.".format(self.database_id,
                                                                   self.project_id,
                                                                   self.instance_id))
     else:
         return hook.update_database(project_id=self.project_id,
                                     instance_id=self.instance_id,
                                     database_id=self.database_id,
                                     ddl_statements=self.ddl_statements,
                                     operation_id=self.operation_id)
Ejemplo n.º 12
0
 def execute(self, context):
     hook = SpannerHook(gcp_conn_id=self.gcp_conn_id)
     if not hook.get_database(project_id=self.project_id,
                              instance_id=self.instance_id,
                              database_id=self.database_id):
         self.log.info("Creating Cloud Spanner database "
                       "'%s' in project '%s' and instance '%s'",
                       self.database_id, self.project_id, self.instance_id)
         return hook.create_database(project_id=self.project_id,
                                     instance_id=self.instance_id,
                                     database_id=self.database_id,
                                     ddl_statements=self.ddl_statements)
     else:
         self.log.info("The database '%s' in project '%s' and instance '%s'"
                       " already exists. Nothing to do. Exiting.",
                       self.database_id, self.project_id, self.instance_id)
     return True
Ejemplo n.º 13
0
class TestGcpSpannerHookNoDefaultProjectID(unittest.TestCase):
    def setUp(self):
        with mock.patch(
            'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__',
            new=mock_base_gcp_hook_no_default_project_id,
        ):
            self.spanner_hook_no_default_project_id = SpannerHook(gcp_conn_id='test')

    @mock.patch(
        "airflow.providers.google.cloud.hooks.spanner.SpannerHook.client_info", new_callable=mock.PropertyMock
    )
    @mock.patch(
        "airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_credentials",
        return_value="CREDENTIALS",
    )
    @mock.patch("airflow.providers.google.cloud.hooks.spanner.Client")
    def test_spanner_client_creation(self, mock_client, mock_get_creds, mock_client_info):
        result = self.spanner_hook_no_default_project_id._get_client(GCP_PROJECT_ID_HOOK_UNIT_TEST)
        mock_client.assert_called_once_with(
            project=GCP_PROJECT_ID_HOOK_UNIT_TEST,
            credentials=mock_get_creds.return_value,
            client_info=mock_client_info.return_value,
        )
        self.assertEqual(mock_client.return_value, result)
        self.assertEqual(self.spanner_hook_no_default_project_id._client, result)

    @mock.patch('airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_client')
    def test_get_existing_instance_overridden_project_id(self, get_client):
        instance_method = get_client.return_value.instance
        instance_exists_method = instance_method.return_value.exists
        instance_exists_method.return_value = True
        res = self.spanner_hook_no_default_project_id.get_instance(
            instance_id=SPANNER_INSTANCE, project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST
        )
        get_client.assert_called_once_with(project_id='example-project')
        instance_method.assert_called_once_with(instance_id='instance')
        self.assertIsNotNone(res)

    @mock.patch('airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_client')
    def test_get_non_existing_instance(self, get_client):
        instance_method = get_client.return_value.instance
        instance_exists_method = instance_method.return_value.exists
        instance_exists_method.return_value = False
        res = self.spanner_hook_no_default_project_id.get_instance(
            instance_id=SPANNER_INSTANCE, project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST
        )
        get_client.assert_called_once_with(project_id='example-project')
        instance_method.assert_called_once_with(instance_id='instance')
        self.assertIsNone(res)

    @mock.patch('airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_client')
    def test_create_instance_overridden_project_id(self, get_client):
        instance_method = get_client.return_value.instance
        create_method = instance_method.return_value.create
        create_method.return_value = False
        res = self.spanner_hook_no_default_project_id.create_instance(
            project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST,
            instance_id=SPANNER_INSTANCE,
            configuration_name=SPANNER_CONFIGURATION,
            node_count=1,
            display_name=SPANNER_DATABASE,
        )
        get_client.assert_called_once_with(project_id='example-project')
        instance_method.assert_called_once_with(
            instance_id='instance',
            configuration_name='configuration',
            display_name='database-name',
            node_count=1,
        )
        self.assertIsNone(res)

    @mock.patch('airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_client')
    def test_update_instance_overridden_project_id(self, get_client):
        instance_method = get_client.return_value.instance
        instance_exists_method = instance_method.return_value.exists
        instance_exists_method.return_value = True
        update_method = instance_method.return_value.update
        update_method.return_value = False
        res = self.spanner_hook_no_default_project_id.update_instance(
            project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST,
            instance_id=SPANNER_INSTANCE,
            configuration_name=SPANNER_CONFIGURATION,
            node_count=2,
            display_name=SPANNER_DATABASE,
        )
        get_client.assert_called_once_with(project_id='example-project')
        instance_method.assert_called_once_with(
            instance_id='instance',
            configuration_name='configuration',
            display_name='database-name',
            node_count=2,
        )
        update_method.assert_called_once_with()
        self.assertIsNone(res)

    @mock.patch('airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_client')
    def test_delete_instance_overridden_project_id(self, get_client):
        instance_method = get_client.return_value.instance
        instance_exists_method = instance_method.return_value.exists
        instance_exists_method.return_value = True
        delete_method = instance_method.return_value.delete
        delete_method.return_value = False
        res = self.spanner_hook_no_default_project_id.delete_instance(
            project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, instance_id=SPANNER_INSTANCE
        )
        get_client.assert_called_once_with(project_id='example-project')
        instance_method.assert_called_once_with('instance')
        delete_method.assert_called_once_with()
        self.assertIsNone(res)

    @mock.patch('airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_client')
    def test_get_database_overridden_project_id(self, get_client):
        instance_method = get_client.return_value.instance
        instance_exists_method = instance_method.return_value.exists
        instance_exists_method.return_value = True
        database_method = instance_method.return_value.database
        database_exists_method = instance_method.return_value.exists
        database_exists_method.return_value = True
        res = self.spanner_hook_no_default_project_id.get_database(
            project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST,
            instance_id=SPANNER_INSTANCE,
            database_id=SPANNER_DATABASE,
        )
        get_client.assert_called_once_with(project_id='example-project')
        instance_method.assert_called_once_with(instance_id='instance')
        database_method.assert_called_once_with(database_id='database-name')
        database_exists_method.assert_called_once_with()
        self.assertIsNotNone(res)

    @mock.patch('airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_client')
    def test_create_database_overridden_project_id(self, get_client):
        instance_method = get_client.return_value.instance
        instance_exists_method = instance_method.return_value.exists
        instance_exists_method.return_value = True
        database_method = instance_method.return_value.database
        database_create_method = database_method.return_value.create
        res = self.spanner_hook_no_default_project_id.create_database(
            project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST,
            instance_id=SPANNER_INSTANCE,
            database_id=SPANNER_DATABASE,
            ddl_statements=[],
        )
        get_client.assert_called_once_with(project_id='example-project')
        instance_method.assert_called_once_with(instance_id='instance')
        database_method.assert_called_once_with(database_id='database-name', ddl_statements=[])
        database_create_method.assert_called_once_with()
        self.assertIsNone(res)

    @mock.patch('airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_client')
    def test_update_database_overridden_project_id(self, get_client):
        instance_method = get_client.return_value.instance
        instance_exists_method = instance_method.return_value.exists
        instance_exists_method.return_value = True
        database_method = instance_method.return_value.database
        database_update_ddl_method = database_method.return_value.update_ddl
        res = self.spanner_hook_no_default_project_id.update_database(
            project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST,
            instance_id=SPANNER_INSTANCE,
            database_id=SPANNER_DATABASE,
            ddl_statements=[],
        )
        get_client.assert_called_once_with(project_id='example-project')
        instance_method.assert_called_once_with(instance_id='instance')
        database_method.assert_called_once_with(database_id='database-name')
        database_update_ddl_method.assert_called_once_with(ddl_statements=[], operation_id=None)
        self.assertIsNone(res)

    @mock.patch('airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_client')
    def test_update_database_overridden_project_id_and_operation(self, get_client):
        instance_method = get_client.return_value.instance
        instance_exists_method = instance_method.return_value.exists
        instance_exists_method.return_value = True
        database_method = instance_method.return_value.database
        database_update_ddl_method = database_method.return_value.update_ddl
        res = self.spanner_hook_no_default_project_id.update_database(
            project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST,
            instance_id=SPANNER_INSTANCE,
            database_id=SPANNER_DATABASE,
            operation_id="operation",
            ddl_statements=[],
        )
        get_client.assert_called_once_with(project_id='example-project')
        instance_method.assert_called_once_with(instance_id='instance')
        database_method.assert_called_once_with(database_id='database-name')
        database_update_ddl_method.assert_called_once_with(ddl_statements=[], operation_id="operation")
        self.assertIsNone(res)

    @mock.patch('airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_client')
    def test_delete_database_overridden_project_id(self, get_client):
        instance_method = get_client.return_value.instance
        instance_exists_method = instance_method.return_value.exists
        instance_exists_method.return_value = True
        database_method = instance_method.return_value.database
        database_drop_method = database_method.return_value.drop
        database_exists_method = database_method.return_value.exists
        database_exists_method.return_value = True
        res = self.spanner_hook_no_default_project_id.delete_database(
            project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST,
            instance_id=SPANNER_INSTANCE,
            database_id=SPANNER_DATABASE,
        )
        get_client.assert_called_once_with(project_id='example-project')
        instance_method.assert_called_once_with(instance_id='instance')
        database_method.assert_called_once_with(database_id='database-name')
        database_exists_method.assert_called_once_with()
        database_drop_method.assert_called_once_with()
        self.assertTrue(res)

    @mock.patch('airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_client')
    def test_delete_database_missing_database(self, get_client):
        instance_method = get_client.return_value.instance
        instance_exists_method = instance_method.return_value.exists
        instance_exists_method.return_value = True
        database_method = instance_method.return_value.database
        database_drop_method = database_method.return_value.drop
        database_exists_method = database_method.return_value.exists
        database_exists_method.return_value = False
        self.spanner_hook_no_default_project_id.delete_database(
            project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST,
            instance_id=SPANNER_INSTANCE,
            database_id=SPANNER_DATABASE,
        )
        get_client.assert_called_once_with(project_id='example-project')
        instance_method.assert_called_once_with(instance_id='instance')
        database_method.assert_called_once_with(database_id='database-name')
        database_exists_method.assert_called_once_with()
        database_drop_method.assert_not_called()

    @mock.patch('airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_client')
    def test_execute_dml_overridden_project_id(self, get_client):
        instance_method = get_client.return_value.instance
        instance_exists_method = instance_method.return_value.exists
        instance_exists_method.return_value = True
        database_method = instance_method.return_value.database
        run_in_transaction_method = database_method.return_value.run_in_transaction
        res = self.spanner_hook_no_default_project_id.execute_dml(
            project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST,
            instance_id=SPANNER_INSTANCE,
            database_id=SPANNER_DATABASE,
            queries='',
        )
        get_client.assert_called_once_with(project_id='example-project')
        instance_method.assert_called_once_with(instance_id='instance')
        database_method.assert_called_once_with(database_id='database-name')
        run_in_transaction_method.assert_called_once_with(mock.ANY)
        self.assertIsNone(res)