def test_multiple_mixins(self):
        class AnotherThriftClient(object):
            def __init__(self, protocol):
                pass

            def method_hello(self):
                return "hello"

        class AnotherThriftClientMixin(AnotherThriftClient, ThriftClientMixin):
            pass

        client = FakeThriftClientMixin(HostsProvider(HOSTS))
        another_client = AnotherThriftClientMixin(
            HostsProvider(['anotherhost1:1000', 'anotherhost2:2000']))

        # Call a method of the first client.
        self.assertTrue(client.method_success())
        # Connection should be established.
        self.assertEqual('fakehost1', client.host)
        self.assertEqual(100, client.port)
        self.assertTrue(client.connected)
        self.assertEqual(1, client.requests_served)

        # Call a method of the second client.
        self.assertEqual("hello", another_client.method_hello())
        # Connection should be established. Due to the way random.choice is
        # mocked, the host/port should be the second one.
        self.assertEqual('anotherhost2', another_client.host)
        self.assertEqual(2000, another_client.port)
        self.assertTrue(another_client.connected)
        self.assertEqual(1, another_client.requests_served)
    def test_connection_expiration(self):
        # The connection should expire immediately.
        client = FakePooledThriftClientMixin(
            host_provider=HostsProvider(HOSTS), connection_expiration=1)
        self.assertRaises(ExpiredConnection, client.method_sleep, 0.1)

        # The expiration is long enough to allow the method to complete.
        client = FakePooledThriftClientMixin(
            host_provider=HostsProvider(HOSTS), connection_expiration=1000)
        self.assertTrue(client.method_sleep(0.1))
 def test_network_error(self):
     client = FakePooledThriftClientMixin(
         host_provider=HostsProvider(HOSTS))
     # Calling a function to simulate network error. Note that we are
     # raising TTransportException but thrift mixin makes sure it gets
     # rewritten as ThriftConnectionError.
     self.assertRaises(ThriftConnectionError, client.method_network_error)
    def test_teardown(self):
        client = FakeThriftClientMixin(HostsProvider(HOSTS))
        self.assertTrue(client.method_success())
        # Connection should be established on the first endpoint.
        self.assertEqual('fakehost1', client.host)
        self.assertEqual(100, client.port)
        self.assertTrue(client.connected)
        self.assertEqual(1, client.requests_served)

        # Make the second call.
        self.assertTrue(client.method_success())
        # Connection should be teared down and endpoint is reset.
        self.assertEqual('fakehost2', client.host)
        self.assertEqual(200, client.port)
        self.assertFalse(client.connected)
        self.assertEqual(0, client.requests_served)

        # Make the third call.
        self.assertTrue(client.method_success())
        # Connection should be established on the second endpoint.
        self.assertEqual('fakehost2', client.host)
        self.assertEqual(200, client.port)
        self.assertTrue(client.connected)
        self.assertEqual(1, client.requests_served)

        # Make the forth call.
        self.assertTrue(client.method_success())
        # Connection should again be teared down and endpoint is reset.
        self.assertEqual('fakehost3', client.host)
        self.assertEqual(300, client.port)
        self.assertFalse(client.connected)
        self.assertEqual(0, client.requests_served)
