コード例 #1
0
    def __init__(self,
                 transport,
                 host,
                 service,
                 mechanism=six.u('GSSAPI'),
                 generate_tickets=False,
                 using_keytab=False,
                 principal=None,
                 keytab_file=None,
                 ccache_file=None,
                 password=None,
                 **sasl_kwargs):
        """
        transport: an underlying transport to use, typically just a TSocket
        host: the name of the server, from a SASL perspective
        service: the name of the server's service, from a SASL perspective
        mechanism: the name of the preferred mechanism to use
        All other kwargs will be passed to the puresasl.client.SASLClient
        constructor.
        """

        self.transport = transport

        if six.PY3:
            self._patch_pure_sasl()
        self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs)

        self.__wbuf = BytesIO()
        self.__rbuf = BytesIO()

        self.generate_tickets = generate_tickets
        if self.generate_tickets:
            self.krb_context = krbContext(using_keytab, principal, keytab_file,
                                          ccache_file, password)
            self.krb_context.init_with_keytab()
コード例 #2
0
    def authenticate_xmpp(self):
        """Authenticate the user to the XMPP server via the BOSH connection."""

        self.request_sid()

        self.log.debug('Prepare the XMPP authentication')

        # Instantiate a sasl object
        sasl = SASLClient(host=self.to,
                          service='xmpp',
                          username=self.jid,
                          password=self.password)

        # Choose an auth mechanism
        sasl.choose_mechanism(self.server_auth, allow_anonymous=False)

        # Request challenge
        challenge = self.get_challenge(sasl.mechanism)

        # Process challenge and generate response
        challengeString = base64.b64decode(challenge)
        if not 'realm=' in challengeString:
            challengeString += ',realm="random"'
        response = sasl.process(challengeString)
        # Send response
        resp_root = self.send_challenge_response(response)

        success = self.check_authenticate_success(resp_root)
        if success is None and\
                resp_root.find('{{{0}}}challenge'.format(XMPP_SASL_NS)) is not None:
            resp_root = self.send_challenge_response('')
            return self.check_authenticate_success(resp_root)
        return success
コード例 #3
0
ファイル: sasl.py プロジェクト: ox-it/aioldap
def sasl_bind(client, host):
    sasl_client = SASLClient(host, service='ldap', mechanism='GSSAPI')
    
    sasl_credentials = SaslCredentials()
    sasl_credentials.setComponentByName("mechanism", LDAPString("gssapi"))
    sasl_credentials.setComponentByName("credentials", sasl_client.process(None))

    authentication_choice = AuthenticationChoice()
    authentication_choice.setComponentByName('sasl', sasl_credentials)
    
    bind_request = BindRequest()
    bind_request.setComponentByName('version', Version(3))
    bind_request.setComponentByName('name', LDAPDN(''))
    bind_request.setComponentByName('authentication', authentication_choice)
    
    protocol_op = ProtocolOp()
    protocol_op.setComponentByName("bindRequest", bind_request)
    
    ber_encode(authentication_choice)
    ber_encode(sasl_credentials)
    print(bind_request.prettyPrint())
    ber_encode(bind_request)
    ber_encode(protocol_op)
    response = yield from client.request(protocol_op)
    
    print(response)
コード例 #4
0
ファイル: test_mechanism.py プロジェクト: ryan-pip/pure-sasl
class _BaseMechanismTests(unittest.TestCase):

    mechanism_class = AnonymousMechanism
    sasl_kwargs = {}

    def setUp(self):
        self.sasl = SASLClient('localhost', mechanism=self.mechanism_class.name, **self.sasl_kwargs)
        self.mechanism = self.sasl._chosen_mech

    def test_init_basic(self, *args):
        sasl = SASLClient('localhost', mechanism=self.mechanism_class.name, **self.sasl_kwargs)
        mech = sasl._chosen_mech
        self.assertIs(mech.sasl, sasl)
        self.assertIsInstance(mech, self.mechanism_class)

    def test_process_basic(self, *args):
        self.assertIsInstance(self.sasl.process(six.b('string')), six.binary_type)
        self.assertIsInstance(self.sasl.process(six.b('string')), six.binary_type)

    def test_dispose_basic(self, *args):
        self.sasl.dispose()

    def test_wrap_unwrap(self, *args):
        self.assertRaises(NotImplementedError, self.sasl.wrap, 'msg')
        self.assertRaises(NotImplementedError, self.sasl.unwrap, 'msg')

    def test__pick_qop(self, *args):
        self.assertRaises(SASLProtocolException, self.sasl._chosen_mech._pick_qop, set())
        self.sasl._chosen_mech._pick_qop(set(QOP.all))
コード例 #5
0
    def authenticate_xmpp(self):
        """Authenticate the user to the XMPP server via the BOSH connection."""

        self.request_sid()

        self.log.debug('Prepare the XMPP authentication')

        # Instantiate a sasl object
        sasl = SASLClient(
            host=self.to,
            service='xmpp',
            username=self.jid,
            password=self.password
        )

        # Choose an auth mechanism
        sasl.choose_mechanism(self.server_auth, allow_anonymous=False)

        # Request challenge
        challenge = self.get_challenge(sasl.mechanism)

        # Process challenge and generate response
        response = sasl.process(base64.b64decode(challenge))

        # Send response
        resp_root = self.send_challenge_response(response)

        success = self.check_authenticate_success(resp_root)
        if success is None and\
                resp_root.find('{{{0}}}challenge'.format(XMPP_SASL_NS)) is not None:
            resp_root = self.send_challenge_response('')
            return self.check_authenticate_success(resp_root)
        return success
コード例 #6
0
class _BaseMechanismTests(unittest.TestCase):

    mechanism_class = AnonymousMechanism
    sasl_kwargs = {}

    def setUp(self):
        self.sasl = SASLClient('localhost', mechanism=self.mechanism_class.name, **self.sasl_kwargs)
        self.mechanism = self.sasl._chosen_mech

    def test_init_basic(self, *args):
        sasl = SASLClient('localhost', mechanism=self.mechanism_class.name, **self.sasl_kwargs)
        mech = sasl._chosen_mech
        self.assertIs(mech.sasl, sasl)
        self.assertIsInstance(mech, self.mechanism_class)

    def test_process_basic(self, *args):
        self.assertIsInstance(self.sasl.process(six.b('string')), six.binary_type)
        self.assertIsInstance(self.sasl.process(six.b('string')), six.binary_type)

    def test_dispose_basic(self, *args):
        self.sasl.dispose()

    def test_wrap_unwrap(self, *args):
        self.assertRaises(NotImplementedError, self.sasl.wrap, 'msg')
        self.assertRaises(NotImplementedError, self.sasl.unwrap, 'msg')

    def test__pick_qop(self, *args):
        self.assertRaises(SASLProtocolException, self.sasl._chosen_mech._pick_qop, set())
        self.sasl._chosen_mech._pick_qop(set(QOP.all))
コード例 #7
0
 def __init__(self, transport, host, service, mechanism='GSSAPI',
              **sasl_kwargs):
     from puresasl.client import SASLClient
     self.transport = transport
     self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs)
     self.__wbuf = BufferIO()
     self.__rbuf = BufferIO(b'')
コード例 #8
0
    def authenticate_xmpp(self):
        """Authenticate the user to the XMPP server via the BOSH connection."""

        self.request_sid()

        self.log.debug('Prepare the XMPP authentication')

        # Instantiate a sasl object 
        sasl = SASLClient(host=self.to,
                         service='xmpp',
                         username=self.jid,
                         password=self.password)

        # Choose an auth mechanism
        sasl.choose_mechanism(self.server_auth, allow_anonymous=False)

        # Request challenge
        challenge = self.get_challenge(sasl.mechanism)
        
        # Process challenge and generate response
        response = sasl.process(base64.b64decode(challenge))

        # Send response
        success = self.send_challenge_response(response)
        if not success:
            return False

        self.request_restart()

        self.bind_resource()
        
        return True
コード例 #9
0
 def __init__(self, host, service, qops, properties):
     properties = properties or {}
     self.sasl = SASLClient(host,
                            service,
                            'GSSAPI',
                            qops=qops,
                            **properties)
コード例 #10
0
    def test_init(self):
        # defaults
        SASLClient('localhost')

        # with mechanism
        sasl_client = SASLClient('localhost', mechanism=AnonymousMechanism.name)
        self.assertIsInstance(sasl_client._chosen_mech, AnonymousMechanism)
        self.assertIs(sasl_client._chosen_mech.sasl, sasl_client)

        # invalid mech
        self.assertRaises(SASLError, SASLClient, 'localhost', mechanism='WRONG')
コード例 #11
0
    def connect(self):
        # use service name component from principal
        service = re.split('[\/@]', str(self.hdfs_namenode_principal))[0]

        if not self.sasl:
            self.sasl = SASLClient(self._trans.host, service)

        negotiate = RpcSaslProto()
        negotiate.state = 1
        self._send_sasl_message(negotiate)

        # do while true
        while True:
            res = self._recv_sasl_message()
            # TODO: check mechanisms
            if res.state == 1:
                mechs = []
                for auth in res.auths:
                    mechs.append(auth.mechanism)

                log.debug("Available mechs: %s" % (",".join(mechs)))
                self.sasl.choose_mechanism(mechs, allow_anonymous=False)
                log.debug("Chosen mech: %s" % self.sasl.mechanism)

                initiate = RpcSaslProto()
                initiate.state = 2
                initiate.token = self.sasl.process()

                for auth in res.auths:
                    if auth.mechanism == self.sasl.mechanism:
                        auth_method = initiate.auths.add()
                        auth_method.mechanism = self.sasl.mechanism
                        auth_method.method = auth.method
                        auth_method.protocol = auth.protocol
                        auth_method.serverId = self._trans.host

                self._send_sasl_message(initiate)
                continue

            if res.state == 3:
                res_token = self._evaluate_token(res)
                response = RpcSaslProto()
                response.token = res_token
                response.state = 4
                self._send_sasl_message(response)
                continue

            if res.state == 0:
                return True
コード例 #12
0
class SaslAuthenticator(Authenticator):
    """
    A pass-through :class:`~.Authenticator` using the third party package
    'pure-sasl' for authentication
    """
    def __init__(self, host, service, mechanism='GSSAPI', **sasl_kwargs):
        if SASLClient is None:
            raise ImportError('The puresasl library has not been installed')
        self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs)

    def initial_response(self):
        return self.sasl.process()

    def evaluate_challenge(self, challenge):
        return self.sasl.process(challenge)
コード例 #13
0
ファイル: auth.py プロジェクト: datastax/python-driver-dse
class GSSAPIAuthenticator(BaseDSEAuthenticator):
    def __init__(self, host, service, qops, properties):
        properties = properties or {}
        self.sasl = SASLClient(host, service, 'GSSAPI', qops=qops, **properties)

    def get_mechanism(self):
        return "GSSAPI"

    def get_initial_challenge(self):
        return "GSSAPI-START"

    def evaluate_challenge(self, challenge):
        if challenge == 'GSSAPI-START':
            return self.sasl.process()
        else:
            return self.sasl.process(challenge)
コード例 #14
0
 def test_init_basic(self, *args):
     sasl = SASLClient('localhost',
                       mechanism=self.mechanism_class.name,
                       **self.sasl_kwargs)
     mech = sasl._chosen_mech
     self.assertIs(mech.sasl, sasl)
     self.assertIsInstance(mech, self.mechanism_class)
