예제 #1
0
def test_execute_sql_async(stub):
    session = stub.CreateSession(
        spanner_pb2.CreateSessionRequest(database=_DATABASE))
    response_future = stub.ExecuteSql.future(
        spanner_pb2.ExecuteSqlRequest(session=session.name, sql=_TEST_SQL))
    response_future.result()
    stub.DeleteSession(spanner_pb2.DeleteSessionRequest(name=session.name))
예제 #2
0
 def test_create_list_delete_session(self):
     stub = spanner_pb2_grpc.SpannerStub(self.channel)
     session = stub.CreateSession(
         spanner_pb2.CreateSessionRequest(database=_DATABASE))
     self.assertIsNotNone(session)
     self.assertEqual(1, len(self.channel._channel_refs))
     self.assertEqual(1, self.channel._channel_refs[0]._affinity_ref)
     self.assertEqual(0, self.channel._channel_refs[0]._active_stream_ref)
     sessions = stub.ListSessions(
         spanner_pb2.ListSessionsRequest(database=_DATABASE))
     self.assertIsNotNone(sessions.sessions)
     self.assertIn(session.name, (s.name for s in sessions.sessions))
     self.assertEqual(1, len(self.channel._channel_refs))
     self.assertEqual(1, self.channel._channel_refs[0]._affinity_ref)
     self.assertEqual(0, self.channel._channel_refs[0]._active_stream_ref)
     stub.DeleteSession(spanner_pb2.DeleteSessionRequest(name=session.name))
     self.assertEqual(1, len(self.channel._channel_refs))
     self.assertEqual(0, self.channel._channel_refs[0]._affinity_ref)
     self.assertEqual(0, self.channel._channel_refs[0]._active_stream_ref)
     sessions = stub.ListSessions(
         spanner_pb2.ListSessionsRequest(database=_DATABASE))
     self.assertNotIn(session.name, (s.name for s in sessions.sessions))
     self.assertEqual(1, len(self.channel._channel_refs))
     self.assertEqual(0, self.channel._channel_refs[0]._affinity_ref)
     self.assertEqual(0, self.channel._channel_refs[0]._active_stream_ref)
예제 #3
0
 def test_execute_sql_future(self):
     stub = spanner_pb2_grpc.SpannerStub(self.channel)
     session = stub.CreateSession(
         spanner_pb2.CreateSessionRequest(database=_DATABASE))
     self.assertEqual(1, len(self.channel._channel_refs))
     self.assertEqual(1, self.channel._channel_refs[0]._affinity_ref)
     self.assertEqual(0, self.channel._channel_refs[0]._active_stream_ref)
     self.assertIsNotNone(session)
     rendezvous = stub.ExecuteSql.future(
         spanner_pb2.ExecuteSqlRequest(session=session.name, sql=_TEST_SQL))
     self.assertEqual(1, len(self.channel._channel_refs))
     self.assertEqual(1, self.channel._channel_refs[0]._affinity_ref)
     self.assertEqual(1, self.channel._channel_refs[0]._active_stream_ref)
     result_set = rendezvous.result()
     self.assertEqual(1, len(self.channel._channel_refs))
     self.assertEqual(1, self.channel._channel_refs[0]._affinity_ref)
     self.assertEqual(0, self.channel._channel_refs[0]._active_stream_ref)
     self.assertIsNotNone(result_set)
     self.assertEqual(1, len(result_set.rows))
     self.assertEqual(_TEST_COLUMN_DATA,
                      result_set.rows[0].values[0].string_value)
     stub.DeleteSession(spanner_pb2.DeleteSessionRequest(name=session.name))
     self.assertEqual(1, len(self.channel._channel_refs))
     self.assertEqual(0, self.channel._channel_refs[0]._affinity_ref)
     self.assertEqual(0, self.channel._channel_refs[0]._active_stream_ref)
