class EzSecurityTest(EzThriftServerTestHarness):

    def setUp(self):
        super(EzSecurityTest, self).setUp()

        # load EzConfiguration
        self.ez_props = EzConfiguration().getProperties()
        self.appId = "SecurityClientTest"

        # prior to 2.1-preview, we were passing around ezConfig,
        # but it is better to pass ez_props which inherits from `dict`.
        self.ez_props["zookeeper.connection.string"] = self.hosts
        self.ez_props["thrift.use.ssl"] = "false"
        self.ez_props["ezbake.security.ssl.dir"] = "test/certs/client"
        self.ez_props["ezbake.security.app.id"] = self.appId

        # load private key
        with open("test/certs/server/application.priv", "r") as f:
            rsa = f.read()

        handler = EzSecurityHandler(rsa, token_ttl_millis=500)
        self.add_server(handler.app_name, handler.service_name, host="localhost", port=8449,
                        processor=EzSecurity.Processor(handler), wait=3)

        self.global_client_pool = ThriftClientPool(self.ez_props)

    def tearDown(self):
        self.global_client_pool.close()
        super(EzSecurityTest, self).tearDown()

    def get_client(self):
        return EzSecurityClient(self.ez_props, self.global_client_pool)
class EzSecurityTest(unittest.TestCase):

    def setUp(self):
        self.appId = "client"
        ez_props = EzConfiguration().getProperties()
        ez_props["application.name"] = "client_name"
        ez_props["ezbake.security.app.id"] = self.appId
        ez_props["zookeeper.connection.string"] = ZOO_CON_STR
        ez_props["thrift.use.ssl"] = "false"
        ez_props["ezbake.security.ssl.dir"] = "test/certs/client/"

        self.global_client_pool = ThriftClientPool(ez_props)

        self.es_client = EzSecurityClient(ez_props, self.global_client_pool)

    def tearDown(self):
        self.global_client_pool.close()

    def IT_ping(self):
        nt.assert_true(self.es_client.ping())

    def IT_app_info(self):
        token = self.es_client.fetch_app_token(self.appId)

        nt.assert_is_not_none(token)
        nt.assert_equal(self.appId, token.tokenPrincipal.principal)
        nt.assert_equal(self.appId, token.validity.issuedFor)

    @staticmethod
    def _make_dn(subject):
        x509 = X509Info(subject=subject)
        token = ProxyUserToken(x509=x509,
                               issuedBy="EzSecurity", issuedTo="EFE",
                               notAfter=util.current_time_millis() + 720000)
        return jsonpickle.encode(token)

    @staticmethod
    def _sign(data):
        with open('test/certs/server/application.priv', 'r') as f:
            server_private_key = f.read()
        key = ossl.load_privatekey(ossl.FILETYPE_PEM, server_private_key)
        return base64.b64encode(ossl.sign(key, data, 'sha256'))

    def IT_user_info(self):
        subject = "CN=EzbakeClient, OU=42six, O=CSC, C=US"
        dn = self._make_dn(subject)
        sig = self._sign(dn)
        t = self.es_client.fetch_user_token({
            HTTP_HEADER_USER_INFO: dn,
            HTTP_HEADER_SIGNATURE: sig})

        nt.assert_equal(self.appId, t.validity.issuedTo)
        nt.assert_equal(self.appId, t.validity.issuedFor)
        nt.assert_equal(subject, t.tokenPrincipal.principal)
    def setUp(self):
        super(ThriftClientPoolTest, self).setUp()

        ez_props = EzConfiguration().getProperties()
        ez_props["thrift.use.ssl"] = "false"
        ez_props["zookeeper.connection.string"] = self.hosts
        application_name = ApplicationConfiguration(ez_props).getApplicationName()

        for endpoint in ENDPOINTS:
            host, port = endpoint.split(':')
            self.add_server(application_name, "ezpz", host, int(port), EzPz.Processor(EzPzHandler()))

        self.sd_client.register_endpoint(application_name, "service_one", 'localhost', 8083)
        self.sd_client.register_endpoint(application_name, "service_two", 'localhost', 8084)
        self.sd_client.register_endpoint(application_name, "service_three", 'localhost', 8085)

        self.sd_client.register_common_endpoint('common_service_one', 'localhost', 8080)
        self.sd_client.register_common_endpoint('common_service_two', 'localhost', 8081)
        self.sd_client.register_common_endpoint('common_service_three', 'localhost', 8082)
        self.sd_client.register_common_endpoint('common_service_multi', '192.168.1.1', 6060)
        self.sd_client.register_common_endpoint('common_service_multi', '192.168.1.2', 6161)

        self.sd_client.register_endpoint("NotThriftClientPool", "unknown_service_three", 'localhost', 8091)
        self.sd_client.register_endpoint("NotThriftClientPool", "unknown_service_three", 'localhost', 8092)
        self.sd_client.register_endpoint("NotThriftClientPool", "unknown_service_three", 'localhost', 8093)

        self.clientPool = ThriftClientPool(ez_props)