コード例 #15
0
ファイル: auth.py プロジェクト: IChocolateKapa/python-driver
class SaslAuthenticator(Authenticator):
    """
    An :class:`~.Authenticator` that works with DSE's KerberosAuthenticator.

    .. versionadded:: 2.1.3-post
    """

    def __init__(self, host, service, mechanism='GSSAPI', **sasl_kwargs):
        if SASLClient is None:
            raise ImportError('The puresasl library has not been installed')
        self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs)

    def initial_response(self):
        return self.sasl.process()

    def evaluate_challenge(self, challenge):
        return self.sasl.process(challenge)
コード例 #16
0
 def test_unchosen_mechanism(self):
     client = SASLClient('localhost')
     self.assertRaises(SASLError, client.process)
     self.assertRaises(SASLError, client.wrap, 'msg')
     self.assertRaises(SASLError, client.unwrap, 'msg')
     with self.assertRaises(SASLError):
         client.complete
     self.assertRaises(SASLError, client.dispose)
コード例 #17
0
ファイル: auth.py プロジェクト: StuartAxelOwen/python-driver
class SaslAuthenticator(Authenticator):
    """
    A pass-through :class:`~.Authenticator` using the third party package
    'pure-sasl' for authentication

    .. versionadded:: 2.1.4
    """

    def __init__(self, host, service, mechanism='GSSAPI', **sasl_kwargs):
        if SASLClient is None:
            raise ImportError('The puresasl library has not been installed')
        self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs)

    def initial_response(self):
        return self.sasl.process()

    def evaluate_challenge(self, challenge):
        return self.sasl.process(challenge)
コード例 #18
0
 def test_chosen_mechanism(self):
     client = SASLClient('localhost', mechanism=PlainMechanism.name, username='******', password='******')
     self.assertTrue(client.process())
     self.assertTrue(client.complete)
     msg = 'msg'
     self.assertEqual(client.wrap(msg), msg)
     self.assertEqual(client.unwrap(msg), msg)
     client.dispose()
コード例 #19
0
    def __init__(self, transport, host, service, mechanism=six.u('GSSAPI'),
                 **sasl_kwargs):
        """
        transport: an underlying transport to use, typically just a TSocket
        host: the name of the server, from a SASL perspective
        service: the name of the server's service, from a SASL perspective
        mechanism: the name of the preferred mechanism to use
        All other kwargs will be passed to the puresasl.client.SASLClient
        constructor.
        """

        self.transport = transport

        if six.PY3:
            self._patch_pure_sasl()
        self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs)

        self.__wbuf = BytesIO()
        self.__rbuf = BytesIO()
コード例 #20
0
ファイル: TTransport.py プロジェクト: zhaoche27/thrift
    def __init__(self, transport, host, service, mechanism='GSSAPI',
                 **sasl_kwargs):
        """
        transport: an underlying transport to use, typically just a TSocket
        host: the name of the server, from a SASL perspective
        service: the name of the server's service, from a SASL perspective
        mechanism: the name of the preferred mechanism to use

        All other kwargs will be passed to the puresasl.client.SASLClient
        constructor.
        """

        from puresasl.client import SASLClient

        self.transport = transport
        self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs)

        self.__wbuf = BufferIO()
        self.__rbuf = BufferIO(b'')
コード例 #21
0
ファイル: connection.py プロジェクト: EmergingThreats/pycassa
    def __init__(self, transport, host, service,
            mechanism='GSSAPI', **sasl_kwargs):

        from puresasl.client import SASLClient

        self.transport = transport
        self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs)

        self.__wbuf = StringIO()
        self.__rbuf = StringIO()
コード例 #22
0
class GSSAPIAuthenticator(BaseDSEAuthenticator):
    def __init__(self, host, service, qops, properties):
        properties = properties or {}
        self.sasl = SASLClient(host,
                               service,
                               'GSSAPI',
                               qops=qops,
                               **properties)

    def get_mechanism(self):
        return six.b("GSSAPI")

    def get_initial_challenge(self):
        return six.b("GSSAPI-START")

    def evaluate_challenge(self, challenge):
        if challenge == six.b('GSSAPI-START'):
            return self.sasl.process()
        else:
            return self.sasl.process(challenge)
コード例 #23
0
ファイル: test_mechanism.py プロジェクト: beltran/pure-sasl
    def test_process_with_authorization_id_or_identity(self):
        challenge = u"\U0001F44D"
        identity = 'user2'

        # Test that we can pass an identity
        sasl_kwargs = self.sasl_kwargs.copy()
        sasl_kwargs.update({'identity': identity})
        sasl = SASLClient('localhost',
                          mechanism=self.mechanism_class.name,
                          **sasl_kwargs)
        response = sasl.process(challenge)
        self.assertEqual(
            response,
            six.b('{0}\x00{1}\x00{2}'.format(identity, self.username,
                                             self.password)))
        self.assertIsInstance(response, six.binary_type)

        # Test that the sasl authorization_id has priority over identity
        auth_id = 'user3'
        sasl_kwargs.update({'authorization_id': auth_id})
        sasl = SASLClient('localhost',
                          mechanism=self.mechanism_class.name,
                          **sasl_kwargs)
        response = sasl.process(challenge)
        self.assertEqual(
            response,
            six.b('{0}\x00{1}\x00{2}'.format(auth_id, self.username,
                                             self.password)))
        self.assertIsInstance(response, six.binary_type)
コード例 #24
0
ファイル: thrift_sasl.py プロジェクト: yoziru-desu/pyhs2
 def __init__(self, transport, host, service, mechanism="GSSAPI", **sasl_kwargs):
     """
     transport: an underlying transport to use, typically just a TSocket
     host: the name of the server, from a SASL perspective
     service: the name of the server's service, from a SASL perspective
     mechanism: the name of the preferred mechanism to use
     All other kwargs will be passed to the puresasl.client.SASLClient
     constructor.
     """
     self.transport = transport
     self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs)
     self.__wbuf = StringIO()
     self.__rbuf = StringIO()
コード例 #25
0
 def test_chosen_mechanism(self):
     client = SASLClient('localhost', mechanism=PlainMechanism.name, username='******', password='******')
     self.assertTrue(client.process())
     self.assertTrue(client.complete)
     msg = 'msg'
     self.assertEqual(client.wrap(msg), msg)
     self.assertEqual(client.unwrap(msg), msg)
     client.dispose()
コード例 #26
0
ファイル: test_mechanism.py プロジェクト: ksauzz/pure-sasl
    def test_process_server_answer(self):
        sasl_kwargs = {'username': "******", 'password': "******"}
        sasl = SASLClient('elwood.innosoft.com',
                          service="imap",
                          mechanism=self.mechanism_class.name,
                          mutual_auth=True,
                          **sasl_kwargs)
        testChallenge = (
            b'utf-8,username="******",realm="elwood.innosoft.com",'
            b'nonce="OA6MG9tEQGm2hh",nc=00000001,cnonce="OA6MHXh6VqTrRk",'
            b'digest-uri="imap/elwood.innosoft.com",'
            b'response=d388dad90d4bbd760a152321f2143af7,qop=auth')
        sasl.process(testChallenge)
        # cnonce is generated randomly so we have to set it so
        # we assert the expected value
        sasl._chosen_mech.cnonce = b"OA6MHXh6VqTrRk"

        serverResponse = (b'rspauth=ea40f60335c427b5527b84dbabcdfffd')
        sasl.process(serverResponse)
コード例 #27
0
 def __init__(self, host, service, mechanism='GSSAPI', **sasl_kwargs):
     if SASLClient is None:
         raise ImportError('The puresasl library has not been installed')
     self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs)
コード例 #28
0
    def test_choose_mechanism(self):
        client = SASLClient('localhost', service='something')
        choices = ['invalid']
        self.assertRaises(SASLError, client.choose_mechanism, choices)

        choices = [m for m in mechanisms.values() if m is not DigestMD5Mechanism]
        mech_names = set(m.name for m in choices)
        client.choose_mechanism(mech_names)
        self.assertIsInstance(client._chosen_mech, max(choices, key=lambda m: m.score))

        anon_names = set(m.name for m in choices if m.allows_anonymous)
        client.choose_mechanism(anon_names)
        self.assertIn(client.mechanism, anon_names)
        self.assertRaises(SASLError, client.choose_mechanism, anon_names, allow_anonymous=False)

        plain_names = set(m.name for m in choices if m.uses_plaintext)
        client.choose_mechanism(plain_names)
        self.assertIn(client.mechanism, plain_names)
        self.assertRaises(SASLError, client.choose_mechanism, plain_names, allow_plaintext=False)

        not_active_names = set(m.name for m in choices if not m.active_safe)
        client.choose_mechanism(not_active_names)
        self.assertIn(client.mechanism, not_active_names)
        self.assertRaises(SASLError, client.choose_mechanism, not_active_names, allow_active=False)

        not_dict_names = set(m.name for m in choices if not m.dictionary_safe)
        client.choose_mechanism(not_dict_names)
        self.assertIn(client.mechanism, not_dict_names)
        self.assertRaises(SASLError, client.choose_mechanism, not_dict_names, allow_dictionary=False)
コード例 #29
0
ファイル: _sasl.py プロジェクト: ClearwaterCore/Telephus
class ThriftSASLClientProtocol(ThriftClientProtocol):

    START = 1
    OK = 2
    BAD = 3
    ERROR = 4
    COMPLETE = 5

    MAX_LENGTH = 2 ** 31 - 1

    def __init__(self, client_class, iprot_factory, oprot_factory=None,
            host=None, service=None, mechanism='GSSAPI', **sasl_kwargs):
        ThriftClientProtocol.__init__(self, client_class, iprot_factory, oprot_factory)

        self._sasl_negotiation_deferred = None
        self._sasl_negotiation_status = None
        self.client = None

        if host is not None:
            self.createSASLClient(host, service, mechanism, **sasl_kwargs)

    def createSASLClient(self, host, service, mechanism, **kwargs):
        self.sasl = SASLClient(host, service, mechanism, **kwargs)

    def dispatch(self, msg):
        encoded = self.sasl.wrap(msg)
        len_and_encoded = ''.join((struct.pack('!i', len(encoded)), encoded))
        ThriftClientProtocol.dispatch(self, len_and_encoded)

    @defer.inlineCallbacks
    def connectionMade(self):
        self._sendSASLMessage(self.START, self.sasl.mechanism)
        initial_message = yield deferToThread(self.sasl.process)
        self._sendSASLMessage(self.OK, initial_message)

        while True:
            status, challenge = yield self._receiveSASLMessage()
            if status == self.OK:
                response = yield deferToThread(self.sasl.process, challenge)
                self._sendSASLMessage(self.OK, response)
            elif status == self.COMPLETE:
                if not self.sasl.complete:
                    msg = "The server erroneously indicated that SASL " \
                          "negotiation was complete"
                    raise TTransportException(msg, message=msg)
                else:
                    break
            else:
                msg = "Bad SASL negotiation status: %d (%s)" % (status, challenge)
                raise TTransportException(msg, message=msg)

        self._sasl_negotiation_deferred = None
        ThriftClientProtocol.connectionMade(self)

    def _sendSASLMessage(self, status, body):
        if body is None:
            body = ""
        header = struct.pack(">BI", status, len(body))
        self.transport.write(header + body)

    def _receiveSASLMessage(self):
        self._sasl_negotiation_deferred = defer.Deferred()
        self._sasl_negotiation_status = None
        return self._sasl_negotiation_deferred

    def connectionLost(self, reason=connectionDone):
        if self.client:
            ThriftClientProtocol.connectionLost(self, reason)

    def dataReceived(self, data):
        if self._sasl_negotiation_deferred:
            # we got a sasl challenge in the format (status, length, challenge)
            # save the status, let IntNStringReceiver piece the challenge data together
            self._sasl_negotiation_status, = struct.unpack("B", data[0])
            ThriftClientProtocol.dataReceived(self, data[1:])
        else:
            # normal frame, let IntNStringReceiver piece it together
            ThriftClientProtocol.dataReceived(self, data)

    def stringReceived(self, frame):
        if self._sasl_negotiation_deferred:
            # the frame is just a SASL challenge
            response = (self._sasl_negotiation_status, frame)
            self._sasl_negotiation_deferred.callback(response)
        else:
            # there's a second 4 byte length prefix inside the frame
            decoded_frame = self.sasl.unwrap(frame[4:])
            ThriftClientProtocol.stringReceived(self, decoded_frame)