예제 #4
0
def test_execute_sql():
    channel = _create_channel()
    stub = _create_stub(channel)

    session = stub.CreateSession(
        spanner_pb2.CreateSessionRequest(database=_DATABASE))

    # warm up
    for _ in range(_NUM_WARM_UP_CALLS):
        stub.ExecuteSql(
            spanner_pb2.ExecuteSqlRequest(
                session=session.name,
                sql='select data from {}'.format(_TABLE)))

    def execute_sql(result):
        # session = stub.CreateSession(
        #     spanner_pb2.CreateSessionRequest(database=_DATABASE))
        for _ in range(_NUM_OF_RPC):
            start = timeit.default_timer()
            stub.ExecuteSql(
                spanner_pb2.ExecuteSqlRequest(
                    session=session.name,
                    sql='select data from {}'.format(_TABLE)))
            dur = timeit.default_timer() - start
            print('single call latency: {} ms'.format(dur * 1000))
            result.append(dur)
        # stub.DeleteSession(
        #     spanner_pb2.DeleteSessionRequest(name=session.name))

    print('Executing blocking unary-unary call.')
    _run_test(channel, execute_sql)

    stub.DeleteSession(spanner_pb2.DeleteSessionRequest(name=session.name))
예제 #5
0
def _read(stub, metrics):
    """Probe to test Read and StreamingRead grpc call from Spanner stub.

  Args:
    stub: An object of SpannerStub.
    metrics: A list of metrics.

  Raises:
    ValueError: An error occurred when read result is not as expected.
  """
    session = None
    try:
        session = stub.CreateSession(
            spanner_pb2.CreateSessionRequest(database=_DATABASE))

        # Probing Read call
        start = time.time()
        result_set = stub.Read(
            spanner_pb2.ReadRequest(
                session=session.name,
                table='users',
                columns=['username', 'firstname', 'lastname'],
                key_set=keys_pb2.KeySet(all=True)))
        latency = (time.time() - start) * 1000
        metrics['read_latency_ms'] = latency

        if result_set is None:
            raise ValueError('result_set is None')
        if len(result_set.rows) != 1:
            raise ValueError('incorrect result_set rows %d' %
                             len(result_set.rows))
        if result_set.rows[0].values[0].string_value != _TEST_USERNAME:
            raise ValueError('incorrect sql result %s' %
                             result_set.rows[0].values[0].string_value)

        # Probing StreamingRead call
        partial_result_set = stub.StreamingRead(
            spanner_pb2.ReadRequest(
                session=session.name,
                table='users',
                columns=['username', 'firstname', 'lastname'],
                key_set=keys_pb2.KeySet(all=True)))

        if partial_result_set is None:
            raise ValueError('streaming_result_set is None')

        start = time.time()
        first_result = partial_result_set.next()
        latency = (time.time() - start) * 1000
        metrics['streaming_read_latency_ms'] = latency

        if first_result.values[0].string_value != _TEST_USERNAME:
            raise ValueError('incorrect streaming sql first result %s' %
                             first_result.values[0].string_value)

    finally:
        if session is not None:
            stub.DeleteSession(
                spanner_pb2.DeleteSessionRequest(name=session.name))
예제 #6
0
def test_execute_streaming_sql(stub):
    session = stub.CreateSession(
        spanner_pb2.CreateSessionRequest(database=_DATABASE))
    rendezvous = stub.ExecuteStreamingSql(
        spanner_pb2.ExecuteSqlRequest(session=session.name, sql=_TEST_SQL))
    for _ in rendezvous:
        pass
    stub.DeleteSession(spanner_pb2.DeleteSessionRequest(name=session.name))
예제 #7
0
 def test_create_session_reuse_channel(self):
     stub = spanner_pb2_grpc.SpannerStub(self.channel)
     for _ in range(_DEFAULT_MAX_CHANNELS_PER_TARGET * 2):
         session = stub.CreateSession(
             spanner_pb2.CreateSessionRequest(database=_DATABASE))
         self.assertIsNotNone(session)
         self.assertEqual(1, len(self.channel._channel_refs))
         stub.DeleteSession(
             spanner_pb2.DeleteSessionRequest(name=session.name))