Beispiel #4
0
    def __init__(self, ez_props, client_pool=None, cache_evict_cycle=DEFAULT_EVICT_CYCLE,
                 log=logging.getLogger(__name__), handler=None):
        """
        """
        if EzSecurityClient.token_cache is None:
            EzSecurityClient.token_cache = TokenCache(cache_evict_cycle)

        self.ez_props = ez_props
        self.securityConfig = ezc_helpers.SecurityConfiguration(ez_props)
        self.appConfig = ezc_helpers.ApplicationConfiguration(ez_props)
        self.zk_con_str = ezc_helpers.ZookeeperConfiguration(ez_props).getZookeeperConnectionString()

        if client_pool is None:
            self.client_pool = ThriftClientPool(ez_props)
            self.__local_pool = True
        else:
            self.client_pool = client_pool
            self.__local_pool = False

        self.client = self.client_pool.get_client(service_name=SECURITY_SERVICE_NAME, clazz=EzSecurity.Client)

        self.privateKey = None
        self.servicePublic = None
        self.servicePrivate = None

        self.log = log
        self.handler = handler

        self.mock = ez_props.getBoolean(USE_MOCK_KEY, False)
        self.log.info("%s has mock config set to %s",
                      self.__class__.__name__, self.mock)
    def setUp(self):
        self.appId = "client"
        ez_props = EzConfiguration().getProperties()
        ez_props["application.name"] = "client_name"
        ez_props["ezbake.security.app.id"] = self.appId
        ez_props["zookeeper.connection.string"] = ZOO_CON_STR
        ez_props["thrift.use.ssl"] = "false"
        ez_props["ezbake.security.ssl.dir"] = "test/certs/client/"

        self.global_client_pool = ThriftClientPool(ez_props)

        self.es_client = EzSecurityClient(ez_props, self.global_client_pool)
    def setUp(self):
        super(ThriftClientPoolTest, self).setUp()

        ez_props = EzConfiguration().getProperties()
        ez_props["thrift.use.ssl"] = "true"
        ez_props["zookeeper.connection.string"] = self.hosts
        application_name = ApplicationConfiguration(ez_props).getApplicationName()

        for endpoint in ENDPOINTS:
            host, port = endpoint.split(':')
            self.add_server(application_name, "ezpz_ssl", host, int(port), EzPz.Processor(EzPzHandler()),
                            use_ssl=True, ca_certs=servercapath, cert=servercertpath, key=serverprivpath)

        self.clientPool = ThriftClientPool(ez_props)
    def setUp(self):
        super(ThriftClientPoolTest, self).setUp()

        ez_props = EzConfiguration().getProperties()
        ez_props["thrift.use.ssl"] = "false"
        ez_props["zookeeper.connection.string"] = self.hosts
        ez_props["thrift.max.idle.clients"] = 6
        # ez_props["thrift.max.pool.clients"] = 6
        ez_props["thrift.millis.between.client.eviction.checks"] = 1000
        ez_props["thrift.millis.idle.before.eviction"] = 1.5 * 1000
        application_name = ApplicationConfiguration(ez_props).getApplicationName()

        for endpoint in ENDPOINTS:
            host, port = endpoint.split(':')
            self.add_server(application_name, "ezpz", host, int(port),
                            EzPz.Processor(EzPzHandler()), use_simple_server=False)

        self.clientPool = ThriftClientPool(ez_props)
    def setUp(self):
        """
        """
        super(ThriftClientPoolTest, self).setUp()
        ezd_client = ServiceDiscoveryClient(self.hosts)

        ez_props = EzConfiguration().getProperties()
        ez_props["thrift.use.ssl"] = "false"
        ez_props["zookeeper.connection.string"] = self.hosts
        application_name = ApplicationConfiguration(ez_props).getApplicationName()

        self.serverProcesses = []
        for endpoint in ENDPOINTS:
            host, port = endpoint.split(':')
            port = int(port)
            server_process = Process(target=start_ezpz, args=(EzPzHandler(), port,))
            server_process.start()
            time.sleep(1)
            self.serverProcesses.append(server_process)
            ezd_client.register_endpoint(application_name, "ezpz", host, port)

        ezd_client.register_endpoint(application_name, "service_one", 'localhost', 8083)
        ezd_client.register_endpoint(application_name, "service_two", 'localhost', 8084)
        ezd_client.register_endpoint(application_name, "service_three", 'localhost', 8085)

        ezd_client.register_common_endpoint('common_service_one', 'localhost', 8080)
        ezd_client.register_common_endpoint('common_service_two', 'localhost', 8081)
        ezd_client.register_common_endpoint('common_service_three', 'localhost', 8082)
        ezd_client.register_common_endpoint('common_service_multi', '192.168.1.1', 6060)
        ezd_client.register_common_endpoint('common_service_multi', '192.168.1.2', 6161)

        ezd_client.register_endpoint("NotThriftClientPool", "unknown_service_three", 'localhost', 8091)
        ezd_client.register_endpoint("NotThriftClientPool", "unknown_service_three", 'localhost', 8092)
        ezd_client.register_endpoint("NotThriftClientPool", "unknown_service_three", 'localhost', 8093)

        self.clientPool = ThriftClientPool(ez_props)