コード例 #30
0
class SaslRpcClient:
    def __init__(self, trans, hdfs_namenode_principal=None):
        self.sasl = None
        self._trans = trans
        self.hdfs_namenode_principal = hdfs_namenode_principal

    def _send_sasl_message(self, message):
        rpcheader = RpcRequestHeaderProto()
        rpcheader.rpcKind = 2  # RPC_PROTOCOL_BUFFER
        rpcheader.rpcOp = 0
        rpcheader.callId = -33  # SASL
        rpcheader.retryCount = -1
        rpcheader.clientId = b""

        s_rpcheader = rpcheader.SerializeToString()
        s_message = message.SerializeToString()

        header_length = len(s_rpcheader) + encoder._VarintSize(
            len(s_rpcheader)) + len(s_message) + encoder._VarintSize(
                len(s_message))

        self._trans.write(struct.pack('!I', header_length))
        self._trans.write_delimited(s_rpcheader)
        self._trans.write_delimited(s_message)

        log_protobuf_message("Send out", message)

    def _recv_sasl_message(self):
        bytestream = self._trans.recv_rpc_message()
        sasl_response = self._trans.parse_response(bytestream, RpcSaslProto)

        return sasl_response

    def connect(self):
        # use service name component from principal
        service = re.split('[\/@]', str(self.hdfs_namenode_principal))[0]

        if not self.sasl:
            self.sasl = SASLClient(self._trans.host, service)

        negotiate = RpcSaslProto()
        negotiate.state = 1
        self._send_sasl_message(negotiate)

        # do while true
        while True:
            res = self._recv_sasl_message()
            # TODO: check mechanisms
            if res.state == 1:
                mechs = []
                for auth in res.auths:
                    mechs.append(auth.mechanism)

                log.debug("Available mechs: %s" % (",".join(mechs)))
                self.sasl.choose_mechanism(mechs, allow_anonymous=False)
                log.debug("Chosen mech: %s" % self.sasl.mechanism)

                initiate = RpcSaslProto()
                initiate.state = 2
                initiate.token = self.sasl.process()

                for auth in res.auths:
                    if auth.mechanism == self.sasl.mechanism:
                        auth_method = initiate.auths.add()
                        auth_method.mechanism = self.sasl.mechanism
                        auth_method.method = auth.method
                        auth_method.protocol = auth.protocol
                        auth_method.serverId = self._trans.host

                self._send_sasl_message(initiate)
                continue

            if res.state == 3:
                res_token = self._evaluate_token(res)
                response = RpcSaslProto()
                response.token = res_token
                response.state = 4
                self._send_sasl_message(response)
                continue

            if res.state == 0:
                return True

    def _evaluate_token(self, sasl_response):
        return self.sasl.process(challenge=sasl_response.token)

    def wrap(self, message):
        encoded = self.sasl.wrap(message)

        sasl_message = RpcSaslProto()
        sasl_message.state = 5  #  WRAP
        sasl_message.token = encoded

        self._send_sasl_message(sasl_message)

    def unwrap(self):
        response = self._recv_sasl_message()
        if response.state != 5:
            raise Exception("Server send non-wrapped response")

        return self.sasl.unwrap(response.token)

    def use_wrap(self):
        # SASL wrapping is only used if the connection has a QOP, and
        # the value is not auth.  ex. auth-int & auth-priv
        if self.sasl.qop.decode() == 'auth-int' or self.sasl.qop.decode(
        ) == 'auth-conf':
            return True
        return False
コード例 #31
0
class TSaslClientTransport(TTransportBase):
    """
    SASL transport
    """

    START = 1
    OK = 2
    BAD = 3
    ERROR = 4
    COMPLETE = 5

    def __init__(self,
                 transport,
                 host,
                 service,
                 mechanism=six.u('GSSAPI'),
                 generate_tickets=False,
                 using_keytab=False,
                 principal=None,
                 keytab_file=None,
                 ccache_file=None,
                 password=None,
                 **sasl_kwargs):
        """
        transport: an underlying transport to use, typically just a TSocket
        host: the name of the server, from a SASL perspective
        service: the name of the server's service, from a SASL perspective
        mechanism: the name of the preferred mechanism to use
        All other kwargs will be passed to the puresasl.client.SASLClient
        constructor.
        """

        self.transport = transport

        if six.PY3:
            self._patch_pure_sasl()
        self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs)

        self.__wbuf = BytesIO()
        self.__rbuf = BytesIO()

        self.generate_tickets = generate_tickets
        if self.generate_tickets:
            self.krb_context = krbContext(using_keytab, principal, keytab_file,
                                          ccache_file, password)
            self.krb_context.init_with_keytab()

    def _patch_pure_sasl(self):
        ''' we need to patch pure_sasl to support python 3 '''
        puresasl.mechanisms.mechanisms['GSSAPI'] = CustomGSSAPIMechanism

    def is_open(self):
        return self.transport.is_open() and bool(self.sasl)

    @with_ticket
    def open(self):
        if not self.transport.is_open():
            self.transport.open()

        self.send_sasl_msg(self.START, self.sasl.mechanism.encode('utf8'))
        self.send_sasl_msg(self.OK, self.sasl.process())

        while True:
            status, challenge = self.recv_sasl_msg()
            if status == self.OK:
                self.send_sasl_msg(self.OK, self.sasl.process(challenge))
            elif status == self.COMPLETE:
                if not self.sasl.complete:
                    raise TTransportException(
                        TTransportException.NOT_OPEN,
                        "The server erroneously indicated "
                        "that SASL negotiation was complete")
                else:
                    break
            else:
                raise TTransportException(
                    TTransportException.NOT_OPEN,
                    "Bad SASL negotiation status: %d (%s)" %
                    (status, challenge))

    def send_sasl_msg(self, status, body):
        '''
        body:bytes
        '''
        header = pack(">BI", status, len(body))
        self.transport.write(header + body)
        self.transport.flush()

    def recv_sasl_msg(self):
        header = readall(self.transport.read, 5)
        status, length = unpack(">BI", header)
        if length > 0:
            payload = readall(self.transport.read, length)
        else:
            payload = ""
        return status, payload

    @with_ticket
    def write(self, data):
        self.__wbuf.write(data)

    @with_ticket
    def flush(self):
        data = self.__wbuf.getvalue()
        encoded = self.sasl.wrap(data)
        if six.PY2:
            self.transport.write(''.join([pack("!i", len(encoded)), encoded]))
        else:
            self.transport.write(b''.join((pack("!i", len(encoded)), encoded)))
        self.transport.flush()
        self.__wbuf = BytesIO()

    def read(self, sz):
        ret = self.__rbuf.read(sz)
        if len(ret) != 0 or sz == 0:
            return ret

        self._read_frame()
        return self.__rbuf.read(sz)

    def _read_frame(self):
        header = readall(self.transport.read, 4)
        length, = unpack('!i', header)
        encoded = readall(self.transport.read, length)
        self.__rbuf = BytesIO(self.sasl.unwrap(encoded))

    def close(self):
        self.sasl.dispose()
        self.transport.close()
