Example #1
0
class ServiceDiscoveryClientTest(KazooTestCase):
    """Basic set of tests for ServiceDiscoveryClient."""

    def setUp(self):
        """Replace the Zookeeper client on the module."""
        super(ServiceDiscoveryClientTest, self).setUp()
        self.ezDiscovery = ServiceDiscoveryClient(self.hosts)

    def tearDown(self):
        """Clean up the Zookeeper entries."""
        super(ServiceDiscoveryClientTest, self).tearDown()

    def test_register_endpoint(self):
        """Register an endpoint and make sure it ends up in Zookeeper."""
        self.ezDiscovery.register_endpoint('foo', 'bar', 'localhost', 8080)
        endpoints = self.ezDiscovery.get_endpoints("foo", "bar")
        self.assertEqual(endpoints[0], "localhost:8080")

    def test_register_common_endpoint(self):
        """Register a common endpoint and make sure it ends up in Zookeeper."""
        self.ezDiscovery.register_common_endpoint('bar', 'localhost', 8080)
        endpoints = self.ezDiscovery.get_common_endpoints("bar")
        self.assertEqual(endpoints[0], "localhost:8080")

    def test_unregister_endpoint(self):
        """Register and unregister an endpoint and make sure it is gone."""
        self.ezDiscovery.register_endpoint('foo', 'bar', 'localhost', 8080)
        self.ezDiscovery.unregister_endpoint('foo', 'bar', 'localhost', 8080)
        endpoints = self.ezDiscovery.get_endpoints("foo", "bar")
        self.assertEqual(len(endpoints), 0)

    def test_unregister_common_endpoint(self):
        """Register and unregister a common endpoint and make sure it is gone. """
        self.ezDiscovery.register_common_endpoint('bar', 'localhost', 8080)
        self.ezDiscovery.unregister_common_endpoint('bar', 'localhost', 8080)
        endpoints = self.ezDiscovery.get_common_endpoints("bar")
        self.assertEqual(len(endpoints), 0)

    def test_unregister_none_exist_endpoint(self):
        """ make sure no exception is raised """
        self.ezDiscovery.unregister_endpoint('foo', 'bar', 'localhost', 8000)

    def test_unregister_multiple_endpoints(self):
        """Test that when multiple endpoints get made and some removed the tree
        of endpoints stays correct.
        """
        self.ezDiscovery.register_endpoint('foo', 'bar', 'localhost', 8000)
        self.ezDiscovery.register_endpoint('foo', 'bar', 'localhost', 8888)

        # Unregister the first endpoint.
        self.ezDiscovery.unregister_endpoint('foo', 'bar', 'localhost', 8000)
        endpoints = self.ezDiscovery.get_endpoints('foo', 'bar')
        self.assertEqual(len(endpoints), 1)
        self.assertEqual(endpoints[0], 'localhost:8888')

        # Unregister the second endpoint.
        self.ezDiscovery.unregister_endpoint('foo', 'bar', 'localhost', 8888)
        endpoints = self.ezDiscovery.get_endpoints('foo', 'bar')
        self.assertEqual(len(endpoints), 0)

        base_path = '/'.join([
            ServiceDiscoveryClient.NAMESPACE,
            'foo',
            'bar',
            ServiceDiscoveryClient.ENDPOINTS
        ])
        self.assertTrue(self.client.exists(base_path))

    def test_get_applications(self):
        """Test application list."""
        # Create a few application endpoints.
        self.ezDiscovery.register_endpoint('foo', 'bar', 'localhost', 8000)
        self.ezDiscovery.register_endpoint('harry', 'sally', 'localhost', 8080)
        self.assertEqual(2, len(self.ezDiscovery.get_applications()))

    def test_get_services(self):
        """Test the application services list."""
        # Create a few application endpoints.
        self.ezDiscovery.register_endpoint('foo', 'bar', 'localhost', 8000)
        self.ezDiscovery.register_endpoint('foo', 'baz', 'localhost', 8001)
        self.ezDiscovery.register_endpoint('harry', 'sally', 'localhost', 8080)

        # Make sure it returns the right count for a single service.
        self.assertEqual(2, len(self.ezDiscovery.get_services('foo')))

        self.assertEqual(1, len(self.ezDiscovery.get_services('harry')))
        self.assertEqual('sally', self.ezDiscovery.get_services('harry')[0])

    def test_get_common_services(self):
        """Test fetching common services."""
        # Make a few common services and and an external, ensure they return
        # properly.
        self.ezDiscovery.register_common_endpoint('foo', 'localhost', 8000)
        self.ezDiscovery.register_common_endpoint('bar', 'localhost', 8001)
        self.ezDiscovery.register_endpoint('harry', 'sally', 'localhost', 8080)
        self.assertEqual(2, len(self.ezDiscovery.get_common_services()))

    def test_get_endpoints(self):
        """Test endpoint list fetching."""
        # Create a few application endpoints.
        self.ezDiscovery.register_endpoint('foo', 'bar', 'localhost', 8000)
        self.ezDiscovery.register_endpoint('foo', 'bar', 'localhost', 8001)
        self.ezDiscovery.register_endpoint('harry', 'sally', 'localhost', 8080)
        self.assertEqual(2, len(self.ezDiscovery.get_endpoints('foo', 'bar')))

    def test_get_common_endpoints(self):
        """Test fetching common endpoints."""
        # Create a few common endpoints and one not, test results.
        self.ezDiscovery.register_common_endpoint('foo', 'localhost', 8000)
        self.ezDiscovery.register_common_endpoint('foo', 'localhost', 8001)
        self.ezDiscovery.register_endpoint('harry', 'sally', 'localhost', 8080)
        self.assertEqual(2, len(self.ezDiscovery.get_common_endpoints('foo')))
        self.assertEquals(0, len(self.ezDiscovery.get_common_endpoints('sally')))

    def test_is_service_common(self):
        """Ensure only common services return true."""
        # Test one that does not exist.
        self.assertFalse(self.ezDiscovery.is_service_common('foo'))
        self.ezDiscovery.register_common_endpoint('foo', 'localhost', 8000)
        self.assertTrue(self.ezDiscovery.is_service_common('foo'))
        self.ezDiscovery.register_endpoint('harry', 'sally', 'localhost', 8080)
        self.assertFalse(self.ezDiscovery.is_service_common('sally'))

    def test_set_security_id_for_application(self):
        """Ensure security id's get set for applications."""
        self.ezDiscovery.register_endpoint('foo', 'bar', 'localhost', 8000)
        self.ezDiscovery.set_security_id_for_application('foo', 'sid')

        path = '/'.join([
            ServiceDiscoveryClient.NAMESPACE,
            'foo',
            ServiceDiscoveryClient.SECURITY,
            ServiceDiscoveryClient.SECURITY_ID
        ])
        self.assertTrue(self.client.exists(path))
        self.assertEquals('sid', self.client.get(path)[0])

    def test_set_security_id_for_common_service(self):
        """Ensure security id's get set for common services."""
        self.ezDiscovery.register_common_endpoint('foo', 'localhost', 8000)
        self.ezDiscovery.set_security_id_for_common_service('foo', 'sid')
        path = '/'.join([
            ServiceDiscoveryClient.NAMESPACE,
            '/'.join([ServiceDiscoveryClient.COMMON_APP_NAME, 'foo']),
            ServiceDiscoveryClient.SECURITY,
            ServiceDiscoveryClient.SECURITY_ID
        ])
        self.assertTrue(self.client.exists(path))
        self.assertEquals('sid', self.client.get(path)[0])

    def test_get_security_id_for_application(self):
        """Ensure fetching application security id's returns properly."""
        # Fetch one that does not exist.
        self.assertEquals(
            None,
            self.ezDiscovery.get_security_id('foo')
        )
        self.ezDiscovery.register_endpoint('foo', 'bar', 'localhost', 8000)
        self.ezDiscovery.set_security_id_for_application('foo', 'sid')
        self.assertEquals(
            'sid',
            self.ezDiscovery.get_security_id('foo')
        )

    def test_get_security_id_for_common_service(self):
        """Ensure fetching application security id's returns properly."""

        # clear the cache
        self.ezDiscovery.securityIdCache.clear()

        # Fetch one does not exist.
        self.assertEquals(
            None,
            self.ezDiscovery.get_security_id('foo')
        )
        self.ezDiscovery.register_common_endpoint('foo', 'localhost', 8000)
        self.ezDiscovery.set_security_id_for_common_service('foo', 'sid')
        self.assertEquals(
            'sid',
            self.ezDiscovery.get_security_id('foo')
        )