예제 #8
0
 def test_bound_after_unbind(self):
     stub = spanner_pb2_grpc.SpannerStub(self.channel)
     session = stub.CreateSession(
         spanner_pb2.CreateSessionRequest(database=_DATABASE))
     self.assertEqual(1, len(self.channel._channel_ref_by_affinity_key))
     stub.DeleteSession(spanner_pb2.DeleteSessionRequest(name=session.name))
     self.assertEqual(0, len(self.channel._channel_ref_by_affinity_key))
     with self.assertRaises(Exception) as context:
         stub.GetSession(spanner_pb2.GetSessionRequest(name=session.name))
     self.assertEqual(grpc.StatusCode.NOT_FOUND, context.exception.code())
예제 #9
0
 def test_create_session_new_channel(self):
     stub = spanner_pb2_grpc.SpannerStub(self.channel)
     futures = []
     for i in range(_DEFAULT_MAX_CHANNELS_PER_TARGET):
         futures.append(
             stub.CreateSession.future(
                 spanner_pb2.CreateSessionRequest(database=_DATABASE)))
         self.assertEqual(i + 1, len(self.channel._channel_refs))
     for future in futures:
         stub.DeleteSession(
             spanner_pb2.DeleteSessionRequest(name=future.result().name))
     futures = []
     for i in range(_DEFAULT_MAX_CHANNELS_PER_TARGET):
         futures.append(
             stub.CreateSession.future(
                 spanner_pb2.CreateSessionRequest(database=_DATABASE)))
         self.assertEqual(_DEFAULT_MAX_CHANNELS_PER_TARGET,
                          len(self.channel._channel_refs))
     for future in futures:
         stub.DeleteSession(
             spanner_pb2.DeleteSessionRequest(name=future.result().name))
예제 #10
0
def _execute_sql(stub, metrics):
    """Probes to test ExecuteSql and ExecuteStreamingSql call from Spanner stub.

  Args:
    stub: An object of SpannerStub.
    metrics: A list of metrics.

  Raises:
    ValueError: An error occurred when sql result is not as expected.
  """
    session = None
    try:
        session = stub.CreateSession(
            spanner_pb2.CreateSessionRequest(database=_DATABASE))

        # Probing ExecuteSql call
        start = time.time()
        result_set = stub.ExecuteSql(
            spanner_pb2.ExecuteSqlRequest(session=session.name,
                                          sql='select * FROM users'))
        latency = (time.time() - start) * 1000
        metrics['execute_sql_latency_ms'] = latency

        if result_set is None:
            raise ValueError('result_set is None')
        if len(result_set.rows) != 1:
            raise ValueError('incorrect result_set rows %d' %
                             len(result_set.rows))
        if result_set.rows[0].values[0].string_value != _TEST_USERNAME:
            raise ValueError('incorrect sql result %s' %
                             result_set.rows[0].values[0].string_value)

        # Probing ExecuteStreamingSql call
        partial_result_set = stub.ExecuteStreamingSql(
            spanner_pb2.ExecuteSqlRequest(session=session.name,
                                          sql='select * FROM users'))

        if partial_result_set is None:
            raise ValueError('streaming_result_set is None')

        start = time.time()
        first_result = partial_result_set.next()
        latency = (time.time() - start) * 1000
        metrics['execute_streaming_sql_latency_ms'] = latency

        if first_result.values[0].string_value != _TEST_USERNAME:
            raise ValueError('incorrect streaming sql first result %s' %
                             first_result.values[0].string_value)

    finally:
        if session is not None:
            stub.DeleteSession(
                spanner_pb2.DeleteSessionRequest(name=session.name))