コード例 #32
0
ファイル: net.py プロジェクト: ashafer01/laurelin
class LDAPSocket(object):
    """Holds a connection to an LDAP server.

    :param str host_uri: "scheme://netloc" to connect to
    :param int connect_timeout: Number of seconds to wait for connection to be accepted
    :param bool ssl_verify: Validate the certificate and hostname on an SSL/TLS connection
    :param str ssl_ca_file: Path to PEM-formatted concatenated CA certficates file
    :param str ssl_ca_path: Path to directory with CA certs under hashed file names. See
                            https://www.openssl.org/docs/man1.1.0/ssl/SSL_CTX_load_verify_locations.html for more
                            information about the format of this directory.
    :param ssl_ca_data: An ASCII string of one or more PEM-encoded certs or a bytes object containing DER-encoded
                        certificates.
    :type ssl_ca_data: str or bytes
    """

    RECV_BUFFER = 4096

    # For ldapi:/// try to connect to these socket files in order
    # Globs must match exactly one result
    LDAPI_SOCKET_PATHS = ['/var/run/ldapi', '/var/run/slapd/ldapi', '/var/run/slapd-*.socket']

    # OIDs of unsolicited messages

    OID_DISCONNECTION_NOTICE = '1.3.6.1.4.1.1466.20036'  # RFC 4511 sec 4.4.1 Notice of Disconnection

    def __init__(self, host_uri, connect_timeout=5, ssl_verify=True, ssl_ca_file=None, ssl_ca_path=None,
                 ssl_ca_data=None):

        self._prop_init(connect_timeout)
        self._uri_connect(host_uri, ssl_verify, ssl_ca_file, ssl_ca_path, ssl_ca_data)

    def _prop_init(self, connect_timeout=5):
        # get socket ID number
        global _next_sock_id
        self.ID = _next_sock_id
        _next_sock_id += 1

        # misc init
        self._message_queues = {}
        self._next_message_id = 1
        self._sasl_client = None

        self.refcount = 0
        self.bound = False
        self.unbound = False
        self.abandoned_mids = []
        self.started_tls = False
        self.connect_timeout = connect_timeout

    def _parse_uri(self, host_uri):
        # parse host_uri
        parts = host_uri.split('://')
        if len(parts) == 1:
            netloc = unquote(parts[0])
            if netloc[0] == '/':
                scheme = 'ldapi'
            else:
                scheme = 'ldap'
        elif len(parts) == 2:
            scheme = parts[0]
            netloc = unquote(parts[1])
        else:
            raise LDAPError('Invalid host_uri')
        self.uri = '{0}://{1}'.format(scheme, netloc)
        return scheme, netloc

    def _uri_connect(self, host_uri, ssl_verify, ssl_ca_file, ssl_ca_path, ssl_ca_data):
        # connect
        scheme, netloc = self._parse_uri(host_uri)
        logger.info('Connecting to {0} on #{1}'.format(self.uri, self.ID))
        if scheme == 'ldap':
            self._inet_connect(netloc, 389)
        elif scheme == 'ldaps':
            self._inet_connect(netloc, 636)
            self.start_tls(ssl_verify, ssl_ca_file, ssl_ca_path, ssl_ca_data)
            logger.info('Connected with TLS on #{0}'.format(self.ID))
        elif scheme == 'ldapi':
            if not _have_unix_socket:
                raise LDAPError('Unix sockets are not supported on your platform, please choose a protocol other'
                                'than ldapi')
            self.sock_path = None
            self._sock = socket(AF_UNIX)
            self.host = 'localhost'

            if netloc == '/':
                for sockGlob in LDAPSocket.LDAPI_SOCKET_PATHS:
                    fn = glob(sockGlob)
                    if not fn:
                        continue
                    if len(fn) > 1:
                        logger.debug('Multiple results for glob {0}'.format(sockGlob))
                        continue
                    fn = fn[0]
                    try:
                        self._connect(fn)
                        self.sock_path = fn
                        break
                    except SocketError:
                        continue
                if self.sock_path is None:
                    raise LDAPConnectionError('Could not find any local LDAPI unix socket - full '
                                              'socket path must be supplied in URI')
            else:
                try:
                    self._connect(netloc)
                    self.sock_path = netloc
                except SocketError as e:
                    raise LDAPConnectionError('failed connect to unix socket {0} - {1} ({2})'.format(
                        netloc, e.strerror, e.errno
                    ))

            logger.debug('Connected to unix socket {0} on #{1}'.format(self.sock_path, self.ID))
        else:
            raise LDAPError('Unsupported scheme "{0}"'.format(scheme))

    def _connect(self, addr):
        self._sock.settimeout(self.connect_timeout)
        self._sock.connect(addr)
        self._sock.settimeout(None)

    def _inet_connect(self, netloc, default_port):
        ap = netloc.rsplit(':', 1)
        self.host = ap[0]
        if len(ap) == 1:
            port = default_port
        else:
            port = int(ap[1])
        try:
            self._sock = create_connection((self.host, port), self.connect_timeout)
            logger.debug('Connected to {0}:{1} on #{2}'.format(self.host, port, self.ID))
        except SocketError as e:
            raise LDAPConnectionError('failed connect to {0}:{1} - {2} ({3})'.format(
                                      self.host, port, e.strerror, e.errno))

    def start_tls(self, verify=True, ca_file=None, ca_path=None, ca_data=None):
        """Install TLS layer on this socket connection.

        :param bool verify: Validate the certificate and hostname on an SSL/TLS connection
        :param str ca_file: Path to PEM-formatted concatenated CA certficates file
        :param str ca_path: Path to directory with CA certs under hashed file names. See
                            https://www.openssl.org/docs/man1.1.0/ssl/SSL_CTX_load_verify_locations.html for more
                            information about the format of this directory.
        :param ca_data: An ASCII string of one or more PEM-encoded certs or a bytes object containing DER-encoded
                        certificates.
        :type ca_data: str or bytes
        """
        if self.started_tls:
            raise LDAPError('TLS layer already installed')

        if verify:
            verify_mode = ssl.CERT_REQUIRED
        else:
            verify_mode = ssl.CERT_NONE

        try:
            proto = ssl.PROTOCOL_TLS
        except AttributeError:
            proto = ssl.PROTOCOL_SSLv23

        try:
            ctx = ssl.SSLContext(proto)
            ctx.verify_mode = verify_mode
            ctx.check_hostname = False  # we do this ourselves
            if verify:
                ctx.load_default_certs()
            if ca_file or ca_path or ca_data:
                ctx.load_verify_locations(cafile=ca_file, capath=ca_path, cadata=ca_data)
            self._sock = ctx.wrap_socket(self._sock)
        except AttributeError:
            # SSLContext wasn't added until 2.7.9
            if ca_path or ca_data:
                raise RuntimeError('python version >= 2.7.9 required for SSL ca_path/ca_data')

            self._sock = ssl.wrap_socket(self._sock, ca_certs=ca_file, cert_reqs=verify_mode, ssl_version=proto)

        if verify:
            cert = self._sock.getpeercert()
            cert_cn = dict([e[0] for e in cert['subject']])['commonName']
            self.check_hostname(cert_cn, cert)
        else:
            logger.debug('Skipping hostname validation')
        self.started_tls = True
        logger.debug('Installed TLS layer on #{0}'.format(self.ID))

    def check_hostname(self, cert_cn, cert):
        """SSL check_hostname according to RFC 4513 sec 3.1.3. Compares supplied values against ``self.host`` to
        determine the validity of the cert.

        :param str cert_cn: The common name of the cert
        :param dict cert: A dictionary representing the rest of the cert. Checks key subjectAltNames for a list of
                          (type, value) tuples, where type is 'DNS' or 'IP'. DNS supports leading wildcard.
        :rtype: None
        :raises LDAPConnectionError: if no supplied values match ``self.host``
        """
        if self.host == cert_cn:
            logger.debug('Matched server identity to cert commonName')
        else:
            valid = False
            tried = [cert_cn]
            for type, value in cert.get('subjectAltName', []):
                if type == 'DNS' and value.startswith('*.'):
                    valid = self.host.endswith(value[1:])
                else:
                    valid = (self.host == value)
                tried.append(value)
                if valid:
                    logger.debug('Matched server identity to cert {0} subjectAltName'.format(type))
                    break
            if not valid:
                raise LDAPConnectionError('Server identity "{0}" does not match any cert names: {1}'.format(
                    self.host, ', '.join(tried)))

    def sasl_init(self, mechs, **props):
        """Initialize a :class:`.puresasl.client.SASLClient`"""
        self._sasl_client = SASLClient(self.host, 'ldap', **props)
        self._sasl_client.choose_mechanism(mechs)

    def _has_sasl_client(self):
        return self._sasl_client is not None

    def _require_sasl_client(self):
        if not self._has_sasl_client():
            raise LDAPSASLError('SASL init not complete')

    @property
    def sasl_qop(self):
        """Obtain the chosen quality of protection"""
        self._require_sasl_client()
        return self._sasl_client.qop

    @property
    def sasl_mech(self):
        """Obtain the chosen mechanism"""
        self._require_sasl_client()
        mech = self._sasl_client.mechanism
        if mech is None:
            raise LDAPSASLError('SASL init not complete - no mech chosen')
        else:
            return mech

    def sasl_process_auth_challenge(self, challenge):
        """Process an auth challenge and return the correct response"""
        self._require_sasl_client()
        return self._sasl_client.process(challenge)

    def _prep_message(self, op, obj, controls=None):
        """Prepare a message for transmission"""
        mid = self._next_message_id
        self._next_message_id += 1
        lm = pack(mid, op, obj, controls)
        raw = ber_encode(lm)
        if self._has_sasl_client():
            raw = self._sasl_client.wrap(raw)
        return mid, raw

    def send_message(self, op, obj, controls=None):
        """Create and send an LDAPMessage given an operation name and a corresponding object.

        Operation names must be defined as component names in laurelin.ldap.rfc4511.ProtocolOp and
        the object must be of the corresponding type.

        :param str op: The protocol operation name
        :param object obj: The associated protocol object (see :class:`.rfc4511.ProtocolOp` for mapping.
        :param controls: Any request controls for the message
        :type controls: rfc4511.Controls or None
        :return: The message ID for this message
        :rtype: int
        """
        mid, raw = self._prep_message(op, obj, controls)
        self._sock.sendall(raw)
        return mid

    def recv_one(self, want_message_id):
        """Get the next message with ``want_message_id`` being sent by the server

        :param int want_message_id: The desired message ID.
        :return: The LDAP message
        :rtype: rfc4511.LDAPMessage
        """
        return next(self.recv_messages(want_message_id))

    def recv_messages(self, want_message_id):
        """Iterate all messages with ``want_message_id`` being sent by the server.

        :param int want_message_id: The desired message ID.
        :return: An iterator over :class:`.rfc4511.LDAPMessage`.
        """
        flush_queue = True
        raw = b''
        while True:
            if flush_queue:
                if want_message_id in self._message_queues:
                    q = self._message_queues[want_message_id]
                    while True:
                        if len(q) == 0:
                            break
                        obj = q.popleft()
                        if len(q) == 0:
                            del self._message_queues[want_message_id]
                        yield obj
            else:
                flush_queue = True
            if want_message_id in self.abandoned_mids:
                return
            try:
                newraw = self._sock.recv(LDAPSocket.RECV_BUFFER)
                if self._has_sasl_client():
                    newraw = self._sasl_client.unwrap(newraw)
                raw += newraw
                while len(raw) > 0:
                    response, raw = ber_decode(raw, asn1Spec=LDAPMessage())
                    have_message_id = response.getComponentByName('messageID')
                    if want_message_id == have_message_id:
                        yield response
                    elif have_message_id == 0:
                        msg = 'Received unsolicited message (default message - should never be seen)'
                        try:
                            mid, xr, ctrls = unpack('extendedResp', response)
                            res_code = xr.getComponentByName('resultCode')
                            xr_oid = six.text_type(xr.getComponentByName('responseName'))
                            if xr_oid == LDAPSocket.OID_DISCONNECTION_NOTICE:
                                mtype = 'Notice of Disconnection'
                            else:
                                mtype = 'Unhandled ({0})'.format(xr_oid)
                            diag = xr.getComponentByName('diagnosticMessage')
                            msg = 'Got unsolicited message: {0}: {1}: {2}'.format(mtype, res_code, diag)
                            if res_code == ResultCode('protocolError'):
                                msg += (' (This may indicate an incompatability between laurelin-ldap and your server '
                                        'distribution)')
                            elif res_code == ResultCode('strongerAuthRequired'):
                                # this is a direct quote from RFC 4511 sec 4.4.1
                                msg += (' (The server has detected that an established security association between the'
                                        ' client and server has unexpectedly failed or been compromised)')
                        except UnexpectedResponseType:
                            msg = 'Unhandled unsolicited message from server'
                        finally:
                            raise LDAPUnsolicitedMessage(response, msg)
                    else:
                        if have_message_id not in self._message_queues:
                            self._message_queues[have_message_id] = deque()
                        self._message_queues[have_message_id].append(response)
            except SubstrateUnderrunError:
                flush_queue = False
                continue

    def close(self):
        """Close the low-level socket connection."""
        return self._sock.close()
コード例 #33
0
ファイル: net.py プロジェクト: ashafer01/laurelin
 def sasl_init(self, mechs, **props):
     """Initialize a :class:`.puresasl.client.SASLClient`"""
     self._sasl_client = SASLClient(self.host, 'ldap', **props)
     self._sasl_client.choose_mechanism(mechs)