Example #5
0
    def test_retrieving_and_invalidation(self):
        """Test host retrieval."""
        host_provider = HostsProvider(HostSelectorTestCase.HOST_LIST)
        base_host_selector = BaseHostSelector(host_provider,
                                              expire_time=0,
                                              retry_time=0,
                                              invalidation_threshold=1.0)
        self.assertTrue(base_host_selector.get_last_host() is None)

        with patch(hosts.__name__ + ".BaseHostSelector._choose_host",
                   new=Mock(return_value=HostSelectorTestCase.HOST_LIST[0])):
            # Get one host.
            host1 = base_host_selector.get_host()
            self.assertEquals(host1, HostSelectorTestCase.HOST_LIST[0])
            # If invalidated the state of the object changes.
            self.assertTrue(host1 not in base_host_selector._bad_hosts)
            base_host_selector.invalidate()
            self.assertTrue(host1 in base_host_selector._bad_hosts)

        # If called again, with retry_time being set to 0 bad hosts should be
        # invalidated.
        with patch(hosts.__name__ + ".BaseHostSelector._choose_host",
                   new=Mock(return_value=HostSelectorTestCase.HOST_LIST[1])):
            host2 = base_host_selector.get_host()
            # Now bad hosts should be empty
            self.assertTrue(not base_host_selector._bad_hosts)
            self.assertEquals(host2, HostSelectorTestCase.HOST_LIST[1])
            base_host_selector.invalidate()
            self.assertTrue(host2 in base_host_selector._bad_hosts)
    def test_teardown_secs(self):
        start_ts = time.time()
        client = FakeThriftClientMixin(HostsProvider(HOSTS))
        self.assertTrue(client.method_success())
        # Connection should be established on the first endpoint.
        self.assertEqual('fakehost1', client.host)
        self.assertEqual(100, client.port)
        self.assertTrue(client.connected)
        self.assertEqual(1, client.requests_served)
        self.assertGreater(client.connected_at, start_ts)

        client.connected_at -= SECS_FOR_CONNECTION_TEARDOWN

        # Make the call.
        self.assertTrue(client.method_success())
        # Connection should be teared down.
        self.assertEqual('fakehost2', client.host)
        self.assertEqual(200, client.port)
        self.assertFalse(client.connected)
        self.assertEqual(0, client.requests_served)

        # Another call.
        start_ts = time.time()
        self.assertTrue(client.method_success())
        # Connection re-established.
        self.assertEqual('fakehost2', client.host)
        self.assertEqual(200, client.port)
        self.assertTrue(client.connected)
        self.assertEqual(1, client.requests_served)
        self.assertGreater(client.connected_at, start_ts)
    def test_timeout(self):
        client = FakeThriftClientMixin(HostsProvider(HOSTS),
                                       timeout=5000,
                                       socket_connection_timeout=1000)
        self.assertTrue(client.method_success())
        # in case there is no rpc_timeout, the default socket connection
        # timeout should have been 1000 ms.
        client._socket.setTimeout.assert_any_call(1000)
        # in case there is no rpc_timeout, the default request timeout
        # should have been the same as timeout specified above.
        client._socket.setTimeout.assert_any_call(5000)
        self.assertTrue(client.method_success(rpc_timeout_ms=100))
        self.assertEqual(2, client.requests_served)
        self.assertEqual('fakehost1', client.host)
        self.assertEqual(100, client.port)
        self.assertTrue(client.connected)
        self.assertEqual(client.timeout, 5000)
        client._socket.setTimeout.assert_called_with(100)

        # call it again
        self.assertTrue(client.method_success(rpc_timeout_ms=200))
        self.assertEqual(3, client.requests_served)
        self.assertEqual('fakehost1', client.host)
        self.assertEqual(100, client.port)
        self.assertTrue(client.connected)
        self.assertEqual(client.timeout, 5000)
        client._socket.setTimeout.assert_called_with(200)