class ThriftClientPoolTest(EzThriftServerTestHarness):

    def setUp(self):
        super(ThriftClientPoolTest, self).setUp()

        ez_props = EzConfiguration().getProperties()
        ez_props["thrift.use.ssl"] = "false"
        ez_props["zookeeper.connection.string"] = self.hosts
        application_name = ApplicationConfiguration(ez_props).getApplicationName()

        for endpoint in ENDPOINTS:
            host, port = endpoint.split(':')
            self.add_server(application_name, "ezpz", host, int(port), EzPz.Processor(EzPzHandler()))

        self.sd_client.register_endpoint(application_name, "service_one", 'localhost', 8083)
        self.sd_client.register_endpoint(application_name, "service_two", 'localhost', 8084)
        self.sd_client.register_endpoint(application_name, "service_three", 'localhost', 8085)

        self.sd_client.register_common_endpoint('common_service_one', 'localhost', 8080)
        self.sd_client.register_common_endpoint('common_service_two', 'localhost', 8081)
        self.sd_client.register_common_endpoint('common_service_three', 'localhost', 8082)
        self.sd_client.register_common_endpoint('common_service_multi', '192.168.1.1', 6060)
        self.sd_client.register_common_endpoint('common_service_multi', '192.168.1.2', 6161)

        self.sd_client.register_endpoint("NotThriftClientPool", "unknown_service_three", 'localhost', 8091)
        self.sd_client.register_endpoint("NotThriftClientPool", "unknown_service_three", 'localhost', 8092)
        self.sd_client.register_endpoint("NotThriftClientPool", "unknown_service_three", 'localhost', 8093)

        self.clientPool = ThriftClientPool(ez_props)

    def tearDown(self):
        self.clientPool.close()
        nt.assert_false(self.clientPool._get_service_map())
        nt.assert_false(self.clientPool._get_client_map())
        super(ThriftClientPoolTest, self).tearDown()

    def test_endpoints(self):
        service_map = self.clientPool._get_service_map()
        self.assertTrue("common_service_one" in service_map)
        self.assertTrue("common_service_two" in service_map)
        self.assertTrue("common_service_three" in service_map)
        self.assertTrue("service_one" in service_map)
        self.assertTrue("service_two" in service_map)
        self.assertTrue("service_three" in service_map)
        self.assertFalse("unknown_service_one" in service_map)
        self.assertFalse("unknown_service_two" in service_map)
        self.assertFalse("unknown_service_three" in service_map)

        self.assertTrue("common_service_multi" in service_map)
        self.assertEqual(len(service_map["common_service_multi"]), 2)

        self.assertTrue("ezpz" in service_map)
        self.assertEqual(len(service_map["ezpz"]), len(ENDPOINTS))

    def test_get_client(self):
        client = self.clientPool.get_client(service_name='ezpz', clazz=EzPz.Client)
        try:
            nt.assert_equal('pz', client.ez())
        finally:
            client.close()

        client = self.clientPool.get_client(service_name='ezpz1', clazz=EzPz.Client)            # None existing service
        nt.assert_false(client)

    def test_multi_get_client(self):
        client1 = self.clientPool.get_client(service_name='ezpz', clazz=EzPz.Client)
        client2 = self.clientPool.get_client(service_name='ezpz', clazz=EzPz.Client)
        try:
            nt.assert_equal('pz', client1.ez())
            nt.assert_equal('pz', client2.ez())
        finally:
            self.clientPool.close()

    def test_get_client_app(self):
        client = self.clientPool.get_client(app_name="testApp", service_name="ezpz", clazz=EzPz.Client)
        try:
            nt.assert_equal("pz", client.ez())
        finally:
            client.close()

        client = self.clientPool.get_client(app_name="testApp1", service_name="ezpz0", clazz=EzPz.Client) # None existing app
        nt.assert_false(client)