コード例 #34
0
ファイル: TTransport.py プロジェクト: mrofiq/pycharm_helpers
class TSaslClientTransport(TTransportBase, CReadableTransport):
    """
  SASL transport 
  """

    START = 1
    OK = 2
    BAD = 3
    ERROR = 4
    COMPLETE = 5

    def __init__(self,
                 transport,
                 host,
                 service,
                 mechanism='GSSAPI',
                 **sasl_kwargs):
        """
    transport: an underlying transport to use, typically just a TSocket
    host: the name of the server, from a SASL perspective
    service: the name of the server's service, from a SASL perspective
    mechanism: the name of the preferred mechanism to use

    All other kwargs will be passed to the puresasl.client.SASLClient
    constructor.
    """

        from puresasl.client import SASLClient

        self.transport = transport
        self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs)

        self.__wbuf = StringIO()
        self.__rbuf = StringIO()

    def open(self):
        if not self.transport.isOpen():
            self.transport.open()

        self.send_sasl_msg(self.START, self.sasl.mechanism)
        self.send_sasl_msg(self.OK, self.sasl.process())

        while True:
            status, challenge = self.recv_sasl_msg()
            if status == self.OK:
                self.send_sasl_msg(self.OK, self.sasl.process(challenge))
            elif status == self.COMPLETE:
                if not self.sasl.complete:
                    raise TTransportException(
                        "The server erroneously indicated "
                        "that SASL negotiation was complete")
                else:
                    break
            else:
                raise TTransportException(
                    "Bad SASL negotiation status: %d (%s)" %
                    (status, challenge))

    def send_sasl_msg(self, status, body):
        header = pack(">BI", status, len(body))
        self.transport.write(header + body)
        self.transport.flush()

    def recv_sasl_msg(self):
        header = self.transport.readAll(5)
        status, length = unpack(">BI", header)
        if length > 0:
            payload = self.transport.readAll(length)
        else:
            payload = ""
        return status, payload

    def write(self, data):
        self.__wbuf.write(data)

    def flush(self):
        data = self.__wbuf.getvalue()
        encoded = self.sasl.wrap(data)
        self.transport.write(''.join((pack("!i", len(encoded)), encoded)))
        self.transport.flush()
        self.__wbuf = StringIO()

    def read(self, sz):
        ret = self.__rbuf.read(sz)
        if len(ret) != 0:
            return ret

        self._read_frame()
        return self.__rbuf.read(sz)

    def _read_frame(self):
        header = self.transport.readAll(4)
        length, = unpack('!i', header)
        encoded = self.transport.readAll(length)
        self.__rbuf = StringIO(self.sasl.unwrap(encoded))

    def close(self):
        self.sasl.dispose()
        self.transport.close()

    # based on TFramedTransport
    @property
    def cstringio_buf(self):
        return self.__rbuf

    def cstringio_refill(self, prefix, reqlen):
        # self.__rbuf will already be empty here because fastbinary doesn't
        # ask for a refill until the previous buffer is empty.  Therefore,
        # we can start reading new frames immediately.
        while len(prefix) < reqlen:
            self._read_frame()
            prefix += self.__rbuf.getvalue()
        self.__rbuf = StringIO(prefix)
        return self.__rbuf
コード例 #35
0
ファイル: saslppwrapper.py プロジェクト: lexman/sasl-pp
class Client(object):

    def __init__(self):
        self.lasterror = ''
        self.attributes = {}
        self.sasl = None

    # Set attributes to be used in authenticating the session.  All attributes should be set
    # before init() is called.
    #
    # @param key Name of attribute being set
    # @param value Value of attribute being set
    # @return true iff success.  If false is returned, call getError() for error details.
    #
    # Available attribute keys:
    #
    #    service      - Name of the service being accessed
    #    username     - User identity for authentication
    #    authname     - User identity for authorization (if different from username)
    #    password     - Password associated with username
    #    host         - Fully qualified domain name of the server host
    #    maxbufsize   - Maximum receive buffer size for the security layer
    #    minssf       - Minimum acceptable security strength factor (integer)
    #    maxssf       - Maximum acceptable security strength factor (integer)
    #    externalssf  - Security strength factor supplied by external mechanism (i.e. SSL/TLS)
    #    externaluser - Authentication ID (of client) as established by external mechanism

    def setAttr(self, key, value):
        self.attributes[key] = value

    # Initialize the client object.  This should be called after all of the properties have been set.
    #
    # @return true iff success.  If false is returned, call getError() for error details.

    def getAttr(self, key):
        return self.attributes[key]

    def init(self):
        return True

    # Start the SASL exchange with the server.
    #
    # @param mechList List of mechanisms provided by the server
    # @param chosen The mechanism chosen by the client
    # @param initialResponse Initial block of data to send to the server
    #
    # @return true iff success.  If false is returned, call getError() for error details.

    def start(self, chosen):
        self.sasl = SASLClient(self.attributes['host'], mechanism=chosen, callback=self.getAttr)
        # ret, (bytes)chosen_mech, (bytes)initial_response = self.sasl.start(self.mechanism)
        return True, chosen.encode(), self.sasl.process()

    # Step the SASL handshake.
    #
    # @param challenge The challenge supplied by the server
    # @param response (output) The response to be sent back to the server
    #
    # @return true iff success.  If false is returned, call getError() for error details.
            
    def step(self, challenge):
        return True, self.sasl.process(challenge)

    # Encode data for secure transmission to the server.
    #
    # @param clearText Clear text data to be encrypted
    # @param cipherText (output) Encrypted data to be transmitted
    #
    # @return true iff success.  If false is returned, call getError() for error details.

    def encode(self, clearText):
        return True, self.sasl.wrap(clearText)

    # Decode data received from the server.
    #
    # @param cipherText Encrypted data received from the server
    # @param clearText (output) Decrypted clear text data 
    #
    # @return true iff success.  If false is returned, call getError() for error details.

    def decode(self, cipherText):
        return True, self.sasl.unwrap(clearText)

    # Get the user identity (used for authentication) associated with this session.
    # Note that this is particularly useful for single-sign-on mechanisms in which the 
    # username is not supplied by the application.
    #
    # @param userId (output) Authenticated user ID for this session.

    def getUserId(self):
        return self.attributes['externaluser']

    # Get the security strength factor associated with this session.
    #
    # @param ssf (output) Negotiated SSF value.

    def getSSF(self):
        return True

    # Get error message for last error.
    # This function will return the last error message then clear the error state.
    # If there was no error or the error state has been cleared, this function will output
    # an empty string.
    #
    # @param error Error message string

    def getError(self):
        error = self.lasterror[:]
        self.lasterror = ''
        return error
コード例 #36
0
 def setUp(self):
     self.sasl = SASLClient('localhost',
                            mechanism=self.mechanism_class.name,
                            **self.sasl_kwargs)
     self.mechanism = self.sasl._chosen_mech
コード例 #37
0
ファイル: channel.py プロジェクト: alope107/py-yarn
    def negotiate_sasl(self, token):
        log.debug("##############NEGOTIATING SASL#####################")

        # Prepares negotiate request

        header_bytes = self.create_sasl_header().SerializeToString()

        negotiate_request = RpcSaslProto()
        negotiate_request.state = RpcSaslProto.NEGOTIATE
        negotiate_request.version = 0

        sasl_bytes = negotiate_request.SerializeToString()

        total_length = (
            len(header_bytes)
            + len(sasl_bytes)
            + encoder._VarintSize(len(header_bytes))
            + encoder._VarintSize(len(sasl_bytes))
        )

        # Sends negotiate request
        self.write(struct.pack("!I", total_length))
        self.write_delimited(header_bytes)
        self.write_delimited(sasl_bytes)

        # Gets negotiate response
        bytes = self.recv_rpc_message()
        resp = self.parse_response(bytes, RpcSaslProto)

        chosen_auth = None
        for auth in resp.auths:
            if auth.method == "TOKEN" and auth.mechanism == "DIGEST-MD5":
                chosen_auth = auth

        if chosen_auth is None:
            raise IOError("Token digest-MD5 authentication not supported by server")

        # Prepares initiate request

        self.sasl = SASLClient(
            chosen_auth.serverId,
            chosen_auth.protocol,
            mechanism=chosen_auth.mechanism,
            username=base64.b64encode(token["identifier"]),
            password=base64.b64encode(token["password"]),
        )

        challenge_resp = self.sasl.process(chosen_auth.challenge)

        auth = RpcSaslProto.SaslAuth()
        auth.method = chosen_auth.method
        auth.mechanism = chosen_auth.mechanism
        auth.protocol = chosen_auth.protocol
        auth.serverId = chosen_auth.serverId

        initiate_request = RpcSaslProto()
        initiate_request.state = RpcSaslProto.INITIATE
        initiate_request.version = 0
        initiate_request.auths.extend([auth])
        initiate_request.token = challenge_resp

        sasl_bytes = initiate_request.SerializeToString()

        total_length = (
            len(header_bytes)
            + len(sasl_bytes)
            + encoder._VarintSize(len(header_bytes))
            + encoder._VarintSize(len(sasl_bytes))
        )

        # Sends initiate request
        self.write(struct.pack("!I", total_length))
        self.write_delimited(header_bytes)
        self.write_delimited(sasl_bytes)

        bytes = self.recv_rpc_message()
        resp = self.parse_response(bytes, RpcSaslProto)
コード例 #38
0
    def _connect(self, host, port):
        client = self.client
        self.logger.info('Connecting to %s:%s, use_ssl: %r', host, port,
                         self.client.use_ssl)

        self.logger.log(BLATHER, '    Using session_id: %r session_passwd: %s',
                        client._session_id, hexlify(client._session_passwd))

        with self._socket_error_handling():
            self._socket = self.handler.create_connection(
                address=(host, port),
                timeout=client._session_timeout / 1000.0,
                use_ssl=self.client.use_ssl,
                keyfile=self.client.keyfile,
                certfile=self.client.certfile,
                ca=self.client.ca,
                keyfile_password=self.client.keyfile_password,
                verify_certs=self.client.verify_certs,
            )

        self._socket.setblocking(0)

        connect = Connect(0, client.last_zxid, client._session_timeout,
                          client._session_id or 0, client._session_passwd,
                          client.read_only)

        connect_result, zxid = self._invoke(client._session_timeout / 1000.0,
                                            connect)

        if connect_result.time_out <= 0:
            raise SessionExpiredError("Session has expired")

        if zxid:
            client.last_zxid = zxid

        # Load return values
        client._session_id = connect_result.session_id
        client._protocol_version = connect_result.protocol_version
        negotiated_session_timeout = connect_result.time_out
        connect_timeout = negotiated_session_timeout / len(client.hosts)
        read_timeout = negotiated_session_timeout * 2.0 / 3.0
        client._session_passwd = connect_result.passwd

        self.logger.log(
            BLATHER, 'Session created, session_id: %r session_passwd: %s\n'
            '    negotiated session timeout: %s\n'
            '    connect timeout: %s\n'
            '    read timeout: %s', client._session_id,
            hexlify(client._session_passwd), negotiated_session_timeout,
            connect_timeout, read_timeout)

        if connect_result.read_only:
            client._session_callback(KeeperState.CONNECTED_RO)
            self._ro_mode = iter(self._server_pinger())
        else:
            self._ro_mode = None

            # Get a copy of the auth data before iterating, in case it is
            # changed.
            client_auth_data_copy = copy.copy(client.auth_data)

            if client.use_sasl and self.sasl_cli is None:
                if PURESASL_AVAILABLE:
                    for scheme, auth in client_auth_data_copy:
                        if scheme == 'sasl':
                            username, password = auth.split(":")
                            self.sasl_cli = SASLClient(
                                host=client.sasl_server_principal,
                                service='zookeeper',
                                mechanism='DIGEST-MD5',
                                username=username,
                                password=password)
                            break

                    # As described in rfc
                    # https://tools.ietf.org/html/rfc2831#section-2.1
                    # sending empty challenge
                    self._send_sasl_request(challenge=b'',
                                            timeout=connect_timeout)
                else:
                    self.logger.warn('Pure-sasl library is missing while sasl'
                                     ' authentification is configured. Please'
                                     ' install pure-sasl library to connect '
                                     'using sasl. Now falling back '
                                     'connecting WITHOUT any '
                                     'authentification.')
                    client.use_sasl = False
                    client._session_callback(KeeperState.CONNECTED)
            else:
                client._session_callback(KeeperState.CONNECTED)
                for scheme, auth in client_auth_data_copy:
                    if scheme == "digest":
                        ap = Auth(0, scheme, auth)
                        zxid = self._invoke(connect_timeout / 1000.0,
                                            ap,
                                            xid=AUTH_XID)
                        if zxid:
                            client.last_zxid = zxid

        return read_timeout, connect_timeout
