def test_transaction_sql_metrics_ok_call(self):
        if 'DirectRunner' not in self.runner_name:
            raise unittest.SkipTest('This test only runs with DirectRunner.')

        MetricsEnvironment.process_wide_container().reset()

        with beam.Pipeline(argv=self.args) as p:
            transaction = (p
                           | create_transaction(self.project, self.instance,
                                                self.TEST_DATABASE))
            r = p | ReadFromSpanner(self.project,
                                    self.instance,
                                    self.TEST_DATABASE,
                                    sql="select * from Users",
                                    query_name='query-1',
                                    transaction=transaction)

        assert_that(r, equal_to(self._data))
        self.verify_sql_read_call_metric(self.project, self.TEST_DATABASE,
                                         'query-1', 'ok', 1)
    def test_transaction_sql_metrics_error_call(self):
        if 'DirectRunner' not in self.runner_name:
            raise unittest.SkipTest('This test only runs with DirectRunner.')

        MetricsEnvironment.process_wide_container().reset()

        with self.assertRaises(Exception):
            p = beam.Pipeline(argv=self.args)
            transaction = (p
                           | create_transaction(self.project, self.instance,
                                                self.TEST_DATABASE))
            _ = p | ReadFromSpanner(self.project,
                                    self.instance,
                                    self.TEST_DATABASE,
                                    sql="select * from NonExistent",
                                    query_name="query-2",
                                    transaction=transaction)

            res = p.run()
            res.wait_until_finish()

        self.verify_sql_read_call_metric(self.project, self.TEST_DATABASE,
                                         'query-2', '400', 1)
    def test_transaction_table_metrics_error_call(self):
        if 'DirectRunner' not in self.runner_name:
            raise unittest.SkipTest('This test only runs with DirectRunner.')

        MetricsEnvironment.process_wide_container().reset()

        with self.assertRaises(Exception):
            p = beam.Pipeline(argv=self.args)
            transaction = (p
                           | create_transaction(self.project, self.instance,
                                                self.TEST_DATABASE))
            _ = p | ReadFromSpanner(self.project,
                                    self.instance,
                                    self.TEST_DATABASE,
                                    table="INVALID_TABLE",
                                    columns=["UserId", "Key"],
                                    transaction=transaction)

            res = p.run()
            res.wait_until_finish()

        self.verify_table_read_call_metric(self.project, self.TEST_DATABASE,
                                           'INVALID_TABLE', '404', 1)
예제 #4
0
    def test_read_with_transaction(self, mock_batch_snapshot_class,
                                   mock_client_class):
        mock_client = mock.MagicMock()
        mock_instance = mock.MagicMock()
        mock_database = mock.MagicMock()
        mock_snapshot = mock.MagicMock()

        mock_client_class.return_value = mock_client
        mock_client.instance.return_value = mock_instance
        mock_instance.database.return_value = mock_database
        mock_database.batch_snapshot.return_value = mock_snapshot
        mock_batch_snapshot_class.return_value = mock_snapshot
        mock_batch_snapshot_class.from_dict.return_value = mock_snapshot
        mock_snapshot.to_dict.return_value = FAKE_TRANSACTION_INFO

        mock_session = mock.MagicMock()
        mock_transaction_ctx = mock.MagicMock()
        mock_transaction = mock.MagicMock()

        mock_snapshot._get_session.return_value = mock_session
        mock_session.transaction.return_value = mock_transaction
        mock_transaction.__enter__.return_value = mock_transaction_ctx
        mock_transaction_ctx.execute_sql.return_value = FAKE_ROWS

        ro = [ReadOperation.query("Select * from users")]
        p = TestPipeline()

        transaction = (p | create_transaction(
            project_id=TEST_PROJECT_ID,
            instance_id=TEST_INSTANCE_ID,
            database_id=_generate_database_name(),
            exact_staleness=datetime.timedelta(seconds=10)))

        read_query = (p | 'with query' >> ReadFromSpanner(
            project_id=TEST_PROJECT_ID,
            instance_id=TEST_INSTANCE_ID,
            database_id=_generate_database_name(),
            transaction=transaction,
            sql="Select * from users"))

        read_table = (p | 'with table' >> ReadFromSpanner(
            project_id=TEST_PROJECT_ID,
            instance_id=TEST_INSTANCE_ID,
            database_id=_generate_database_name(),
            transaction=transaction,
            table="users",
            columns=["Key", "Value"]))

        read_indexed_table = (p | 'with index' >> ReadFromSpanner(
            project_id=TEST_PROJECT_ID,
            instance_id=TEST_INSTANCE_ID,
            database_id=_generate_database_name(),
            transaction=transaction,
            table="users",
            index="Key",
            columns=["Key", "Value"]))

        read = (p | 'read all' >> ReadFromSpanner(TEST_PROJECT_ID,
                                                  TEST_INSTANCE_ID,
                                                  _generate_database_name(),
                                                  transaction=transaction,
                                                  read_operations=ro))

        read_pipeline = (p
                         | 'create read operations' >> beam.Create(ro)
                         |
                         'reads' >> ReadFromSpanner(TEST_PROJECT_ID,
                                                    TEST_INSTANCE_ID,
                                                    _generate_database_name(),
                                                    transaction=transaction))

        p.run()

        assert_that(read_query, equal_to(FAKE_ROWS), label='checkQuery')
        assert_that(read_table, equal_to(FAKE_ROWS), label='checkTable')
        assert_that(read_indexed_table,
                    equal_to(FAKE_ROWS),
                    label='checkTableIndex')
        assert_that(read, equal_to(FAKE_ROWS), label='checkReadAll')
        assert_that(read_pipeline,
                    equal_to(FAKE_ROWS),
                    label='checkReadPipeline')

        with self.assertRaises(ValueError):
            # Test the exception raised when user passes the read operations in the
            # constructor and also in the pipeline.
            _ = (p
                 | 'create read operations2' >> beam.Create(ro)
                 | 'reads with error' >> ReadFromSpanner(
                     TEST_PROJECT_ID,
                     TEST_INSTANCE_ID,
                     _generate_database_name(),
                     transaction=transaction,
                     read_operations=ro))
            p.run()