Example #8
0
 def test_reject_invalidation(self):
     """Test rejecting invalidation."""
     fd, tmp_file = tempfile.mkstemp()
     with open(tmp_file, 'w') as f:
         f.write('\n'.join(HostSelectorWithLocalFileTestCase.HOST_LIST))
     host_provider = HostsProvider(
         HostSelectorWithLocalFileTestCase.HOST_LIST, file_path=tmp_file)
     base_host_selector = BaseHostSelector(host_provider,
                                           expire_time=0,
                                           retry_time=0)
     with patch(hosts.__name__ + ".BaseHostSelector._choose_host",
                new=Mock(return_value=HostSelectorWithLocalFileTestCase.
                         HOST_LIST[0])):
         # Get one host.
         host1 = base_host_selector.get_host()
         self.assertEquals(host1,
                           HostSelectorWithLocalFileTestCase.HOST_LIST[0])
         # If invalidated the state of the object changes.
         self.assertTrue(host1 not in base_host_selector._bad_hosts)
         base_host_selector.invalidate()
         # Because 1 is larger than 2 * 0.2 = 0.4
         self.assertTrue(host1 not in base_host_selector._bad_hosts)
         base_host_selector._invalidation_threshold = 0.5
         host1 = base_host_selector.get_host()
         self.assertEquals(host1,
                           HostSelectorWithLocalFileTestCase.HOST_LIST[0])
         base_host_selector.invalidate()
         # Because 1 <= 2 * 0.5 = 1.0
         self.assertTrue(host1 in base_host_selector._bad_hosts)
     HostSelectorWithLocalFileTestCase.FILE_WATCH._clear_all_watches()
     os.remove(tmp_file)
    def test_e2e_timeout(self):
        # a test case where connection expiration is not triggered, but the
        # overall time taken is longer than connection expiration timeout.
        client = FakePooledThriftClientMixin(
            host_provider=HostsProvider(HOSTS),
            pool_size=1,
            connection_wait_timeout=300,
            connection_expiration=300)
        result = gevent.spawn(client.method_sleep_set_event, 0.2)
        # make sure the previous call has acquired a connection from the pool
        AnotherFakeClient.sleep_event.wait(0.1)
        AnotherFakeClient.sleep_event.clear()
        # the total time this call will wait is 0.2 + 0.2 seconds = 400ms
        # however, this won't trigger connection expiration timeout, since the
        # connection acquisition time was not counted towards connection
        # expiration timeout
        self.assertTrue(client.method_sleep(0.2))
        result.join()

        # a test case to show that end-to-end timeout is tighter than connection
        # expiration timeout, even though they are set to the same value.
        client = FakePooledThriftClientMixin(
            host_provider=HostsProvider(HOSTS),
            pool_size=1,
            connection_wait_timeout=300,
            connection_expiration=300,
            e2e_timeout=300)
        result = gevent.spawn(client.method_sleep_set_event, 0.2)
        AnotherFakeClient.sleep_event.wait(0.1)
        AnotherFakeClient.sleep_event.clear()
        # this should raise exception, because this call will wait 400 ms
        # (200ms on connection acquisition, 200ms on execution), while the
        # end-to-end timeout is set to 300ms
        self.assertRaises(ExpiredConnection, client.method_sleep, 0.2)
        result.join()

        # a test case to show that everything works as expected even if only
        # end-to-end timeout is specified.
        client = FakePooledThriftClientMixin(
            host_provider=HostsProvider(HOSTS), pool_size=1, e2e_timeout=300)
        result = gevent.spawn(client.method_sleep_set_event, 0.2)
        AnotherFakeClient.sleep_event.wait(0.1)
        AnotherFakeClient.sleep_event.clear()
        t0 = time.time()
        self.assertRaises(ExpiredConnection, client.method_sleep, 0.2)
        result.join()
 def test_other_error(self):
     client = FakeThriftClientMixin(HostsProvider(HOSTS))
     self.assertRaises(TApplicationException, client.method_other_error)
     # Connection should be established.
     self.assertEqual('fakehost1', client.host)
     self.assertEqual(100, client.port)
     self.assertTrue(client.connected)
     # requests_served doesn't not increase.
     self.assertEqual(0, client.requests_served)
    def test_default_replace_policy(self):
        client = FakePooledThriftClientMixin(
            host_provider=HostsProvider(HOSTS))

        self.assertRaises(TApplicationException, client.method_other_error)
        self.assertEqual(1, client.client_pool.num_connected)

        self.assertRaises(FakeThriftException,
                          client.method_user_defined_error)
        self.assertEqual(1, client.client_pool.num_connected)
    def test_connection_wait_timeout(self):
        # The connection acquisition attempt should be immediately abandoned.
        start = time.time()
        client = FakePooledThriftClientMixin(
            host_provider=HostsProvider(HOSTS),
            pool_size=0,
            connection_wait_timeout=1)
        self.assertRaises(gevent.queue.Empty, client.method_success)
        self.assertTrue(time.time() - start < 0.2)

        # The connection acquisition attempt should wait for a short period
        # before giving up.
        start = time.time()
        client = FakePooledThriftClientMixin(
            host_provider=HostsProvider(HOSTS),
            pool_size=0,
            connection_wait_timeout=300)
        self.assertRaises(gevent.queue.Empty, client.method_success)
        self.assertTrue(time.time() - start > 0.2)
    def test_replace_if(self):
        def replace_if(ex):
            return isinstance(ex, ThriftConnectionError)

        client = FakePooledThriftClientMixin(
            host_provider=HostsProvider(HOSTS), conn_replace_policy=replace_if)
        self.assertRaises(ThriftConnectionError, client.method_network_error)
        self.assertEqual(0, client.client_pool.num_connected)

        self.assertRaises(TApplicationException, client.method_other_error)
        self.assertEqual(1, client.client_pool.num_connected)
    def test_concurrency(self):
        client = FakePooledThriftClientMixin(
            host_provider=HostsProvider(HOSTS), pool_size=5)

        self.assertEqual(0, AnotherFakeClient.in_flight_calls)
        AnotherFakeClient.num_calls = 0

        greenlets = []
        for i in xrange(0, 10):
            greenlets.append(gevent.spawn(self._run_method_success, client, 3))
        gevent.joinall(greenlets)
        self.assertEqual(30, AnotherFakeClient.num_calls)
 def test_wrap(self):
     client = FakePooledThriftClientMixin(
         host_provider=HostsProvider(HOSTS))
     self.assertEqual(types.FunctionType, type(client.method_success))
     self.assertEqual(types.FunctionType, type(client.method_network_error))
     self.assertEqual(types.FunctionType, type(client.method_other_error))
     self.assertEqual(types.MethodType,
                      type(client.get_connection_exception_class))
     self.assertIs(
         client.get_connection_exception_class.im_func,
         FakePooledThriftClientMixin.
         __dict__['get_connection_exception_class'])