class ThriftClientPoolTest(EzThriftServerTestHarness):

    def setUp(self):
        super(ThriftClientPoolTest, self).setUp()

        ez_props = EzConfiguration().getProperties()
        ez_props["thrift.use.ssl"] = "false"
        ez_props["zookeeper.connection.string"] = self.hosts
        ez_props["thrift.max.idle.clients"] = 6
        # ez_props["thrift.max.pool.clients"] = 6
        ez_props["thrift.millis.between.client.eviction.checks"] = 1000
        ez_props["thrift.millis.idle.before.eviction"] = 1.5 * 1000
        application_name = ApplicationConfiguration(ez_props).getApplicationName()

        for endpoint in ENDPOINTS:
            host, port = endpoint.split(':')
            self.add_server(application_name, "ezpz", host, int(port),
                            EzPz.Processor(EzPzHandler()), use_simple_server=False)

        self.clientPool = ThriftClientPool(ez_props)

    def tearDown(self):
        self.clientPool.close()
        nt.assert_false(self.clientPool._get_service_map())
        nt.assert_false(self.clientPool._get_client_map())
        super(ThriftClientPoolTest, self).tearDown()

    def client_thread(self, tid):
        print tid
        client = self.clientPool.get_client(service_name='ezpz', clazz=EzPz.Client)
        res = client.ez2(random.randint(10, 20) * 0.1)
        print chr(ord('a') + tid)
        return res

    def test_eviction(self):

        threads = []
        for i in range(10):
            thread = threading.Thread(target=self.client_thread, args=(i,))
            thread.start()
            threads.append(thread)

        for thread in threads:
            thread.join()

        time.sleep(3)                                   # sleep longer enough to evict all connections.
        client = self.clientPool.get_client(service_name='ezpz', clazz=EzPz.Client)
        #client._pool._connection_queue.join()           # test to see if join function on queue still work
        nt.assert_equal(0, client._pool._connection_queue.qsize())

    def client_thread_for_apps(self, tid):
        print tid
        client = self.clientPool.get_client(app_name='testApp', service_name='ezpz', clazz=EzPz.Client)
        res = client.ez2(random.randint(10, 20) * 0.1)
        print chr(ord('a') + tid)
        return res

    def test_eviction_for_apps(self):


        threads = []
        for i in range(10):
            thread = threading.Thread(target=self.client_thread, args=(i,))
            thread.start()
            threads.append(thread)

        for thread in threads:
            thread.join()

        time.sleep(3)                                   # sleep longer enough to evict all connections.
        client = self.clientPool.get_client(app_name='testApp', service_name='ezpz', clazz=EzPz.Client)
        #client._pool._connection_queue.join()           # test to see if join function on queue still work
        nt.assert_equal(0, client._pool._connection_queue.qsize())