예제 #11
0
    def test_channel_connectivity(self):
        callback = _Callback()

        self.channel.subscribe(callback.update_first, try_to_connect=False)
        stub = spanner_pb2_grpc.SpannerStub(self.channel)
        session = stub.CreateSession(
            spanner_pb2.CreateSessionRequest(database=_DATABASE))
        connectivities = callback.block_until_connectivities_satisfy(
            lambda connectivities: grpc.ChannelConnectivity.READY in
            connectivities)
        self.assertEqual(3, len(connectivities))
        self.assertSequenceEqual((grpc.ChannelConnectivity.IDLE,
                                  grpc.ChannelConnectivity.CONNECTING,
                                  grpc.ChannelConnectivity.READY),
                                 connectivities)
        stub.DeleteSession(spanner_pb2.DeleteSessionRequest(name=session.name))

        self.channel.unsubscribe(callback.update_first)
        session = stub.CreateSession(
            spanner_pb2.CreateSessionRequest(database=_DATABASE))
        self.assertEqual(3, len(connectivities))
        stub.DeleteSession(spanner_pb2.DeleteSessionRequest(name=session.name))
예제 #12
0
def test_max_concurrent_streams():
    channel = _create_channel()
    stub = _create_stub(channel)

    session = stub.CreateSession(
        spanner_pb2.CreateSessionRequest(database=_DATABASE))

    futures = []

    def execute_streaming_sql():

        rendezvous = stub.ExecuteStreamingSql(
            spanner_pb2.ExecuteSqlRequest(
                session=session.name,
                sql='select * from {}'.format(_LARGE_TABLE)))
        futures.append(rendezvous)

    for i in range(_NUM_OF_RPC):
        start = timeit.default_timer()
        execute_streaming_sql()
        print('{} --> started ExecuteStreamingSql with {} ms'.format(
            i + 1, (time.time() - start) * 1000))
    print('Successfully started {} ExecuteStreamingSql calls.'.format(
        _NUM_OF_RPC))

    def print_callback(start):
        dur = (time.time() - start) * 1000
        print('Finished ListSessions async call with {} ms...'.format(dur))

    print('Starting ListSessions async call....')
    new_call_start = time.time()
    list_sessions_future = stub.ListSessions.future(
        spanner_pb2.ListSessionsRequest(database=_DATABASE))
    list_sessions_future.add_done_callback(
        lambda resp: print_callback(new_call_start))
    print('Started ListSessions async call.')

    for _ in futures[0]:
        pass
    print('Freed one active stream')

    print('Free all active streams...')
    for future in futures[1:]:
        for _ in future:
            pass
    print('Done')

    stub.DeleteSession(spanner_pb2.DeleteSessionRequest(name=session.name))
예제 #13
0
def _transaction(stub, metrics):
    """Probe to test BeginTransaction, Commit and Rollback grpc from Spanner stub.

  Args:
    stub: An object of SpannerStub.
    metrics: A list of metrics.
  """
    session = None
    try:
        session = stub.CreateSession(
            spanner_pb2.CreateSessionRequest(database=_DATABASE))

        txn_options = transaction_pb2.TransactionOptions(
            read_write=transaction_pb2.TransactionOptions.ReadWrite())
        txn_request = spanner_pb2.BeginTransactionRequest(
            session=session.name,
            options=txn_options,
        )

        # Probing BeginTransaction call
        start = time.time()
        txn = stub.BeginTransaction(txn_request)
        latency = (time.time() - start) * 1000
        metrics['begin_transaction_latency_ms'] = latency

        # Probing Commit call
        commit_request = spanner_pb2.CommitRequest(session=session.name,
                                                   transaction_id=txn.id)
        start = time.time()
        stub.Commit(commit_request)
        latency = (time.time() - start) * 1000
        metrics['commit_latency_ms'] = latency

        # Probing Rollback call
        txn = stub.BeginTransaction(txn_request)
        rollback_request = spanner_pb2.RollbackRequest(session=session.name,
                                                       transaction_id=txn.id)
        start = time.time()
        stub.Rollback(rollback_request)
        latency = (time.time() - start) * 1000
        metrics['rollback_latency_ms'] = latency

    finally:
        if session is not None:
            stub.DeleteSession(
                spanner_pb2.DeleteSessionRequest(name=session.name))