Example #16
0
    def test_init_base_host_selector_class(self):
        """Test base initialization and functionality."""
        host_provider = HostsProvider([])
        base_host_selector = BaseHostSelector(host_provider)
        # Check that some base states are set.
        self.assertTrue(base_host_selector._last is None)
        self.assertTrue(base_host_selector._current is None)
        self.assertTrue(base_host_selector._select_time is None)
        self.assertEquals(base_host_selector._bad_hosts, {})
        self.assertEquals(base_host_selector._retry_time, 60)
        self.assertTrue(base_host_selector._host_provider is host_provider)

        # This is an abstract class. _chose_host() should raise an exception.
        self.assertRaises(NotImplementedError, base_host_selector._choose_host)
    def test_success(self):
        client = FakeThriftClientMixin(HostsProvider(HOSTS))
        self.assertTrue(client.method_success())
        self.assertEqual(1, client.requests_served)
        self.assertEqual('fakehost1', client.host)
        self.assertEqual(100, client.port)
        self.assertTrue(client.connected)

        # call it again
        self.assertTrue(client.method_success())
        self.assertEqual(2, client.requests_served)
        self.assertEqual('fakehost1', client.host)
        self.assertEqual(100, client.port)
        self.assertTrue(client.connected)
 def test_rpc_timeout(self):
     """ test rpc causing socket.timeout. """
     client = FakeThriftClientMixin(HostsProvider(HOSTS), retry_count=2)
     self.assertRaises(ThriftConnectionError, client.method_rpc_timeout)
     # Connection should be established with the new host, but this teardown
     # and connection re-establishment should only happen once
     # disabled for now, this requests will be retried twice as specified by
     # retry_count, and each time it will try to establish the connection on
     # a new host endpoint.
     self.assertEqual('fakehost3', client.host)
     self.assertEqual(300, client.port)
     self.assertFalse(client.connected)
     # requests_served doesn't not increase.
     self.assertEqual(0, client.requests_served)
Example #19
0
    def test_init_base_host_selector_class(self):
        """Test base initialization and functionality."""

        fd, tmp_file = tempfile.mkstemp()
        host_provider = HostsProvider([], file_path=tmp_file)
        base_host_selector = BaseHostSelector(host_provider)
        # Check that some base states are set.
        self.assertTrue(base_host_selector._last is None)
        self.assertTrue(base_host_selector._current is None)
        self.assertTrue(base_host_selector._select_time is None)
        self.assertEquals(base_host_selector._bad_hosts, {})
        self.assertEquals(base_host_selector._retry_time, 60)
        self.assertTrue(base_host_selector._host_provider is host_provider)

        # This is an abstract class. _chose_host() should raise an exception.
        self.assertRaises(NotImplementedError, base_host_selector._choose_host)
        HostSelectorWithLocalFileTestCase.FILE_WATCH._clear_all_watches()
        os.remove(tmp_file)