Beispiel #11
0
class EzSecurityClient(object):
    """
    Wrapper around the Ezbake Security thrift client

    Handles the PKI stuff surrounding request/response data in ezbake security
    """

    token_cache = None

    def __init__(self, ez_props, client_pool=None, cache_evict_cycle=DEFAULT_EVICT_CYCLE,
                 log=logging.getLogger(__name__), handler=None):
        """
        """
        if EzSecurityClient.token_cache is None:
            EzSecurityClient.token_cache = TokenCache(cache_evict_cycle)

        self.ez_props = ez_props
        self.securityConfig = ezc_helpers.SecurityConfiguration(ez_props)
        self.appConfig = ezc_helpers.ApplicationConfiguration(ez_props)
        self.zk_con_str = ezc_helpers.ZookeeperConfiguration(ez_props).getZookeeperConnectionString()

        if client_pool is None:
            self.client_pool = ThriftClientPool(ez_props)
            self.__local_pool = True
        else:
            self.client_pool = client_pool
            self.__local_pool = False

        self.client = self.client_pool.get_client(service_name=SECURITY_SERVICE_NAME, clazz=EzSecurity.Client)

        self.privateKey = None
        self.servicePublic = None
        self.servicePrivate = None

        self.log = log
        self.handler = handler

        self.mock = ez_props.getBoolean(USE_MOCK_KEY, False)
        self.log.info("%s has mock config set to %s",
                      self.__class__.__name__, self.mock)

    @staticmethod
    def _read_file(filename):
        """
        Helper to read public/private keys where necessary
        @return the files bytes
        """
        with open(filename, 'r') as f:
            b = f.read()
        return b

    @staticmethod
    def principal_from_request(headers):
        """
        Builds a ProxyPrincipal object
        (ProxyPrincipal(proxyUser:string, signature:string)) from the
        "EZB_VERIFIED_USER_INFO" and "EZB_VERIFIED_SIGNATURE" headers.

        :param headers: dict
        :return:
        """
        dn = headers.get(HTTP_HEADER_USER_INFO)
        signature = headers.get(HTTP_HEADER_SIGNATURE)

        if dn and signature not in (None, False):
            proxy_principal = ProxyPrincipal(dn, signature)
            return proxy_principal

    @staticmethod
    def _cache_key(target_app, subject):
        return "{};{}".format(target_app, subject)

    @staticmethod
    def _get_cache_key(token_type, subject, exclude_auths=None, request_chain=None, target_security_id=None):
        li = []
        chk_append = lambda l, a: l.append(a) if a is not None else None
        chk_append(li, str(token_type))
        chk_append(li, subject)
        ea_str = None if exclude_auths is None else ';'.join(sorted(exclude_auths))
        chk_append(li, ea_str)
        rc_str = None if request_chain is None else ';'.join(request_chain)
        chk_append(li, rc_str)
        chk_append(li, target_security_id)
        return '|'.join(li)

    def get_client(self):
        """
        Returns an EzSecurity.Client object that users can use to call the
        EzSecurity service directly.
        """
        self.client_pool.get_client(service_name=SECURITY_SERVICE_NAME, clazz=Client)

    def close_client_pool(self):
        if self.__local_pool:
            self.client_pool.close()

    def _ensure_keys(self):
        if self.privateKey is None:
            self.privateKey = self._read_file(
                self.securityConfig.getPrivateKey())

        if self.servicePublic is None:
            self.servicePublic = self._read_file(
                self.securityConfig.getServicePublicKey())

        # attempt to get the server's private key if we're in the mock-mode
        if self.mock and not self.servicePrivate:
            private_key_path = self.ez_props.getProperty(
                MOCK_SERVER_KEY_PRIVATE)
            private_key_exists = \
                os.path.exists(private_key_path) if private_key_path else False
            if private_key_path and private_key_exists:
                self.servicePrivate = self._read_file(private_key_path)

    def _sign(self, data):

        self._ensure_keys()
        return util.ssl_sign(data, self.privateKey)

    def _mock_service_sign(self, data):
        """
        Looks up the service's private key if the client is in mock-mode, and
        and signs the data with the server's private key.

        WARNING: DO NOT USE THIS METHOD FOR CODE THAT WILL BE USED IN PROD.

        :param data:
        :return:
        """
        if self.mock:
            self._ensure_keys()
            if self.servicePrivate is not None:
                return util.ssl_sign(data, self.servicePrivate)
            else:
                return ""

        raise ValueError("_mock_service_sign can only be called in mock-mode.")

    def ping(self):
        """
        Ping the security service
        @return true if the service is healthy
        """
        ret = self.client.ping()
        return ret

    def _user_dn(self, dn):
        """
        Request a signed DN from the security service. Note this will most
        likely fail, since it only signs DNs for the EFE
        @param dn: the user's X509 subject
        @return an EzSecurityPrincipal with a valid signature
        """
        headers = {
            HTTP_HEADER_USER_INFO: dn,
            HTTP_HEADER_SIGNATURE: ""
        }
        request, signature = self.build_request(headers)
        dn = self.client.requestUserDN(request, signature)
        return dn

    def fetch_app_token(self, targetApp=None, excludedAuths=None, skipCache=False):
        """
        Request a token containing application info, optionally with a target
        securityId in the token. If the targetApp is specified, you will be
        able to send this token to another application, and it will validate on
        the other end. You should set txApp to
        ApplicationConfiguration(ez_props).getSecurityID() if you are sending
        this to another thrift service within your application
        @param targetApp: optionally, request security service to include a
        targetSecurityId in the token
        @return the EzSecurityToken
        """

        app = self.appConfig.getApplicationName()

        headers = {
            HTTP_HEADER_USER_INFO: app,
            HTTP_HEADER_SIGNATURE: ''
        }

        if targetApp is None:
            targetApp = self.appConfig.getSecurityID()

        # look in the cache
        cache_key = self._get_cache_key(TokenType.APP, headers.get(HTTP_HEADER_USER_INFO), excludedAuths,
                                        target_security_id=targetApp)
        if not skipCache:
            token = self.__get_from_cache(cache_key)
            if token:
                return token

        request, signature = self.build_request(headers, targetApp, token_type=TokenType.APP,
                                                exclude_authorizations=excludedAuths)
        return self._request_token_and_store(request, signature, "app", app, cache_key)

    def fetch_user_token(self, headers, target_app=None, skipCache=False):
        """
        Request a token with user info. Includes a targetSecurityId
        in the token if the txApp is passed. If targetSecurityId is set in the
        token, you will be able to pass this token to other thrift services.
        You should set txApp to
        ApplicationConfiguration(ez_props).getSecurityID() if you are sending
        this to another thrift service within your application,
        @param target_app: optionally, request security service to include a
        targetSecurityId in the token
        @return: the EzSecurityToken
        """
        dn = headers.get(HTTP_HEADER_USER_INFO)

        if target_app is None:
            target_app = self.appConfig.getSecurityID()
        if self.mock and dn is None:
            dn = self.ez_props.get(MOCK_USER_DN)
            if dn is None:
                raise RuntimeError("{0} is in mock mode, but {1} is None".
                                   format(self.__class__, MOCK_USER_DN))

        # look in the cache (and return immediately if in cache)
        cache_key = self._get_cache_key(TokenType.USER, dn, target_security_id=target_app)
        if not skipCache:
            token = self.__get_from_cache(cache_key)
            if token:
                return token

        # get token (since it wasn't found in the cache)
        request, signature = self.build_request(headers, target_app)
        return self._request_token_and_store(request, signature, "user", dn, cache_key)

    def fetch_derived_token(self, ezSecurityToken, targetApp,
                            excludedAuths=None, skipCache=False):
        """
        Used when an application receives an EzSecurityToken as part of it's
        API but needs to call another service that itself takes an
        EzSecurityToken.

        :param ezSecurityToken:
        :param targetApp:
        :param excludedAuths:
        :return:
        """

        # get the security id for target app (depending on if its a common
        # service or an application)
        dc = ServiceDiscoveryClient(self.zk_con_str)
        targetSecurityId = dc.get_security_id(targetApp)
        token_request = TokenRequest(
            self.appConfig.getSecurityID(),
            util.current_time_millis()
        )
        token_request.tokenPrincipal = ezSecurityToken
        token_request.targetSecurityId = targetSecurityId
        token_request.excludeAuthorizations = excludedAuths

        # look in the cache (and return immediately if in cache)
        dn = ezSecurityToken.tokenPrincipal.principal
        request_chain = ezSecurityToken.tokenPrincipal.requestChain
        cache_key = self._get_cache_key(ezSecurityToken.type, dn, excludedAuths, request_chain, targetSecurityId)
        if not skipCache:
            token = self.__get_from_cache(cache_key)
            if token:
                return token

        # get token (since it wasn't found in the cache)
        headers = {
            HTTP_HEADER_USER_INFO: dn,
            HTTP_HEADER_SIGNATURE: self._sign(dn)
        }
        request, signature = self.build_request(headers, targetApp, exclude_authorizations=excludedAuths)
        return self._request_token_and_store(request, signature, "derived", dn, cache_key)

    def _request_token_and_store(self, request, signature, type_info, subject, cache_key):

        self.log.debug("Requesting %s token for %s from EzSecurity", type_info, subject)
        token = self.client.requestToken(request, signature)
        self.log.debug("Received %s token for %s from EzSecurity", type_info, subject)

        # validate the token we received if we're not mocking (i.e.: in dev)
        if not self.mock:
            if not self._validate_token(token):
                self.log.error("Invalid token received from EzSecurity")
                token = None

        if token is not None:
            self.log.info("Storing %s token %s into cache", type_info, subject)
            expires = token.validity.notAfter
            self.token_cache[cache_key] = (expires, token)

        return token

    def __get_from_cache(self, cache_key):
        """
        Shortcut for retrieving contents from cache.
        :param cache_key:
        :return: Contents of cache if found
        """
        try:
            token = self.token_cache[cache_key]
            if self._validate_token(token):
                self.log.info("Using token from cache")
                return token
            else:
                self.log.info("Token in cache was invalid. getting new")
        except KeyError:
            # it's not in the cache, continue
            pass

        return None

    def _validate_token(self, token):
        """
        Internal method for verifying tokens received from the security service
        @param token: the received EzSecurityToken
        @return: true if the token is valid
        """
        self._ensure_keys()
        return util.verify(token, self.servicePublic,
                           self.appConfig.getSecurityID(), None)

    def validate_received_token(self, token):
        """
        Validate a token that was received in a thrift request. This must be
        called whenever your application receives an EzSecurityToken from an
        unknown source (even if you think you know where it came from)
        @param token: the received EzSecurityToken
        @return: true if the token is valid
        """
        if self.mock:
            return True
        self._ensure_keys()
        return util.verify(token, self.servicePublic, None,
                           self.appConfig.getSecurityID())

    def validate_signed_dn(self, dn, signature):
        """
        Validate a DN/Signature pair that is expected to have been signed by
        the security service
        @param dn: the dn
        @param signature: the security service signature
        @return: true if the DN validates
        """
        self._ensure_keys()
        return util.verify_signed_dn(dn, signature, self.servicePublic)

    def build_request(self, headers, target_app=None, token_type=TokenType.USER, exclude_authorizations=None):
        """
        Build a TokenRequest for the given information.
        @param target_app: the optional targetSecurityId
        @return: A TokenRequest for the request
        """
        token = TokenRequest(securityId=self.appConfig.getSecurityID(),
                             targetSecurityId=target_app,
                             timestamp=util.current_time_millis(),
                             type=token_type,
                             excludeAuthorizations=exclude_authorizations)
        token.targetSecurityId = target_app

        if token_type == TokenType.USER:
            token.proxyPrincipal = self.principal_from_request(headers)

        # generate signature
        if not self.mock:
            signature = self._sign(util.serialize_token_request(token))
        else:
            signature = ""

        return token, signature

    def validate_current_request(self, headers):
        """
        Verifies that the dn provided is valid (based on signature) and
        :return: True if the request is valid, False if it is invalid invalid.
        """

        # if we're mocking, return True
        if self.mock and not self.servicePrivate:
            return True

        now = util.current_time_millis()
        try:
            dn = headers[HTTP_HEADER_USER_INFO]
            sig = headers[HTTP_HEADER_SIGNATURE]
            self._ensure_keys()
            pubkey = self.servicePublic

            # verify the user_info header with the signature
            verified = util.verify_proxy_token_signature(dn, sig, pubkey)
            if not verified:
                return False

            # verify that the  ProxyUserToken has not expired
            json_dict = util.deserialize_from_json(dn)

            # populate X509
            x509 = X509Info()
            x509.__dict__.update(json_dict['x509'])

            # populate ProxyUserToken
            proxy_user_token = ProxyUserToken()
            proxy_user_token.__dict__.update(json_dict)
            proxy_user_token.x509 = x509

            if proxy_user_token.notAfter < now:
                return False

            return True
        except KeyError:
            self.log.exception("Unable to validate current request.")
            return False