예제 #14
0
def _partition(stub, metrics):
    """Probe to test PartitionQuery and PartitionRead grpc call from Spanner stub.

  Args:
    stub: An object of SpannerStub.
    metrics: A list of metrics.
  """
    session = None
    try:
        session = stub.CreateSession(
            spanner_pb2.CreateSessionRequest(database=_DATABASE))
        txn_options = transaction_pb2.TransactionOptions(
            read_only=transaction_pb2.TransactionOptions.ReadOnly())
        txn_selector = transaction_pb2.TransactionSelector(begin=txn_options)

        # Probing PartitionQuery call
        ptn_query_request = spanner_pb2.PartitionQueryRequest(
            session=session.name,
            sql='select * FROM users',
            transaction=txn_selector,
        )
        start = time.time()
        stub.PartitionQuery(ptn_query_request)
        latency = (time.time() - start) * 1000
        metrics['partition_query_latency_ms'] = latency

        # Probing PartitionRead call
        ptn_read_request = spanner_pb2.PartitionReadRequest(
            session=session.name,
            table='users',
            transaction=txn_selector,
            key_set=keys_pb2.KeySet(all=True),
            columns=['username', 'firstname', 'lastname'])
        start = time.time()
        stub.PartitionRead(ptn_read_request)
        latency = (time.time() - start) * 1000
        metrics['partition_read_latency_ms'] = latency

    finally:
        if session is not None:
            stub.DeleteSession(
                spanner_pb2.DeleteSessionRequest(name=session.name))
예제 #15
0
def prepare_test_data():
    channel = _create_channel()
    stub = _create_stub(channel)

    print('Start adding payload to test table.')
    session = stub.CreateSession(
        spanner_pb2.CreateSessionRequest(database=_DATABASE))
    stub.Commit(
        spanner_pb2.CommitRequest(
            session=session.name,
            single_use_transaction=transaction_pb2.TransactionOptions(
                read_write=transaction_pb2.TransactionOptions.ReadWrite()),
            mutations=[
                mutation_pb2.Mutation(delete=mutation_pb2.Mutation.Delete(
                    table=_LARGE_TABLE, key_set=keys_pb2.KeySet(all=True)))
            ]))

    # because of max data size, we need to seperate into different rows
    column_bytes = min(_PAYLOAD_BYTES, _MAX_SIZE_PER_COLUMN)
    rows = (_PAYLOAD_BYTES - 1) / column_bytes + 1
    for i in range(rows):
        stub.Commit(
            spanner_pb2.CommitRequest(
                session=session.name,
                single_use_transaction=transaction_pb2.TransactionOptions(
                    read_write=transaction_pb2.TransactionOptions.ReadWrite()),
                mutations=[
                    mutation_pb2.Mutation(
                        insert_or_update=mutation_pb2.Mutation.Write(
                            table=_LARGE_TABLE,
                            columns=['id', 'data'],
                            values=[
                                google.protobuf.struct_pb2.ListValue(values=[
                                    google.protobuf.struct_pb2.Value(
                                        string_value='payload{}'.format(i)),
                                    google.protobuf.struct_pb2.Value(
                                        string_value='x' * column_bytes)
                                ])
                            ]))
                ]))
    print('Successfully add payload to table.')
    stub.DeleteSession(spanner_pb2.DeleteSessionRequest(name=session.name))