Example #20
0
    def test_random_host_selector_with_serverset(self):
        fd, tmp_file = tempfile.mkstemp()
        # Add a new host into the local server set file to simulate a join
        f = open(tmp_file, 'w')
        f.write(HostSelectorWithLocalFileTestCase.HOST_LIST[0])
        f.close()
        HostSelectorWithLocalFileTestCase.FILE_WATCH._check_file_updates()
        host_provider = HostsProvider(
            HostSelectorWithLocalFileTestCase.HOST_LIST, file_path=tmp_file)
        self.assertTrue(host_provider.initialized)
        self.assertTrue(host_provider.hosts)
        self.assertEqual(host_provider._current_host_tuple,
                         (HostSelectorWithLocalFileTestCase.HOST_LIST[0], ))
        random_host_selector = RandomHostSelector(host_provider,
                                                  expire_time=0,
                                                  retry_time=0,
                                                  invalidation_threshold=1.0)
        self.assertTrue(random_host_selector.get_host() in
                        HostSelectorWithLocalFileTestCase.HOST_LIST)

        no_of_iterations = 100
        # After the first endpoint joins, random host selector should only
        # start to use hosts in the server set.
        returned_hosts = [
            random_host_selector.get_host() for i in xrange(no_of_iterations)
        ]
        self.assertEqual(len(set(returned_hosts)), 1)
        self.assertEqual(len(host_provider.hosts), 1)
        time.sleep(1)
        f = open(tmp_file, 'a')
        f.write('\n' + HostSelectorWithLocalFileTestCase.HOST_LIST[1])
        f.close()
        HostSelectorWithLocalFileTestCase.FILE_WATCH._check_file_updates()
        # After the second endpoint joins the server set, random host selector
        # should return both endpoints now.
        returned_hosts = [
            random_host_selector.get_host() for i in xrange(no_of_iterations)
        ]
        self.assertEqual(len(set(returned_hosts)), 2)
        self.assertEqual(len(host_provider.hosts), 2)
        HostSelectorWithLocalFileTestCase.FILE_WATCH._clear_all_watches()
        os.remove(tmp_file)
Example #21
0
    def test_retrieving_and_invalidation(self):
        """Test host retrieval."""

        fd, tmp_file = tempfile.mkstemp()
        with open(tmp_file, 'w') as f:
            f.write('\n'.join(HostSelectorWithLocalFileTestCase.HOST_LIST))
        host_provider = HostsProvider(
            HostSelectorWithLocalFileTestCase.HOST_LIST, file_path=tmp_file)
        base_host_selector = BaseHostSelector(host_provider,
                                              expire_time=0,
                                              retry_time=0,
                                              invalidation_threshold=1.0)
        self.assertTrue(base_host_selector.get_last_host() is None)

        with patch(hosts.__name__ + ".BaseHostSelector._choose_host",
                   new=Mock(return_value=HostSelectorWithLocalFileTestCase.
                            HOST_LIST[0])):
            # Get one host.
            host1 = base_host_selector.get_host()
            self.assertEquals(host1,
                              HostSelectorWithLocalFileTestCase.HOST_LIST[0])
            # If invalidated the state of the object changes.
            self.assertTrue(host1 not in base_host_selector._bad_hosts)
            base_host_selector.invalidate()
            self.assertTrue(host1 in base_host_selector._bad_hosts)

        # If called again, with retry_time being set to 0 bad hosts should be
        # invalidated.
        with patch(hosts.__name__ + ".BaseHostSelector._choose_host",
                   new=Mock(return_value=HostSelectorWithLocalFileTestCase.
                            HOST_LIST[1])):
            host2 = base_host_selector.get_host()
            # Now bad hosts should be empty
            self.assertTrue(not base_host_selector._bad_hosts)
            self.assertEquals(host2,
                              HostSelectorWithLocalFileTestCase.HOST_LIST[1])
            base_host_selector.invalidate()
            self.assertTrue(host2 in base_host_selector._bad_hosts)
        HostSelectorWithLocalFileTestCase.FILE_WATCH._clear_all_watches()
        os.remove(tmp_file)