class ThriftClientPool(object):
    """
    """
    def __init__(self, ez_props):
        if ez_props is None:
            raise Exception("Invalid EzProperties.")

        zk_con_str = ZookeeperConfiguration(ez_props).getZookeeperConnectionString()
        self.ezd_client = ServiceDiscoveryClient(zk_con_str)

        self.__applicationConfiguration = ApplicationConfiguration(ez_props)
        self.__applicationName = self.__applicationConfiguration.getApplicationName()
        self.__securityConfiguration = SecurityConfiguration(ez_props)
        self.__thriftConfiguration = ThriftConfiguration(ez_props)

        self.__rLock = threading.RLock()
        self.__serviceMap = {}
        self.__clientMap = {}

        self.__log = logging.getLogger(__name__)

        if self.__applicationName is None:
            self.__log.warn("No application name was found. Only common services will be discoverable.")
        else:
            self.__log.info("Application name: " + self.__applicationName)

        try:
            self.__common_services = list(self.ezd_client.get_common_services())
        except Exception:
            self.__log.error("Unable to get common services")
            raise

        self.__refresh_end_points()
        self.__refresh_common_endpoints()

        thread = threading.Thread(target=self._evict_daemon)
        thread.setDaemon(True)
        thread.start()

    def _evict_daemon(self):
        check_interval_millis = self.__thriftConfiguration.getMillisBetweenClientEvictionChecks()
        idle_threshold_millis = self.__thriftConfiguration.getMillisIdleBeforeEviction()
        while True:
            time.sleep(check_interval_millis * 0.001)
            with self.__rLock:
                for client in self.__clientMap.itervalues():
                    client._pool.evict_check(idle_threshold_millis)

    def _get_service_map(self):
        return self.__serviceMap

    def _get_client_map(self):
        return self.__clientMap

    def __refresh_end_points(self):
        if (self.__applicationName is not None) and (self.__applicationName != ''):
            try:
                for service in self.ezd_client.get_services(self.__applicationName):
                    try:
                        endpoints = self.ezd_client.get_endpoints(self.__applicationName, service)
                        self._add_endpoints(service, endpoints)
                    except Exception:
                        self.__log.warn("No " + service + " for application " + self.__applicationName + " was found")
            except Exception:
                self.__log.warn(
                    "Failed to get application services. "
                    "This might be okay if the application hasn't registered any services."
                )

    def __refresh_common_endpoints(self):
        try:
            for service in self.ezd_client.get_common_services():
                try:
                    endpoints = self.ezd_client.get_common_endpoints(service)
                    self._add_endpoints(service, endpoints)
                except Exception:
                    self.__log.warn("No common service " + service + " was found.")
        except Exception:
            self.__log.warn("Failed to get common services. This might be okay if no common service has been defined.")

    def _add_endpoints(self, service, endpoints):
        with self.__rLock:
            if service in self.__serviceMap:
                del self.__serviceMap[service]
            self.__serviceMap[service] = []
            for endpoint in endpoints:
                self.__serviceMap[service].append(endpoint)

    @staticmethod
    def __get_thrift_connection_key(service_name, client_class):
        return service_name + "|" + str(client_class)

    def __get_endpoints(self, service_name, retry=True):
        with self.__rLock:
            if service_name in self.__serviceMap:
                return self.__serviceMap[service_name]
        if retry:
            self.__refresh_end_points()
            self.__refresh_common_endpoints()
            return self.__get_endpoints(service_name, retry=False)
        return None

    def get_client(self, app_name=None, service_name=None, clazz=None):

        if not service_name:
            raise ValueError("'service_name' does not have a valid value (%s)." % service_name)

        if not clazz:
            raise ValueError("'clazz' does not have a valid value (%s)." % clazz)

        try:
            key = self.__get_thrift_connection_key(service_name, clazz)
            with self.__rLock:

                if app_name:
                    service = self.__applicationConfiguration.getApplicationServiceName(app_name, service_name)
                    if service not in self.__serviceMap:
                        endpoints = self.ezd_client.get_endpoints(app_name, service_name)
                        self._add_endpoints(service, endpoints)

                # get client from client pool, or initialize client
                if key in self.__clientMap:
                    client = self.__clientMap[key]
                else:
                    endpoints = self.__get_endpoints(service_name)
                    if endpoints is None:
                        return None

                    pool_size = self.__thriftConfiguration.getMaxIdleClients()
                    if self.__thriftConfiguration.useSSL():
                        sc = self.__securityConfiguration
                        ca_certs = sc.getTrustedSslCerts()
                        key_file = sc.getPrivateKey()
                        cert = sc.getSslCertificate()
                        client = PoolingThriftClient(endpoints, clazz, pool_size=pool_size,
                                                     use_ssl=True, ca_certs=ca_certs, cert=cert, key=key_file)
                    else:
                        client = PoolingThriftClient(endpoints, clazz, pool_size=pool_size)

                    self.__clientMap[key] = client
                return client

        except Exception, e:
            raise TException(str(e))