コード例 #39
0
ファイル: test_mechanism.py プロジェクト: tristeng/pure-sasl
 def setUp(self):
     self.sasl = SASLClient('localhost', mechanism=self.mechanism_class.name, **self.sasl_kwargs)
     self.mechanism = self.sasl._chosen_mech
コード例 #40
0
ファイル: mock_saslclient.py プロジェクト: jpypi/laurelin
 def __init__(self, *args, **kwds):
     SASLClient.__init__(self, 'testhost')
コード例 #41
0
ファイル: channel.py プロジェクト: alope107/py-yarn
class SocketRpcChannel(RpcChannel):
    ERROR_BYTES = 18446744073709551615L
    RPC_HEADER = "hrpc"
    RPC_SERVICE_CLASS = 0x00
    AUTH_PROTOCOL_NONE = 0x00
    AUTH_PROTOCOL_SASL = 0xDF  # -33
    RPC_PROTOCOL_BUFFFER = 0x02

    """Socket implementation of an RpcChannel.
    """

    def __init__(self, host, port, version, context_protocol, timeout=30):
        """SocketRpcChannel to connect to a socket server on a user defined port."""
        self.host = host
        self.port = port
        self.sock = None
        self.call_id = (
            -3
        )  # First time (when the connection context is sent, the call_id should be -3, otherwise start with 0 and increment)
        self.version = version
        self.client_id = str(uuid.uuid4())
        self.context_protocol = context_protocol
        self.timeout = timeout
        self.token = None

    def validate_request(self, request):
        """Validate the client request against the protocol file."""

        # Check the request is correctly initialized
        if not request.IsInitialized():
            raise Exception("Client request (%s) is missing mandatory fields" % type(request))

    def get_connection(self, host, port):
        """Open a socket connection to a given host and port and writes the Hadoop header
        The Hadoop RPC protocol looks like this when creating a connection:

        +---------------------------------------------------------------------+
        |  Header, 4 bytes ("hrpc")                                           |
        +---------------------------------------------------------------------+
        |  Version, 1 byte (default verion 9)                                 |
        +---------------------------------------------------------------------+
        |  RPC service class, 1 byte (0x00)                                   |
        +---------------------------------------------------------------------+
        |  Auth protocol, 1 byte (Auth method None = 0)                       |
        +---------------------------------------------------------------------+
        |  Length of the RpcRequestHeaderProto  + length of the               |
        |  of the IpcConnectionContextProto (4 bytes/32 bit int)              |
        +---------------------------------------------------------------------+
        |  Serialized delimited RpcRequestHeaderProto                         |
        +---------------------------------------------------------------------+
        |  Serialized delimited IpcConnectionContextProto                     |
        +---------------------------------------------------------------------+
        """

        log.debug("############## CONNECTING ##############")

        auth = self.AUTH_PROTOCOL_NONE if self.token is None else self.AUTH_PROTOCOL_SASL

        # Open socket
        self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
        self.sock.settimeout(self.timeout)
        # Connect socket to server - defined by host and port arguments
        self.sock.connect((host, port))

        # Send RPC headers
        self.write(self.RPC_HEADER)  # header
        self.write(struct.pack("B", self.version))  # version
        self.write(struct.pack("B", self.RPC_SERVICE_CLASS))  # RPC service class
        self.write(struct.pack("B", auth))  # serialization type (default none)

        if auth == SocketRpcChannel.AUTH_PROTOCOL_SASL:
            self.negotiate_sasl(self.token)
            self.call_id = -3

        rpc_header = self.create_rpc_request_header()
        context = (
            self.create_connection_context()
            if auth is self.AUTH_PROTOCOL_NONE
            else self.create_connection_context_auth()
        )

        header_length = (
            len(rpc_header) + encoder._VarintSize(len(rpc_header)) + len(context) + encoder._VarintSize(len(context))
        )

        if log.getEffectiveLevel() == logging.DEBUG:
            log.debug("Header length: %s (%s)" % (header_length, format_bytes(struct.pack("!I", header_length))))

        self.write(struct.pack("!I", header_length))
        self.write_delimited(rpc_header)
        self.write_delimited(context)

    def write(self, data):
        if log.getEffectiveLevel() == logging.DEBUG:
            log.debug("Sending: %s", format_bytes(data))
        self.sock.send(data)

    def write_delimited(self, data):
        self.write(encoder._VarintBytes(len(data)))
        self.write(data)

    def create_rpc_request_header(self):
        """Creates and serializes a delimited RpcRequestHeaderProto message."""
        rpcheader = RpcRequestHeaderProto()
        rpcheader.rpcKind = 2  # rpcheaderproto.RpcKindProto.Value('RPC_PROTOCOL_BUFFER')
        rpcheader.rpcOp = 0  # rpcheaderproto.RpcPayloadOperationProto.Value('RPC_FINAL_PACKET')
        rpcheader.callId = self.call_id
        rpcheader.retryCount = -1
        rpcheader.clientId = self.client_id[0:16]

        if self.call_id == -3:
            self.call_id = 0
        else:
            self.call_id += 1

        # Serialize delimited
        s_rpcHeader = rpcheader.SerializeToString()
        log_protobuf_message("RpcRequestHeaderProto (len: %d)" % (len(s_rpcHeader)), rpcheader)
        return s_rpcHeader

    def create_connection_context(self):
        """Creates and seriazlies a IpcConnectionContextProto (not delimited)"""
        context = IpcConnectionContextProto()
        local_user = pwd.getpwuid(os.getuid())[0]
        context.userInfo.effectiveUser = local_user
        context.protocol = self.context_protocol

        s_context = context.SerializeToString()
        log_protobuf_message("RequestContext (len: %d)" % len(s_context), context)
        return s_context

    def create_connection_context_auth(self):
        """Creates and seriazlies a IpcConnectionContextProto (not delimited)"""
        context = IpcConnectionContextProto()
        # TODO do this better
        context.userInfo.effectiveUser = (
            "appattempt_" + str(self.appid["cluster_timestamp"]) + "_" + str(self.appid["id"]).zfill(4) + "_000001"
        )
        context.protocol = self.context_protocol

        import ipdb

        ipdb.set_trace()

        s_context = context.SerializeToString()
        log_protobuf_message("RequestContext (len: %d)" % len(s_context), context)
        return s_context

    def create_sasl_header(self):
        rpcheader = RpcRequestHeaderProto()
        rpcheader.rpcKind = 2  # rpcheaderproto.RpcKindProto.Value('RPC_PROTOCOL_BUFFER')
        rpcheader.rpcOp = 0  # rpcheaderproto.RpcPayloadOperationProto.Value('RPC_FINAL_PACKET')
        rpcheader.callId = -33
        rpcheader.retryCount = -1
        rpcheader.clientId = self.client_id[0:16]

        return rpcheader

    def negotiate_sasl(self, token):
        log.debug("##############NEGOTIATING SASL#####################")

        # Prepares negotiate request

        header_bytes = self.create_sasl_header().SerializeToString()

        negotiate_request = RpcSaslProto()
        negotiate_request.state = RpcSaslProto.NEGOTIATE
        negotiate_request.version = 0

        sasl_bytes = negotiate_request.SerializeToString()

        total_length = (
            len(header_bytes)
            + len(sasl_bytes)
            + encoder._VarintSize(len(header_bytes))
            + encoder._VarintSize(len(sasl_bytes))
        )

        # Sends negotiate request
        self.write(struct.pack("!I", total_length))
        self.write_delimited(header_bytes)
        self.write_delimited(sasl_bytes)

        # Gets negotiate response
        bytes = self.recv_rpc_message()
        resp = self.parse_response(bytes, RpcSaslProto)

        chosen_auth = None
        for auth in resp.auths:
            if auth.method == "TOKEN" and auth.mechanism == "DIGEST-MD5":
                chosen_auth = auth

        if chosen_auth is None:
            raise IOError("Token digest-MD5 authentication not supported by server")

        # Prepares initiate request

        self.sasl = SASLClient(
            chosen_auth.serverId,
            chosen_auth.protocol,
            mechanism=chosen_auth.mechanism,
            username=base64.b64encode(token["identifier"]),
            password=base64.b64encode(token["password"]),
        )

        challenge_resp = self.sasl.process(chosen_auth.challenge)

        auth = RpcSaslProto.SaslAuth()
        auth.method = chosen_auth.method
        auth.mechanism = chosen_auth.mechanism
        auth.protocol = chosen_auth.protocol
        auth.serverId = chosen_auth.serverId

        initiate_request = RpcSaslProto()
        initiate_request.state = RpcSaslProto.INITIATE
        initiate_request.version = 0
        initiate_request.auths.extend([auth])
        initiate_request.token = challenge_resp

        sasl_bytes = initiate_request.SerializeToString()

        total_length = (
            len(header_bytes)
            + len(sasl_bytes)
            + encoder._VarintSize(len(header_bytes))
            + encoder._VarintSize(len(sasl_bytes))
        )

        # Sends initiate request
        self.write(struct.pack("!I", total_length))
        self.write_delimited(header_bytes)
        self.write_delimited(sasl_bytes)

        bytes = self.recv_rpc_message()
        resp = self.parse_response(bytes, RpcSaslProto)
        # If desired, server can be authenticated using the rspauth in the response

    def send_rpc_message(self, method, request):
        """Sends a Hadoop RPC request to the NameNode.

        The IpcConnectionContextProto, RpcPayloadHeaderProto and HadoopRpcRequestProto
        should already be serialized in the right way (delimited or not) before
        they are passed in this method.

        The Hadoop RPC protocol looks like this for sending requests:

        When sending requests
        +---------------------------------------------------------------------+
        |  Length of the next three parts (4 bytes/32 bit int)                |
        +---------------------------------------------------------------------+
        |  Delimited serialized RpcRequestHeaderProto (varint len + header)   |
        +---------------------------------------------------------------------+
        |  Delimited serialized RequestHeaderProto (varint len + header)      |
        +---------------------------------------------------------------------+
        |  Delimited serialized Request (varint len + request)                |
        +---------------------------------------------------------------------+
        """
        log.debug("############## SENDING ##############")

        # 0. RpcRequestHeaderProto
        rpc_request_header = self.create_rpc_request_header()
        # 1. RequestHeaderProto
        request_header = self.create_request_header(method)
        # 2. Param
        param = request.SerializeToString()
        if log.getEffectiveLevel() == logging.DEBUG:
            log_protobuf_message("Request", request)

        rpc_message_length = (
            len(rpc_request_header)
            + encoder._VarintSize(len(rpc_request_header))
            + len(request_header)
            + encoder._VarintSize(len(request_header))
            + len(param)
            + encoder._VarintSize(len(param))
        )

        if log.getEffectiveLevel() == logging.DEBUG:
            log.debug(
                "RPC message length: %s (%s)"
                % (rpc_message_length, format_bytes(struct.pack("!I", rpc_message_length)))
            )
        self.write(struct.pack("!I", rpc_message_length))

        self.write_delimited(rpc_request_header)
        self.write_delimited(request_header)
        self.write_delimited(param)

    def create_request_header(self, method):
        header = RequestHeaderProto()
        header.methodName = method.name
        header.declaringClassProtocolName = self.context_protocol
        header.clientProtocolVersion = 1

        s_header = header.SerializeToString()
        log_protobuf_message("RequestHeaderProto (len: %d)" % len(s_header), header)
        return s_header

    def recv_rpc_message(self):
        """Handle reading an RPC reply from the server. This is done by wrapping the
        socket in a RcpBufferedReader that allows for rewinding of the buffer stream.
        """
        log.debug("############## RECVING ##############")
        byte_stream = RpcBufferedReader(self.sock)
        return byte_stream

    def get_length(self, byte_stream):
        """ In Hadoop protobuf RPC, some parts of the stream are delimited with protobuf varint,
        while others are delimited with 4 byte integers. This reads 4 bytes from the byte stream
        and retruns the length of the delimited part that follows, by unpacking the 4 bytes
        and returning the first element from a tuple. The tuple that is returned from struc.unpack()
        only contains one element.
        """
        length = struct.unpack("!i", byte_stream.read(4))[0]
        log.debug("4 bytes delimited part length: %d" % length)
        return length

    def parse_response(self, byte_stream, response_class):
        """Parses a Hadoop RPC response.

        The RpcResponseHeaderProto contains a status field that marks SUCCESS or ERROR.
        The Hadoop RPC protocol looks like the diagram below for receiving SUCCESS requests.
        +-----------------------------------------------------------+
        |  Length of the RPC resonse (4 bytes/32 bit int)           |
        +-----------------------------------------------------------+
        |  Delimited serialized RpcResponseHeaderProto              |
        +-----------------------------------------------------------+
        |  Serialized delimited RPC response                        |
        +-----------------------------------------------------------+

        In case of an error, the header status is set to ERROR and the error fields are set.
        """

        log.debug("############## PARSING ##############")
        log.debug("Payload class: %s" % response_class)

        # Read first 4 bytes to get the total length
        len_bytes = byte_stream.read(4)
        total_length = struct.unpack("!I", len_bytes)[0]
        log.debug("Total response length: %s" % total_length)

        header = RpcResponseHeaderProto()
        (header_len, header_bytes) = get_delimited_message_bytes(byte_stream)

        log.debug("Header read %d" % header_len)
        header.ParseFromString(header_bytes)
        log_protobuf_message("RpcResponseHeaderProto", header)

        if header.status == 0:
            log.debug("header: %s, total: %s" % (header_len, total_length))
            if header_len >= total_length:
                return
            response = response_class()
            response_bytes = get_delimited_message_bytes(byte_stream, total_length - header_len)[1]
            if len(response_bytes) > 0:
                response.ParseFromString(response_bytes)
                if log.getEffectiveLevel() == logging.DEBUG:
                    log_protobuf_message("Response", response)
                return response
        else:
            self.handle_error(header)

    def handle_error(self, header):
        raise RequestError("\n".join([header.exceptionClassName, header.errorMsg]))

    def close_socket(self):
        """Closes the socket and resets the channel."""
        log.debug("Closing socket")
        if self.sock:
            try:
                self.sock.close()
            except:
                pass

            self.sock = None

    def CallMethod(self, method, controller, request, response_class, done):
        """Call the RPC method. The naming doesn't confirm PEP8, since it's
        a method called by protobuf
        """
        try:
            self.validate_request(request)

            if not self.sock:
                self.get_connection(self.host, self.port)

            self.send_rpc_message(method, request)

            byte_stream = self.recv_rpc_message()
            return self.parse_response(byte_stream, response_class)
        except RequestError:  # Raise a request error, but don't close the socket
            raise
        except Exception:  # All other errors close the socket
            self.close_socket()
            raise
