Example #1
0
  def test_read_with_query_batch(self, mock_batch_snapshot_class,
                                 mock_client_class):
    mock_snapshot = mock.MagicMock()

    mock_snapshot.generate_query_batches.return_value = [
        {'query': {'sql': 'SELECT * FROM users'},
         'partition': 'test_partition'} for _ in range(3)]
    mock_snapshot.process_query_batch.side_effect = [
        FAKE_ROWS[0:2], FAKE_ROWS[2:4], FAKE_ROWS[4:]]

    ro = [ReadOperation.query("Select * from users")]
    pipeline = TestPipeline()

    read = (pipeline
            | 'read' >> ReadFromSpanner(TEST_PROJECT_ID, TEST_INSTANCE_ID,
                                        _generate_database_name(),
                                        sql="SELECT * FROM users"))

    readall = (pipeline
               | 'read all' >> ReadFromSpanner(TEST_PROJECT_ID,
                                               TEST_INSTANCE_ID,
                                               _generate_database_name(),
                                               read_operations=ro))

    readpipeline = (pipeline
                    | 'create reads' >> beam.Create(ro)
                    | 'reads' >> ReadFromSpanner(TEST_PROJECT_ID,
                                                 TEST_INSTANCE_ID,
                                                 _generate_database_name()))

    pipeline.run()
    assert_that(read, equal_to(FAKE_ROWS), label='checkRead')
    assert_that(readall, equal_to(FAKE_ROWS), label='checkReadAll')
    assert_that(readpipeline, equal_to(FAKE_ROWS), label='checkReadPipeline')
Example #2
0
    def test_read_with_transaction(self, mock_batch_snapshot_class,
                                   mock_client_class):
        mock_client = mock.MagicMock()
        mock_instance = mock.MagicMock()
        mock_database = mock.MagicMock()
        mock_snapshot = mock.MagicMock()

        mock_client_class.return_value = mock_client
        mock_client.instance.return_value = mock_instance
        mock_instance.database.return_value = mock_database
        mock_database.batch_snapshot.return_value = mock_snapshot
        mock_batch_snapshot_class.return_value = mock_snapshot
        mock_batch_snapshot_class.from_dict.return_value = mock_snapshot
        mock_snapshot.to_dict.return_value = FAKE_TRANSACTION_INFO

        mock_session = mock.MagicMock()
        mock_transaction_ctx = mock.MagicMock()
        mock_transaction = mock.MagicMock()

        mock_snapshot._get_session.return_value = mock_session
        mock_session.transaction.return_value = mock_transaction
        mock_transaction.__enter__.return_value = mock_transaction_ctx
        mock_transaction_ctx.execute_sql.return_value = FAKE_ROWS

        ro = [ReadOperation.query("Select * from users")]
        p = TestPipeline()

        transaction = (p | create_transaction(
            project_id=TEST_PROJECT_ID,
            instance_id=TEST_INSTANCE_ID,
            database_id=_generate_database_name(),
            exact_staleness=datetime.timedelta(seconds=10)))

        read_query = (p | 'with query' >> ReadFromSpanner(
            project_id=TEST_PROJECT_ID,
            instance_id=TEST_INSTANCE_ID,
            database_id=_generate_database_name(),
            transaction=transaction,
            sql="Select * from users"))

        read_table = (p | 'with table' >> ReadFromSpanner(
            project_id=TEST_PROJECT_ID,
            instance_id=TEST_INSTANCE_ID,
            database_id=_generate_database_name(),
            transaction=transaction,
            table="users",
            columns=["Key", "Value"]))

        read_indexed_table = (p | 'with index' >> ReadFromSpanner(
            project_id=TEST_PROJECT_ID,
            instance_id=TEST_INSTANCE_ID,
            database_id=_generate_database_name(),
            transaction=transaction,
            table="users",
            index="Key",
            columns=["Key", "Value"]))

        read = (p | 'read all' >> ReadFromSpanner(TEST_PROJECT_ID,
                                                  TEST_INSTANCE_ID,
                                                  _generate_database_name(),
                                                  transaction=transaction,
                                                  read_operations=ro))

        read_pipeline = (p
                         | 'create read operations' >> beam.Create(ro)
                         |
                         'reads' >> ReadFromSpanner(TEST_PROJECT_ID,
                                                    TEST_INSTANCE_ID,
                                                    _generate_database_name(),
                                                    transaction=transaction))

        p.run()

        assert_that(read_query, equal_to(FAKE_ROWS), label='checkQuery')
        assert_that(read_table, equal_to(FAKE_ROWS), label='checkTable')
        assert_that(read_indexed_table,
                    equal_to(FAKE_ROWS),
                    label='checkTableIndex')
        assert_that(read, equal_to(FAKE_ROWS), label='checkReadAll')
        assert_that(read_pipeline,
                    equal_to(FAKE_ROWS),
                    label='checkReadPipeline')

        with self.assertRaises(ValueError):
            # Test the exception raised when user passes the read operations in the
            # constructor and also in the pipeline.
            _ = (p
                 | 'create read operations2' >> beam.Create(ro)
                 | 'reads with error' >> ReadFromSpanner(
                     TEST_PROJECT_ID,
                     TEST_INSTANCE_ID,
                     _generate_database_name(),
                     transaction=transaction,
                     read_operations=ro))
            p.run()