Example #22
0
 def test_reject_invalidation(self):
     """Test rejecting invalidation."""
     host_provider = HostsProvider(HostSelectorTestCase.HOST_LIST)
     base_host_selector = BaseHostSelector(host_provider,
                                           expire_time=0,
                                           retry_time=0)
     with patch(hosts.__name__ + ".BaseHostSelector._choose_host",
                new=Mock(return_value=HostSelectorTestCase.HOST_LIST[0])):
         # Get one host.
         host1 = base_host_selector.get_host()
         self.assertEquals(host1, HostSelectorTestCase.HOST_LIST[0])
         # If invalidated the state of the object changes.
         self.assertTrue(host1 not in base_host_selector._bad_hosts)
         base_host_selector.invalidate()
         # Because 1 is larger than 2 * 0.2 = 0.4
         self.assertTrue(host1 not in base_host_selector._bad_hosts)
         base_host_selector._invalidation_threshold = 0.5
         host1 = base_host_selector.get_host()
         self.assertEquals(host1, HostSelectorTestCase.HOST_LIST[0])
         base_host_selector.invalidate()
         # Because 1 <= 2 * 0.5 = 1.0
         self.assertTrue(host1 in base_host_selector._bad_hosts)
Example #23
0
 def test_random_host_selector_with_serverset(self):
     testutil.initialize_kazoo_client_manager(ZK_HOSTS)
     kazoo_client = KazooClientManager().get_client()
     kazoo_client.ensure_path(HostSelectorTestCase.SERVER_SET_PATH)
     host_provider = HostsProvider(HostSelectorTestCase.PORT_LIST,
                                   HostSelectorTestCase.SERVER_SET_PATH)
     self.assertTrue(host_provider.initialized)
     self.assertTrue(host_provider.hosts)
     # Since there is no live hosts in the server set, host provider should
     # still use the static host list.
     self.assertEqual(host_provider._current_host_tuple,
                      host_provider._static_host_tuple)
     random_host_selector = RandomHostSelector(host_provider,
                                               expire_time=0,
                                               retry_time=0,
                                               invalidation_threshold=1.0)
     self.assertTrue(
         random_host_selector.get_host() in HostSelectorTestCase.PORT_LIST)
     server_set = ServerSet(HostSelectorTestCase.SERVER_SET_PATH, ZK_HOSTS)
     g = server_set.join(HostSelectorTestCase.PORT_LIST[0], use_ip=False)
     g.get()
     no_of_iterations = 100
     # After the first endpoint joins, random host selector should only
     # start to use hosts in the server set.
     returned_hosts = [
         random_host_selector.get_host() for i in xrange(no_of_iterations)
     ]
     self.assertEqual(len(set(returned_hosts)), 1)
     self.assertEqual(len(host_provider.hosts), 1)
     g = server_set.join(HostSelectorTestCase.PORT_LIST[1], use_ip=False)
     g.get()
     # After the second endpoint joins the server set, random host selector
     # should return both endpoints now.
     returned_hosts = [
         random_host_selector.get_host() for i in xrange(no_of_iterations)
     ]
     self.assertEqual(len(set(returned_hosts)), 2)
     self.assertEqual(len(host_provider.hosts), 2)
    def test_statsd_client_is_called(self):
        """Test that a statsd client gets called.

        We'll pass the client to :class:`PooledThriftClientMixinMetaclass`.
        """
        class TestStatsdClient:
            def __init__(self, *args, **kwargs):
                self.val = 0
                self.timing_data = {}

            def increment(self, stats, sample_rate=1):
                self.val += 1

            def timing(self, key, val, **kwargs):
                self.timing_data[key] = val

        sc = TestStatsdClient()
        client = FakePooledThriftClientMixin(
            host_provider=HostsProvider(HOSTS), statsd_client=sc)
        client.method_success()
        self.assertTrue(
            "client.requests.test_thrift_client_mixin.method_success" in
            sc.timing_data)
Example #25
0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from kingpin.thrift_utils.thrift_client_mixin import PooledThriftClientMixin
from kingpin.thrift_utils.base_thrift_exceptions import ThriftConnectionError
from kingpin.kazoo_utils.hosts import HostsProvider

import TestService


class TestServiceConnectionException(ThriftConnectionError):
    pass