예제 #16
0
def test_execute_streaming_sql():
    channel = _create_channel()
    stub = _create_stub(channel)
    # _prepare_test_data(stub)

    session = stub.CreateSession(
        spanner_pb2.CreateSessionRequest(database=_DATABASE))

    # warm up
    print('Begin warm up calls.')
    for _ in range(_NUM_WARM_UP_CALLS):
        rendezvous = stub.ExecuteStreamingSql(
            spanner_pb2.ExecuteSqlRequest(
                session=session.name,
                sql='select data from {}'.format(_TABLE)))
        for _ in rendezvous:
            pass
    print('Warm up finished.')

    def execute_streaming_sql(result):

        for _ in range(_NUM_OF_RPC):
            start = timeit.default_timer()
            rendezvous = stub.ExecuteStreamingSql(
                spanner_pb2.ExecuteSqlRequest(
                    session=session.name,
                    sql='select data from {}'.format(_TABLE)))

            def callback(resp, start_copy=start):
                dur = time.time() - start_copy
                result.append(dur)

            rendezvous.add_done_callback(callback)

            for _ in rendezvous:
                pass

    print('Executing unary-streaming call.')
    _run_test(channel, execute_streaming_sql)

    stub.DeleteSession(spanner_pb2.DeleteSessionRequest(name=session.name))
예제 #17
0
    def test_channel_connectivity_invalid_target(self):
        config = config = grpc_gcp.api_config_from_text_pb(
            pkg_resources.resource_string(__name__, 'spanner.grpc.config'))
        http_request = Request()
        credentials, _ = google.auth.default([_OAUTH_SCOPE], http_request)
        invalid_channel = self._create_secure_gcp_channel(
            credentials,
            http_request,
            'localhost:1234',
            options=[(grpc_gcp.API_CONFIG_CHANNEL_ARG, config)])

        callback = _Callback()
        invalid_channel.subscribe(callback.update_first, try_to_connect=False)

        stub = spanner_pb2_grpc.SpannerStub(invalid_channel)
        with self.assertRaises(Exception) as context:
            stub.CreateSession(
                spanner_pb2.CreateSessionRequest(database=_DATABASE))
        self.assertEqual(grpc.StatusCode.UNAVAILABLE, context.exception.code())
        first_connectivities = callback.block_until_connectivities_satisfy(
            lambda connectivities: len(connectivities) >= 3)
        self.assertEqual(grpc.ChannelConnectivity.IDLE,
                         first_connectivities[0])
        self.assertIn(grpc.ChannelConnectivity.CONNECTING,
                      first_connectivities)
        self.assertIn(grpc.ChannelConnectivity.TRANSIENT_FAILURE,
                      first_connectivities)

        invalid_channel.subscribe(callback.update_second, try_to_connect=True)
        second_connectivities = callback.block_until_connectivities_satisfy(
            lambda connectivities: len(connectivities) >= 3, False)
        self.assertNotIn(grpc.ChannelConnectivity.IDLE, second_connectivities)
        self.assertIn(grpc.ChannelConnectivity.CONNECTING,
                      second_connectivities)
        self.assertIn(grpc.ChannelConnectivity.TRANSIENT_FAILURE,
                      second_connectivities)

        self.assertEqual(2, len(invalid_channel._subscribers))
        invalid_channel.unsubscribe(callback.update_first)
        invalid_channel.unsubscribe(callback.update_second)
        self.assertEqual(0, len(invalid_channel._subscribers))
예제 #18
0
    def test_channel_connectivity_multiple_subchannels(self):
        callback = _Callback()

        self.channel.subscribe(callback.update_first, try_to_connect=False)
        stub = spanner_pb2_grpc.SpannerStub(self.channel)
        futures = []
        for _ in range(2):
            futures.append(
                stub.CreateSession.future(
                    spanner_pb2.CreateSessionRequest(database=_DATABASE)))
        connectivities = callback.block_until_connectivities_satisfy(
            lambda connectivities: grpc.ChannelConnectivity.READY in
            connectivities)

        self.assertEqual(2, len(self.channel._channel_refs))
        self.assertSequenceEqual((grpc.ChannelConnectivity.IDLE,
                                  grpc.ChannelConnectivity.CONNECTING,
                                  grpc.ChannelConnectivity.READY),
                                 connectivities)
        for future in futures:
            stub.DeleteSession(
                spanner_pb2.DeleteSessionRequest(name=future.result().name))
