def test_display_data(self, *args): dd_sql = ReadFromSpanner(project_id=TEST_PROJECT_ID, instance_id=TEST_INSTANCE_ID, database_id=_generate_database_name(), sql="Select * from users").display_data() dd_table = ReadFromSpanner(project_id=TEST_PROJECT_ID, instance_id=TEST_INSTANCE_ID, database_id=_generate_database_name(), table="users", columns=['id', 'name']).display_data() dd_transaction = ReadFromSpanner(project_id=TEST_PROJECT_ID, instance_id=TEST_INSTANCE_ID, database_id=_generate_database_name(), table="users", columns=['id', 'name'], transaction={ "transaction_id": "test123", "session_id": "test456" }).display_data() self.assertTrue("sql" in dd_sql) self.assertTrue("table" in dd_table) self.assertTrue("table" in dd_transaction) self.assertTrue("transaction" in dd_transaction)
def test_read_with_query_batch(self, mock_batch_snapshot_class, mock_client_class): mock_snapshot = mock.MagicMock() mock_snapshot.generate_query_batches.return_value = [ {'query': {'sql': 'SELECT * FROM users'}, 'partition': 'test_partition'} for _ in range(3)] mock_snapshot.process_query_batch.side_effect = [ FAKE_ROWS[0:2], FAKE_ROWS[2:4], FAKE_ROWS[4:]] ro = [ReadOperation.query("Select * from users")] pipeline = TestPipeline() read = (pipeline | 'read' >> ReadFromSpanner(TEST_PROJECT_ID, TEST_INSTANCE_ID, _generate_database_name(), sql="SELECT * FROM users")) readall = (pipeline | 'read all' >> ReadFromSpanner(TEST_PROJECT_ID, TEST_INSTANCE_ID, _generate_database_name(), read_operations=ro)) readpipeline = (pipeline | 'create reads' >> beam.Create(ro) | 'reads' >> ReadFromSpanner(TEST_PROJECT_ID, TEST_INSTANCE_ID, _generate_database_name())) pipeline.run() assert_that(read, equal_to(FAKE_ROWS), label='checkRead') assert_that(readall, equal_to(FAKE_ROWS), label='checkReadAll') assert_that(readpipeline, equal_to(FAKE_ROWS), label='checkReadPipeline')
def test_read_with_table_batch(self, mock_batch_snapshot_class, mock_client_class): mock_snapshot = mock.MagicMock() mock_snapshot.generate_read_batches.return_value = [{ 'read': { 'table': 'users', 'keyset': { 'all': True }, 'columns': ['Key', 'Value'], 'index': '' }, 'partition': 'test_partition' } for _ in range(3)] mock_snapshot.process_read_batch.side_effect = [ FAKE_ROWS[0:2], FAKE_ROWS[2:4], FAKE_ROWS[4:] ] ro = [ReadOperation.table("users", ["Key", "Value"])] pipeline = TestPipeline() read = (pipeline | 'read' >> ReadFromSpanner(TEST_PROJECT_ID, TEST_INSTANCE_ID, _generate_database_name(), table="users", columns=["Key", "Value"])) readall = (pipeline | 'read all' >> ReadFromSpanner(TEST_PROJECT_ID, TEST_INSTANCE_ID, _generate_database_name(), read_operations=ro)) readpipeline = ( pipeline | 'create reads' >> beam.Create(ro) | 'reads' >> ReadFromSpanner(TEST_PROJECT_ID, TEST_INSTANCE_ID, _generate_database_name())) pipeline.run() assert_that(read, equal_to(FAKE_ROWS), label='checkRead') assert_that(readall, equal_to(FAKE_ROWS), label='checkReadAll') assert_that(readpipeline, equal_to(FAKE_ROWS), label='checkReadPipeline') with self.assertRaises(ValueError): # Test the exception raised when user passes the read operations in the # constructor and also in the pipeline. _ = (pipeline | 'reads error' >> ReadFromSpanner( project_id=TEST_PROJECT_ID, instance_id=TEST_INSTANCE_ID, database_id=_generate_database_name(), table="users")) pipeline.run()
def test_read_via_sql(self): _LOGGER.info("Running Spanner via sql") with beam.Pipeline(argv=self.args) as p: r = p | ReadFromSpanner(self.project, self.instance, self.TEST_DATABASE, sql="select * from Users") assert_that(r, equal_to(self._data))
def test_read_via_table(self): _LOGGER.info("Spanner Read via table") with beam.Pipeline(argv=self.args) as p: r = p | ReadFromSpanner(self.project, self.instance, self.TEST_DATABASE, table="Users", columns=["UserId", "Key"]) assert_that(r, equal_to(self._data))
def test_invalid_transaction(self, mock_batch_snapshot_class, mock_client_class): with self.assertRaises(ValueError): p = TestPipeline() transaction = (p | beam.Create([{"invalid": "transaction"}])) _ = (p | 'with query' >> ReadFromSpanner( project_id=TEST_PROJECT_ID, instance_id=TEST_INSTANCE_ID, database_id=_generate_database_name(), transaction=transaction, sql="Select * from users")) p.run()
def test(self): output = ( self.pipeline | 'Read from Spanner' >> ReadFromSpanner( self.project, self.spanner_instance, self.spanner_database, sql="select data from test_data") | 'Count messages' >> ParDo(CountMessages(self.metrics_namespace)) | 'Measure time' >> ParDo(MeasureTime(self.metrics_namespace)) | 'Count' >> Count.Globally()) assert_that(output, equal_to([self.input_options['num_records']]))
def test_sql_metrics_ok_call(self): if 'DirectRunner' not in self.runner_name: raise unittest.SkipTest('This test only runs with DirectRunner.') MetricsEnvironment.process_wide_container().reset() with beam.Pipeline(argv=self.args) as p: r = p | ReadFromSpanner(self.project, self.instance, self.TEST_DATABASE, sql="select * from Users", query_name='query-1') assert_that(r, equal_to(self._data)) self.verify_sql_read_call_metric(self.project, self.TEST_DATABASE, 'query-1', 'ok', 1)
def test_table_metrics_ok_call(self): if 'DirectRunner' not in self.runner_name: raise unittest.SkipTest('This test only runs with DirectRunner.') MetricsEnvironment.process_wide_container().reset() with beam.Pipeline(argv=self.args) as p: r = p | ReadFromSpanner(self.project, self.instance, self.TEST_DATABASE, table="Users", columns=["UserId", "Key"]) assert_that(r, equal_to(self._data)) self.verify_table_read_call_metric(self.project, self.TEST_DATABASE, 'Users', 'ok', 1)
def test_sql_metrics_error_call(self): if 'DirectRunner' not in self.runner_name: raise unittest.SkipTest('This test only runs with DirectRunner.') MetricsEnvironment.process_wide_container().reset() with self.assertRaises(Exception): p = beam.Pipeline(argv=self.args) _ = p | ReadFromSpanner(self.project, self.instance, self.TEST_DATABASE, sql="select * from NonExistent", query_name='query-2') res = p.run() res.wait_until_finish() self.verify_sql_read_call_metric(self.project, self.TEST_DATABASE, 'query-2', '400', 1)
def test_table_metrics_error_call(self): if 'DirectRunner' not in self.runner_name: raise unittest.SkipTest('This test only runs with DirectRunner.') MetricsEnvironment.process_wide_container().reset() with self.assertRaises(Exception): p = beam.Pipeline(argv=self.args) _ = p | ReadFromSpanner( self.project, self.instance, self.TEST_DATABASE, table="INVALID_TABLE", columns=["UserId", "Key"]) res = p.run() res.wait_until_finish() self.verify_table_read_call_metric( self.project, self.TEST_DATABASE, 'INVALID_TABLE', '404', 1)
def test_read_with_transaction(self, mock_batch_snapshot_class, mock_client_class): mock_client = mock.MagicMock() mock_instance = mock.MagicMock() mock_database = mock.MagicMock() mock_snapshot = mock.MagicMock() mock_client_class.return_value = mock_client mock_client.instance.return_value = mock_instance mock_instance.database.return_value = mock_database mock_database.batch_snapshot.return_value = mock_snapshot mock_batch_snapshot_class.return_value = mock_snapshot mock_batch_snapshot_class.from_dict.return_value = mock_snapshot mock_snapshot.to_dict.return_value = FAKE_TRANSACTION_INFO mock_session = mock.MagicMock() mock_transaction_ctx = mock.MagicMock() mock_transaction = mock.MagicMock() mock_snapshot._get_session.return_value = mock_session mock_session.transaction.return_value = mock_transaction mock_transaction.__enter__.return_value = mock_transaction_ctx mock_transaction_ctx.execute_sql.return_value = FAKE_ROWS ro = [ReadOperation.query("Select * from users")] p = TestPipeline() transaction = (p | create_transaction( project_id=TEST_PROJECT_ID, instance_id=TEST_INSTANCE_ID, database_id=_generate_database_name(), exact_staleness=datetime.timedelta(seconds=10))) read_query = (p | 'with query' >> ReadFromSpanner( project_id=TEST_PROJECT_ID, instance_id=TEST_INSTANCE_ID, database_id=_generate_database_name(), transaction=transaction, sql="Select * from users")) read_table = (p | 'with table' >> ReadFromSpanner( project_id=TEST_PROJECT_ID, instance_id=TEST_INSTANCE_ID, database_id=_generate_database_name(), transaction=transaction, table="users", columns=["Key", "Value"])) read_indexed_table = (p | 'with index' >> ReadFromSpanner( project_id=TEST_PROJECT_ID, instance_id=TEST_INSTANCE_ID, database_id=_generate_database_name(), transaction=transaction, table="users", index="Key", columns=["Key", "Value"])) read = (p | 'read all' >> ReadFromSpanner(TEST_PROJECT_ID, TEST_INSTANCE_ID, _generate_database_name(), transaction=transaction, read_operations=ro)) read_pipeline = (p | 'create read operations' >> beam.Create(ro) | 'reads' >> ReadFromSpanner(TEST_PROJECT_ID, TEST_INSTANCE_ID, _generate_database_name(), transaction=transaction)) p.run() assert_that(read_query, equal_to(FAKE_ROWS), label='checkQuery') assert_that(read_table, equal_to(FAKE_ROWS), label='checkTable') assert_that(read_indexed_table, equal_to(FAKE_ROWS), label='checkTableIndex') assert_that(read, equal_to(FAKE_ROWS), label='checkReadAll') assert_that(read_pipeline, equal_to(FAKE_ROWS), label='checkReadPipeline') with self.assertRaises(ValueError): # Test the exception raised when user passes the read operations in the # constructor and also in the pipeline. _ = (p | 'create read operations2' >> beam.Create(ro) | 'reads with error' >> ReadFromSpanner( TEST_PROJECT_ID, TEST_INSTANCE_ID, _generate_database_name(), transaction=transaction, read_operations=ro)) p.run()