class ThriftClientPoolTest(KazooTestCase):

    def setUp(self):
        """
        """
        super(ThriftClientPoolTest, self).setUp()
        ezd_client = ServiceDiscoveryClient(self.hosts)

        ez_props = EzConfiguration().getProperties()
        ez_props["thrift.use.ssl"] = "false"
        ez_props["zookeeper.connection.string"] = self.hosts
        application_name = ApplicationConfiguration(ez_props).getApplicationName()

        self.serverProcesses = []
        for endpoint in ENDPOINTS:
            host, port = endpoint.split(':')
            port = int(port)
            server_process = Process(target=start_ezpz, args=(EzPzHandler(), port,))
            server_process.start()
            time.sleep(1)
            self.serverProcesses.append(server_process)
            ezd_client.register_endpoint(application_name, "ezpz", host, port)

        ezd_client.register_endpoint(application_name, "service_one", 'localhost', 8083)
        ezd_client.register_endpoint(application_name, "service_two", 'localhost', 8084)
        ezd_client.register_endpoint(application_name, "service_three", 'localhost', 8085)

        ezd_client.register_common_endpoint('common_service_one', 'localhost', 8080)
        ezd_client.register_common_endpoint('common_service_two', 'localhost', 8081)
        ezd_client.register_common_endpoint('common_service_three', 'localhost', 8082)
        ezd_client.register_common_endpoint('common_service_multi', '192.168.1.1', 6060)
        ezd_client.register_common_endpoint('common_service_multi', '192.168.1.2', 6161)

        ezd_client.register_endpoint("NotThriftClientPool", "unknown_service_three", 'localhost', 8091)
        ezd_client.register_endpoint("NotThriftClientPool", "unknown_service_three", 'localhost', 8092)
        ezd_client.register_endpoint("NotThriftClientPool", "unknown_service_three", 'localhost', 8093)

        self.clientPool = ThriftClientPool(ez_props)

    def tearDown(self):
        super(ThriftClientPoolTest, self).tearDown()
        self.clientPool.close()
        nt.assert_false(self.clientPool._get_service_map())
        nt.assert_false(self.clientPool._get_client_map())
        for server_process in self.serverProcesses:
            if server_process.is_alive():
                server_process.terminate()

    def test_endpoints(self):
        service_map = self.clientPool._get_service_map()
        self.assertTrue("common_service_one" in service_map)
        self.assertTrue("common_service_two" in service_map)
        self.assertTrue("common_service_three" in service_map)
        self.assertTrue("service_one" in service_map)
        self.assertTrue("service_two" in service_map)
        self.assertTrue("service_three" in service_map)
        self.assertFalse("unknown_service_one" in service_map)
        self.assertFalse("unknown_service_two" in service_map)
        self.assertFalse("unknown_service_three" in service_map)

        self.assertTrue("common_service_multi" in service_map)
        self.assertEqual(len(service_map["common_service_multi"]), 2)

        self.assertTrue("ezpz" in service_map)
        self.assertEqual(len(service_map["ezpz"]), len(ENDPOINTS))

    def test_get_client(self):
        client = self.clientPool.get_client(service_name='ezpz', clazz=EzPz.Client)
        try:
            nt.assert_equal('pz', client.ez())
        finally:
            client.close()

        client = self.clientPool.get_client(service_name='ezpz1', clazz=EzPz.Client)            # None existing service
        nt.assert_false(client)

    def test_get_client_for_app(self):
        client = self.clientPool.get_client(app_name='testApp', service_name='ezpz', clazz=EzPz.Client)
        try:
            nt.assert_equal('pz', client.ez())
        finally:
            client.close()

    def test_multi_get_client(self):
        client1 = self.clientPool.get_client(service_name='ezpz', clazz=EzPz.Client)
        client2 = self.clientPool.get_client(service_name='ezpz', clazz=EzPz.Client)
        try:
            nt.assert_equal('pz', client1.ez())
            nt.assert_equal('pz', client2.ez())
        finally:
            client1.close()
            client2.close()