class TestServiceClient(TestService.Client, PooledThriftClientMixin):
    def get_connection_exception_class(self):
        return TestServiceConnectionException


testservice_client = TestServiceClient(HostsProvider(
    [], file_path="/var/serverset/discovery.test_service.prod"),
                                       timeout=3000,
                                       pool_size=10,
                                       always_retry_on_new_host=True)

print testservice_client.ping()
Example #26
0
    def test_random_host_selector(self):
        """Test the RandomHostSelector."""
        fd, tmp_file = tempfile.mkstemp()
        with open(tmp_file, 'w') as f:
            f.write('\n'.join(HostSelectorWithLocalFileTestCase.HOST_LIST))

        host_provider = HostsProvider(
            HostSelectorWithLocalFileTestCase.HOST_LIST, file_path=tmp_file)
        random_host_selector = RandomHostSelector(host_provider,
                                                  expire_time=0,
                                                  retry_time=0,
                                                  invalidation_threshold=1.0)

        # Note that we didn't have to mock _chose_host() call this time,
        # it should be im RandomHostSelector class already.
        some_host = random_host_selector.get_host()
        self.assertTrue(
            some_host in HostSelectorWithLocalFileTestCase.HOST_LIST)
        self.assertEquals(random_host_selector._current, some_host)

        no_of_iterations = 250
        # If I run get_host() about 100 times I expect to have relatively
        # even distribution and all hosts in the host_list returned by now.
        returned_hosts = [
            random_host_selector.get_host() for i in xrange(no_of_iterations)
        ]
        host_counter = Counter(returned_hosts)

        # We expect that all calls happened.
        self.assertEquals(sum(host_counter.itervalues()), no_of_iterations)
        # We should have seen all the elements.
        self.assertEquals(set(host_counter),
                          set(HostSelectorWithLocalFileTestCase.HOST_LIST))

        # But if we had left large expire_time only one host would be picked
        # up all the time, and we'll show that here.
        random_host_selector = RandomHostSelector(host_provider,
                                                  invalidation_threshold=1.0)
        returned_hosts = [
            random_host_selector.get_host() for i in xrange(no_of_iterations)
        ]
        host_counter = Counter(returned_hosts)
        self.assertEquals(len(list(host_counter)), 1)

        # Test invalidation
        hosts = [HostSelectorWithLocalFileTestCase.HOST_LIST[0]]
        for i in xrange(4):
            hosts.append(HostSelectorWithLocalFileTestCase.HOST_LIST[1])

        def random_select(*args):
            return hosts.pop()

        mock = Mock(side_effect=random_select)
        with patch("random.choice", new=mock):
            random_host_selector = RandomHostSelector(
                host_provider,
                expire_time=0,
                retry_time=60,
                invalidation_threshold=1.0)
            host = random_host_selector.get_host()
            self.assertEqual(host,
                             HostSelectorWithLocalFileTestCase.HOST_LIST[1])
            random_host_selector.invalidate()
            # Because mock will return the bad host three times in a row,
            # this will force it to compute the set of good hosts
            host = random_host_selector.get_host()
            self.assertEqual(host,
                             HostSelectorWithLocalFileTestCase.HOST_LIST[0])
            # At this point, random.choice should have been called 5 times
            self.assertEqual(mock.call_count, 5)
        HostSelectorWithLocalFileTestCase.FILE_WATCH._clear_all_watches()
        os.remove(tmp_file)
 def test_network_error(self):
     """ test rpc causing TTransportException. """
     client = FakeThriftClientMixin(HostsProvider(HOSTS))
     self._test_network_error(client, client.method_network_error)
 def test_socket_error(self):
     """ test rpc causing socket.error. """
     client = FakeThriftClientMixin(HostsProvider(HOSTS))
     self._test_network_error(client, client.method_socket_error)
 def test_thrift_connection_timeout_error(self):
     """ test rpc causing ThriftConnectionTimeoutError. """
     client = FakeThriftClientMixin(HostsProvider(HOSTS))
     self._test_network_error(client,
                              client.method_thrift_connection_timeout_error)
 def test_success(self):
     client = FakePooledThriftClientMixin(
         host_provider=HostsProvider(HOSTS))
     self.assertEqual((True, 1), client.method_success())