예제 #19
0
def test_execute_sql_async():
    channel = _create_channel()
    stub = _create_stub(channel)

    session = stub.CreateSession(
        spanner_pb2.CreateSessionRequest(database=_DATABASE))

    # warm up
    for _ in range(_NUM_WARM_UP_CALLS):
        resp_future = stub.ExecuteSql.future(
            spanner_pb2.ExecuteSqlRequest(session=session.name,
                                          sql='select data from storage'))
        resp_future.result()

    def execute_sql_async(result):
        # session = stub.CreateSession(
        #     spanner_pb2.CreateSessionRequest(database=_DATABASE))
        for _ in range(_NUM_OF_RPC):
            start = timeit.default_timer()
            resp_future = stub.ExecuteSql.future(
                spanner_pb2.ExecuteSqlRequest(session=session.name,
                                              sql='select data from storage'),
                _TIMEOUT)

            def callback(resp, start_copy=start):
                dur = time.time() - start_copy
                result.append(dur)

            resp_future.add_done_callback(callback)
        # stub.DeleteSession(
        #     spanner_pb2.DeleteSessionRequest(name=session.name))

    print('Executing async unary-unary call.')
    _run_test(channel, execute_sql_async)

    stub.DeleteSession(spanner_pb2.DeleteSessionRequest(name=session.name))
예제 #20
0
def test_execute_sql(stub):
    session = stub.CreateSession(
        spanner_pb2.CreateSessionRequest(database=_DATABASE))
    stub.ExecuteSql(
        spanner_pb2.ExecuteSqlRequest(session=session.name, sql=_TEST_SQL))
    stub.DeleteSession(spanner_pb2.DeleteSessionRequest(name=session.name))