class ThriftClientPoolTest(EzThriftServerTestHarness):

    def setUp(self):
        super(ThriftClientPoolTest, self).setUp()

        ez_props = EzConfiguration().getProperties()
        ez_props["thrift.use.ssl"] = "true"
        ez_props["zookeeper.connection.string"] = self.hosts
        application_name = ApplicationConfiguration(ez_props).getApplicationName()

        for endpoint in ENDPOINTS:
            host, port = endpoint.split(':')
            self.add_server(application_name, "ezpz_ssl", host, int(port), EzPz.Processor(EzPzHandler()),
                            use_ssl=True, ca_certs=servercapath, cert=servercertpath, key=serverprivpath)

        self.clientPool = ThriftClientPool(ez_props)

    def tearDown(self):
        self.clientPool.close()
        nt.assert_false(self.clientPool._get_service_map())
        nt.assert_false(self.clientPool._get_client_map())
        super(ThriftClientPoolTest, self).tearDown()

    def test_get_client(self):
        client = self.clientPool.get_client(service_name='ezpz_ssl', clazz=EzPz.Client)
        try:
            resp = client.ez()
            nt.assert_equal('pz', resp)
        finally:
            client.close()

        client = self.clientPool.get_client(service_name='ezpz1', clazz=EzPz.Client)            # None existing service
        nt.assert_false(client)

    def test_get_client_for_apps(self):
        client = self.clientPool.get_client(app_name='testApp', service_name='ezpz_ssl', clazz=EzPz.Client)
        try:
            resp = client.ez()
            nt.assert_equal('pz', resp)
        finally:
            client.close()

        client = self.clientPool.get_client(app_name='testApp', service_name='ezpz1', clazz=EzPz.Client)            # None existing service
        nt.assert_false(client)


    def test_multi_get_client(self):
        client1 = self.clientPool.get_client(service_name='ezpz_ssl', clazz=EzPz.Client)
        client2 = self.clientPool.get_client(service_name='ezpz_ssl', clazz=EzPz.Client)
        try:
            nt.assert_equal('pz', client1.ez())
            nt.assert_equal('pz', client2.ez())
        finally:
            self.clientPool.close()

    def test_multi_get_client_for_apps(self):
        client1 = self.clientPool.get_client(app_name='testApp', service_name='ezpz_ssl', clazz=EzPz.Client)
        client2 = self.clientPool.get_client(app_name='testApp', service_name='ezpz_ssl', clazz=EzPz.Client)
        try:
            nt.assert_equal('pz', client1.ez())
            nt.assert_equal('pz', client2.ez())
        finally:
            self.clientPool.close()