コード例 #42
0
ファイル: thrift_sasl.py プロジェクト: yoziru-desu/pyhs2
class TSaslClientTransport(TTransportBase, CReadableTransport):
    """
    A SASL transport based on the pure-sasl library:
        https://github.com/thobbs/pure-sasl
    """

    START = 1
    OK = 2
    BAD = 3
    ERROR = 4
    COMPLETE = 5

    def __init__(self, transport, host, service, mechanism="GSSAPI", **sasl_kwargs):
        """
        transport: an underlying transport to use, typically just a TSocket
        host: the name of the server, from a SASL perspective
        service: the name of the server's service, from a SASL perspective
        mechanism: the name of the preferred mechanism to use
        All other kwargs will be passed to the puresasl.client.SASLClient
        constructor.
        """
        self.transport = transport
        self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs)
        self.__wbuf = StringIO()
        self.__rbuf = StringIO()

        # extremely awful hack, but you've got to do what you've got to do.
        # essentially "wrap" and "unwrap" are defined for the base Mechanism class and raise a NotImplementedError by
        # default, and PlainMechanism doesn't implement its own versions (lol).

    #        self.sasl._chosen_mech.wrap = lambda x: x
    #        self.sasl._chosen_mech.unwrap = lambda x: x

    def open(self):
        if not self.transport.isOpen():
            self.transport.open()

        self.send_sasl_msg(self.START, self.sasl.mechanism)
        self.send_sasl_msg(self.OK, self.sasl.process() or "")

        while True:
            status, challenge = self.recv_sasl_msg()
            if status == self.OK:
                self.send_sasl_msg(self.OK, self.sasl.process(challenge) or "")
            elif status == self.COMPLETE:
                # self.sasl.complete is not set for PLAIN authentication (trollface.jpg) so we have to skip this check
                #                break
                if not self.sasl.complete:
                    raise TTransportException("The server erroneously indicated " "that SASL negotiation was complete")
                else:
                    break
            else:
                raise TTransportException("Bad SASL negotiation status: %d (%s)" % (status, challenge))

    def send_sasl_msg(self, status, body):
        if body is None:
            body = ""
        header = pack(">BI", status, len(body))

        body = body if isinstance(body, bytes) else body.encode("utf-8")

        self.transport.write(header + body)
        self.transport.flush()

    def recv_sasl_msg(self):
        header = self.transport.readAll(5)
        status, length = unpack(">BI", header)
        if length > 0:
            payload = self.transport.readAll(length)
        else:
            payload = ""
        return status, payload

    def write(self, data):
        self.__wbuf.write(data)

    def flush(self):
        data = self.__wbuf.getvalue()
        encoded = self.sasl.wrap(data)
        self.transport.write("".join((pack("!i", len(encoded)), encoded)))
        self.transport.flush()
        self.__wbuf = StringIO()

    def read(self, sz):
        ret = self.__rbuf.read(sz)
        if len(ret) != 0:
            return ret

        self._read_frame()
        return self.__rbuf.read(sz)

    def _read_frame(self):
        header = self.transport.readAll(4)
        length, = unpack("!i", header)
        encoded = self.transport.readAll(length)
        self.__rbuf = StringIO(self.sasl.unwrap(encoded))

    def close(self):
        self.sasl.dispose()
        self.transport.close()

    # based on TFramedTransport
    @property
    def cstringio_buf(self):
        return self.__rbuf

    def cstringio_refill(self, prefix, reqlen):
        # self.__rbuf will already be empty here because fastbinary doesn't
        # ask for a refill until the previous buffer is empty.  Therefore,
        # we can start reading new frames immediately.
        while len(prefix) < reqlen:
            self._read_frame()
            prefix += self.__rbuf.getvalue()
        self.__rbuf = StringIO(prefix)
        return self.__rbuf
コード例 #43
0
ファイル: auth.py プロジェクト: datastax/python-driver-dse
 def __init__(self, host, service, qops, properties):
     properties = properties or {}
     self.sasl = SASLClient(host, service, 'GSSAPI', qops=qops, **properties)
コード例 #44
0
ファイル: _sasl.py プロジェクト: ClearwaterCore/Telephus
 def createSASLClient(self, host, service, mechanism, **kwargs):
     self.sasl = SASLClient(host, service, mechanism, **kwargs)
コード例 #45
0
ファイル: saslppwrapper.py プロジェクト: lexman/sasl-pp
 def start(self, chosen):
     self.sasl = SASLClient(self.attributes['host'], mechanism=chosen, callback=self.getAttr)
     # ret, (bytes)chosen_mech, (bytes)initial_response = self.sasl.start(self.mechanism)
     return True, chosen.encode(), self.sasl.process()
コード例 #46
0
ファイル: auth.py プロジェクト: StuartAxelOwen/python-driver
 def __init__(self, host, service, mechanism='GSSAPI', **sasl_kwargs):
     if SASLClient is None:
         raise ImportError('The puresasl library has not been installed')
     self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs)
コード例 #47
0
ファイル: TTransport.py プロジェクト: Alpus/Eth
class TSaslClientTransport(TTransportBase, CReadableTransport):
  """
  SASL transport 
  """

  START = 1
  OK = 2
  BAD = 3
  ERROR = 4
  COMPLETE = 5

  def __init__(self, transport, host, service, mechanism='GSSAPI',
      **sasl_kwargs):
    """
    transport: an underlying transport to use, typically just a TSocket
    host: the name of the server, from a SASL perspective
    service: the name of the server's service, from a SASL perspective
    mechanism: the name of the preferred mechanism to use

    All other kwargs will be passed to the puresasl.client.SASLClient
    constructor.
    """

    from puresasl.client import SASLClient

    self.transport = transport
    self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs)

    self.__wbuf = StringIO()
    self.__rbuf = StringIO()

  def open(self):
    if not self.transport.isOpen():
      self.transport.open()

    self.send_sasl_msg(self.START, self.sasl.mechanism)
    self.send_sasl_msg(self.OK, self.sasl.process())

    while True:
      status, challenge = self.recv_sasl_msg()
      if status == self.OK:
        self.send_sasl_msg(self.OK, self.sasl.process(challenge))
      elif status == self.COMPLETE:
        if not self.sasl.complete:
          raise TTransportException("The server erroneously indicated "
              "that SASL negotiation was complete")
        else:
          break
      else:
        raise TTransportException("Bad SASL negotiation status: %d (%s)"
            % (status, challenge))

  def send_sasl_msg(self, status, body):
    header = pack(">BI", status, len(body))
    self.transport.write(header + body)
    self.transport.flush()

  def recv_sasl_msg(self):
    header = self.transport.readAll(5)
    status, length = unpack(">BI", header)
    if length > 0:
      payload = self.transport.readAll(length)
    else:
      payload = ""
    return status, payload

  def write(self, data):
    self.__wbuf.write(data)

  def flush(self):
    data = self.__wbuf.getvalue()
    encoded = self.sasl.wrap(data)
    self.transport.write(''.join((pack("!i", len(encoded)), encoded)))
    self.transport.flush()
    self.__wbuf = StringIO()

  def read(self, sz):
    ret = self.__rbuf.read(sz)
    if len(ret) != 0:
      return ret

    self._read_frame()
    return self.__rbuf.read(sz)

  def _read_frame(self):
    header = self.transport.readAll(4)
    length, = unpack('!i', header)
    encoded = self.transport.readAll(length)
    self.__rbuf = StringIO(self.sasl.unwrap(encoded))

  def close(self):
    self.sasl.dispose()
    self.transport.close()

  # based on TFramedTransport
  @property
  def cstringio_buf(self):
    return self.__rbuf

  def cstringio_refill(self, prefix, reqlen):
    # self.__rbuf will already be empty here because fastbinary doesn't
    # ask for a refill until the previous buffer is empty.  Therefore,
    # we can start reading new frames immediately.
    while len(prefix) < reqlen:
      self._read_frame()
      prefix += self.__rbuf.getvalue()
    self.__rbuf = StringIO(prefix)
    return self.__rbuf