예제 #21
0
    def test_concurrent_streams_watermark(self):
        stub = spanner_pb2_grpc.SpannerStub(self.channel)
        watermark = 2
        self.channel._max_concurrent_streams_low_watermark = watermark
        self.assertEqual(self.channel._max_concurrent_streams_low_watermark,
                         watermark)

        session_list = []
        rendezvous_list = []

        # When active streams have not reached the concurrent_streams_watermark,
        # gRPC calls should be reusing the same channel.
        for i in range(watermark):
            session = stub.CreateSession(
                spanner_pb2.CreateSessionRequest(database=_DATABASE))
            self.assertEqual(1, len(self.channel._channel_refs))
            self.assertEqual(i + 1,
                             self.channel._channel_refs[0]._affinity_ref)
            self.assertEqual(i,
                             self.channel._channel_refs[0]._active_stream_ref)
            self.assertIsNotNone(session)
            session_list.append(session)

            rendezvous = stub.ExecuteStreamingSql(
                spanner_pb2.ExecuteSqlRequest(session=session.name,
                                              sql=_TEST_SQL))
            self.assertEqual(1, len(self.channel._channel_refs))
            self.assertEqual(i + 1,
                             self.channel._channel_refs[0]._affinity_ref)
            self.assertEqual(i + 1,
                             self.channel._channel_refs[0]._active_stream_ref)
            rendezvous_list.append(rendezvous)

        # When active streams reach the concurrent_streams_watermark,
        # channel pool will create a new channel.
        another_session = stub.CreateSession(
            spanner_pb2.CreateSessionRequest(database=_DATABASE))
        self.assertEqual(2, len(self.channel._channel_refs))
        self.assertEqual(2, self.channel._channel_refs[0]._affinity_ref)
        self.assertEqual(2, self.channel._channel_refs[0]._active_stream_ref)
        self.assertEqual(1, self.channel._channel_refs[1]._affinity_ref)
        self.assertEqual(0, self.channel._channel_refs[1]._active_stream_ref)
        self.assertIsNotNone(another_session)
        session_list.append(another_session)

        another_rendezvous = stub.ExecuteStreamingSql(
            spanner_pb2.ExecuteSqlRequest(session=another_session.name,
                                          sql=_TEST_SQL))
        self.assertEqual(2, len(self.channel._channel_refs))
        self.assertEqual(2, self.channel._channel_refs[0]._affinity_ref)
        self.assertEqual(2, self.channel._channel_refs[0]._active_stream_ref)
        self.assertEqual(1, self.channel._channel_refs[1]._affinity_ref)
        self.assertEqual(1, self.channel._channel_refs[1]._active_stream_ref)
        rendezvous_list.append(another_rendezvous)

        # Iterate through the rendezous list to clean active streams.
        for rendezvous in rendezvous_list:
            for _ in rendezvous:
                continue

        # After cleaning, previously created channels will remain in the pool.
        self.assertEqual(2, len(self.channel._channel_refs))
        self.assertEqual(2, self.channel._channel_refs[0]._affinity_ref)
        self.assertEqual(0, self.channel._channel_refs[0]._active_stream_ref)
        self.assertEqual(1, self.channel._channel_refs[1]._affinity_ref)
        self.assertEqual(0, self.channel._channel_refs[1]._active_stream_ref)

        # Delete all sessions to clean affinity.
        for session in session_list:
            stub.DeleteSession(
                spanner_pb2.DeleteSessionRequest(name=session.name))

        self.assertEqual(2, len(self.channel._channel_refs))
        self.assertEqual(0, self.channel._channel_refs[0]._affinity_ref)
        self.assertEqual(0, self.channel._channel_refs[0]._active_stream_ref)
        self.assertEqual(0, self.channel._channel_refs[1]._affinity_ref)
        self.assertEqual(0, self.channel._channel_refs[1]._active_stream_ref)
예제 #22
0
def _session_management(stub, metrics):
    """Probes to test session related grpc call from Spanner stub.

  Includes tests against CreateSession, GetSession, ListSessions, and
  DeleteSession of Spanner stub.

  Args:
    stub: An object of SpannerStub.
    metrics: A list of metrics.

  Raises:
    TypeError: An error occurred when result type is not as expected.
    ValueError: An error occurred when session name is not as expected.
  """
    session = None
    try:
        # Create session
        start = time.time()
        session = stub.CreateSession(
            spanner_pb2.CreateSessionRequest(database=_DATABASE))
        latency = (time.time() - start) * 1000
        metrics['create_session_latency_ms'] = latency

        if not isinstance(session, spanner_pb2.Session):
            raise TypeError(
                'response is of type %s, not spanner_pb2.Session!' %
                type(session))

        # Get session
        start = time.time()
        response = stub.GetSession(
            spanner_pb2.GetSessionRequest(name=session.name))
        latency = (time.time() - start) * 1000
        metrics['get_session_latency_ms'] = latency

        if not isinstance(response, spanner_pb2.Session):
            raise TypeError(
                'response is of type %s, not spanner_pb2.Session!' %
                type(response))
        if response.name != session.name:
            raise ValueError('incorrect session name %s' % response.name)

        # List sessions
        start = time.time()
        response = stub.ListSessions(
            spanner_pb2.ListSessionsRequest(database=_DATABASE))
        latency = (time.time() - start) * 1000
        metrics['list_sessions_latency_ms'] = latency

        session_list = response.sessions

        if session.name not in (s.name for s in session_list):
            raise ValueError(
                'session name %s is not in the result session list!' %
                session.name)

    finally:
        if session is not None:
            start = time.time()
            stub.DeleteSession(
                spanner_pb2.DeleteSessionRequest(name=session.name))
            latency = (time.time() - start) * 1000
            metrics['delete_session_latency_ms'] = latency