コード例 #48
0
ファイル: connection.py プロジェクト: EmergingThreats/pycassa
class TSaslClientTransport(TTransportBase, CReadableTransport):

    START = 1
    OK = 2
    BAD = 3
    ERROR = 4
    COMPLETE = 5

    def __init__(self, transport, host, service,
            mechanism='GSSAPI', **sasl_kwargs):

        from puresasl.client import SASLClient

        self.transport = transport
        self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs)

        self.__wbuf = StringIO()
        self.__rbuf = StringIO()

    def open(self):
        if not self.transport.isOpen():
            self.transport.open()

        self.send_sasl_msg(self.START, self.sasl.mechanism)
        self.send_sasl_msg(self.OK, self.sasl.process())

        while True:
            status, challenge = self.recv_sasl_msg()
            if status == self.OK:
                self.send_sasl_msg(self.OK, self.sasl.process(challenge))
            elif status == self.COMPLETE:
                if not self.sasl.complete:
                    raise TTransportException("The server erroneously indicated "
                            "that SASL negotiation was complete")
                else:
                    break
            else:
                raise TTransportException("Bad SASL negotiation status: %d (%s)"
                        % (status, challenge))

    def send_sasl_msg(self, status, body):
        header = struct.pack(">BI", status, len(body))
        self.transport.write(header + body)
        self.transport.flush()

    def recv_sasl_msg(self):
        header = self.transport.readAll(5)
        status, length = struct.unpack(">BI", header)
        if length > 0:
            payload = self.transport.readAll(length)
        else:
            payload = ""
        return status, payload

    def write(self, data):
        self.__wbuf.write(data)

    def flush(self):
        data = self.__wbuf.getvalue()
        encoded = self.sasl.wrap(data)
        # Note stolen from TFramedTransport:
        # N.B.: Doing this string concatenation is WAY cheaper than making
        # two separate calls to the underlying socket object. Socket writes in
        # Python turn out to be REALLY expensive, but it seems to do a pretty
        # good job of managing string buffer operations without excessive copies
        self.transport.write(''.join((struct.pack("!i", len(encoded)), encoded)))
        self.transport.flush()
        self.__wbuf = StringIO()

    def read(self, sz):
        ret = self.__rbuf.read(sz)
        if len(ret) != 0:
            return ret

        self._read_frame()
        return self.__rbuf.read(sz)

    def _read_frame(self):
        header = self.transport.readAll(4)
        length, = struct.unpack('!i', header)
        encoded = self.transport.readAll(length)
        self.__rbuf = StringIO(self.sasl.unwrap(encoded))

    def close(self):
        self.sasl.dispose()
        self.transport.close()

    # Implement the CReadableTransport interface.
    # Stolen shamelessly from TFramedTransport
    @property
    def cstringio_buf(self):
        return self.__rbuf

    def cstringio_refill(self, prefix, reqlen):
        # self.__rbuf will already be empty here because fastbinary doesn't
        # ask for a refill until the previous buffer is empty.  Therefore,
        # we can start reading new frames immediately.
        while len(prefix) < reqlen:
            self._read_frame()
            prefix += self.__rbuf.getvalue()
        self.__rbuf = StringIO(prefix)
        return self.__rbuf
コード例 #49
0
 def createSASLClient(self, host, service, mechanism, **kwargs):
     self.sasl = SASLClient(host, service, mechanism, **kwargs)
コード例 #50
0
class TSaslClientTransport(TTransportBase, CReadableTransport):
    START = 1
    OK = 2
    BAD = 3
    ERROR = 4
    COMPLETE = 5

    def __init__(self,
                 transport,
                 host,
                 service,
                 mechanism='GSSAPI',
                 **sasl_kwargs):
        from puresasl.client import SASLClient
        self.transport = transport
        self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs)
        self.__wbuf = BufferIO()
        self.__rbuf = BufferIO(b'')

    def open(self):
        if not self.transport.isOpen():
            self.transport.open()
        self.send_sasl_msg(self.START, self.sasl.mechanism)
        self.send_sasl_msg(self.OK, self.sasl.process())
        while True:
            status, challenge = self.recv_sasl_msg()
            if status == self.OK:
                self.send_sasl_msg(self.OK, self.sasl.process(challenge))
            elif status == self.COMPLETE:
                if not self.sasl.complete:
                    raise TTransportException(
                        TTransportException.NOT_OPEN,
                        "importing server.. this "
                        "sudah dilakukan")
                else:
                    break
            else:
                raise TTransportException(
                    TTransportException.NOT_OPEN,
                    "statistik: %d (%s)" % (status, challenge))

    def send_sasl_msg(self, status, body):
        header = pack(">BI", status, len(body))
        self.transport.write(header + body)
        self.transport.flush()

    def recv_sasl_msg(self):
        header = self.transport.readAll(5)
        status, length = unpack(">BI", header)
        if length > 0:
            payload = self.transport.readAll(length)
        else:
            payload = ""
        return status, payload

    def write(self, data):
        self.__wbuf.write(data)

    def flush(self):
        data = self.__wbuf.getvalue()
        encoded = self.sasl.wrap(data)
        self.transport.write(''.join((pack("!i", len(encoded)), encoded)))
        self.transport.flush()
        self.__wbuf = BufferIO()

    def read(self, sz):
        ret = self.__rbuf.read(sz)
        if len(ret) != 0:
            return ret
        self._read_frame()
        return self.__rbuf.read(sz)

    def _read_frame(self):
        header = self.transport.readAll(4)
        length, = unpack('!i', header)
        encoded = self.transport.readAll(length)
        self.__rbuf = BufferIO(self.sasl.unwrap(encoded))

    def close(self):
        self.sasl.dispose()
        self.transport.close()

    @property
    def cstringio_buf(self):
        return self.__rbuf

    def cstringio_refill(self, prefix, reqlen):
        while len(prefix) < reqlen:
            self._read_frame()
            prefix += self.__rbuf.getvalue()
        self.__rbuf = BufferIO(prefix)
        return self.__rbuf
コード例 #51
0
    def test_choose_mechanism(self):
        client = SASLClient('localhost', service='something')
        choices = ['invalid']
        self.assertRaises(SASLError, client.choose_mechanism, choices)

        choices = [
            m for m in mechanisms.values() if m is not DigestMD5Mechanism
        ]
        mech_names = set(m.name for m in choices)
        client.choose_mechanism(mech_names)
        self.assertIsInstance(client._chosen_mech,
                              max(choices, key=lambda m: m.score))

        anon_names = set(m.name for m in choices if m.allows_anonymous)
        client.choose_mechanism(anon_names)
        self.assertIn(client.mechanism, anon_names)
        self.assertRaises(SASLError,
                          client.choose_mechanism,
                          anon_names,
                          allow_anonymous=False)

        plain_names = set(m.name for m in choices if m.uses_plaintext)
        client.choose_mechanism(plain_names)
        self.assertIn(client.mechanism, plain_names)
        self.assertRaises(SASLError,
                          client.choose_mechanism,
                          plain_names,
                          allow_plaintext=False)

        not_active_names = set(m.name for m in choices if not m.active_safe)
        client.choose_mechanism(not_active_names)
        self.assertIn(client.mechanism, not_active_names)
        self.assertRaises(SASLError,
                          client.choose_mechanism,
                          not_active_names,
                          allow_active=False)

        not_dict_names = set(m.name for m in choices if not m.dictionary_safe)
        client.choose_mechanism(not_dict_names)
        self.assertIn(client.mechanism, not_dict_names)
        self.assertRaises(SASLError,
                          client.choose_mechanism,
                          not_dict_names,
                          allow_dictionary=False)
コード例 #52
0
ファイル: connection.py プロジェクト: rkomartin/pycassa
class TSaslClientTransport(TTransportBase, CReadableTransport):

    START = 1
    OK = 2
    BAD = 3
    ERROR = 4
    COMPLETE = 5

    def __init__(self, transport, host, service,
            mechanism='GSSAPI', **sasl_kwargs):

        from puresasl.client import SASLClient

        self.transport = transport
        self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs)

        self.__wbuf = StringIO()
        self.__rbuf = StringIO()

    def open(self):
        if not self.transport.isOpen():
            self.transport.open()

        self.send_sasl_msg(self.START, self.sasl.mechanism)
        self.send_sasl_msg(self.OK, self.sasl.process())

        while True:
            status, challenge = self.recv_sasl_msg()
            if status == self.OK:
                self.send_sasl_msg(self.OK, self.sasl.process(challenge))
            elif status == self.COMPLETE:
                if not self.sasl.complete:
                    raise TTransportException("The server erroneously indicated "
                            "that SASL negotiation was complete")
                else:
                    break
            else:
                raise TTransportException("Bad SASL negotiation status: %d (%s)"
                        % (status, challenge))

    def send_sasl_msg(self, status, body):
        header = struct.pack(">BI", status, len(body))
        self.transport.write(header + body)
        self.transport.flush()

    def recv_sasl_msg(self):
        header = self.transport.readAll(5)
        status, length = struct.unpack(">BI", header)
        if length > 0:
            payload = self.transport.readAll(length)
        else:
            payload = ""
        return status, payload

    def write(self, data):
        self.__wbuf.write(data)

    def flush(self):
        data = self.__wbuf.getvalue()
        encoded = self.sasl.wrap(data)
        # Note stolen from TFramedTransport:
        # N.B.: Doing this string concatenation is WAY cheaper than making
        # two separate calls to the underlying socket object. Socket writes in
        # Python turn out to be REALLY expensive, but it seems to do a pretty
        # good job of managing string buffer operations without excessive copies
        self.transport.write(''.join((struct.pack("!i", len(encoded)), encoded)))
        self.transport.flush()
        self.__wbuf = StringIO()

    def read(self, sz):
        ret = self.__rbuf.read(sz)
        if len(ret) != 0:
            return ret

        self._read_frame()
        return self.__rbuf.read(sz)

    def _read_frame(self):
        header = self.transport.readAll(4)
        length, = struct.unpack('!i', header)
        encoded = self.transport.readAll(length)
        self.__rbuf = StringIO(self.sasl.unwrap(encoded))

    def close(self):
        self.sasl.dispose()
        self.transport.close()

    # Implement the CReadableTransport interface.
    # Stolen shamelessly from TFramedTransport
    @property
    def cstringio_buf(self):
        return self.__rbuf

    def cstringio_refill(self, prefix, reqlen):
        # self.__rbuf will already be empty here because fastbinary doesn't
        # ask for a refill until the previous buffer is empty.  Therefore,
        # we can start reading new frames immediately.
        while len(prefix) < reqlen:
            self._read_frame()
            prefix += self.__rbuf.getvalue()
        self.__rbuf = StringIO(prefix)
        return self.__rbuf