def __init__(self, transport, host, service, mechanism=six.u('GSSAPI'), generate_tickets=False, using_keytab=False, principal=None, keytab_file=None, ccache_file=None, password=None, **sasl_kwargs): """ transport: an underlying transport to use, typically just a TSocket host: the name of the server, from a SASL perspective service: the name of the server's service, from a SASL perspective mechanism: the name of the preferred mechanism to use All other kwargs will be passed to the puresasl.client.SASLClient constructor. """ self.transport = transport if six.PY3: self._patch_pure_sasl() self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs) self.__wbuf = BytesIO() self.__rbuf = BytesIO() self.generate_tickets = generate_tickets if self.generate_tickets: self.krb_context = krbContext(using_keytab, principal, keytab_file, ccache_file, password) self.krb_context.init_with_keytab()
def authenticate_xmpp(self): """Authenticate the user to the XMPP server via the BOSH connection.""" self.request_sid() self.log.debug('Prepare the XMPP authentication') # Instantiate a sasl object sasl = SASLClient(host=self.to, service='xmpp', username=self.jid, password=self.password) # Choose an auth mechanism sasl.choose_mechanism(self.server_auth, allow_anonymous=False) # Request challenge challenge = self.get_challenge(sasl.mechanism) # Process challenge and generate response challengeString = base64.b64decode(challenge) if not 'realm=' in challengeString: challengeString += ',realm="random"' response = sasl.process(challengeString) # Send response resp_root = self.send_challenge_response(response) success = self.check_authenticate_success(resp_root) if success is None and\ resp_root.find('{{{0}}}challenge'.format(XMPP_SASL_NS)) is not None: resp_root = self.send_challenge_response('') return self.check_authenticate_success(resp_root) return success
def sasl_bind(client, host): sasl_client = SASLClient(host, service='ldap', mechanism='GSSAPI') sasl_credentials = SaslCredentials() sasl_credentials.setComponentByName("mechanism", LDAPString("gssapi")) sasl_credentials.setComponentByName("credentials", sasl_client.process(None)) authentication_choice = AuthenticationChoice() authentication_choice.setComponentByName('sasl', sasl_credentials) bind_request = BindRequest() bind_request.setComponentByName('version', Version(3)) bind_request.setComponentByName('name', LDAPDN('')) bind_request.setComponentByName('authentication', authentication_choice) protocol_op = ProtocolOp() protocol_op.setComponentByName("bindRequest", bind_request) ber_encode(authentication_choice) ber_encode(sasl_credentials) print(bind_request.prettyPrint()) ber_encode(bind_request) ber_encode(protocol_op) response = yield from client.request(protocol_op) print(response)
class _BaseMechanismTests(unittest.TestCase): mechanism_class = AnonymousMechanism sasl_kwargs = {} def setUp(self): self.sasl = SASLClient('localhost', mechanism=self.mechanism_class.name, **self.sasl_kwargs) self.mechanism = self.sasl._chosen_mech def test_init_basic(self, *args): sasl = SASLClient('localhost', mechanism=self.mechanism_class.name, **self.sasl_kwargs) mech = sasl._chosen_mech self.assertIs(mech.sasl, sasl) self.assertIsInstance(mech, self.mechanism_class) def test_process_basic(self, *args): self.assertIsInstance(self.sasl.process(six.b('string')), six.binary_type) self.assertIsInstance(self.sasl.process(six.b('string')), six.binary_type) def test_dispose_basic(self, *args): self.sasl.dispose() def test_wrap_unwrap(self, *args): self.assertRaises(NotImplementedError, self.sasl.wrap, 'msg') self.assertRaises(NotImplementedError, self.sasl.unwrap, 'msg') def test__pick_qop(self, *args): self.assertRaises(SASLProtocolException, self.sasl._chosen_mech._pick_qop, set()) self.sasl._chosen_mech._pick_qop(set(QOP.all))
def authenticate_xmpp(self): """Authenticate the user to the XMPP server via the BOSH connection.""" self.request_sid() self.log.debug('Prepare the XMPP authentication') # Instantiate a sasl object sasl = SASLClient( host=self.to, service='xmpp', username=self.jid, password=self.password ) # Choose an auth mechanism sasl.choose_mechanism(self.server_auth, allow_anonymous=False) # Request challenge challenge = self.get_challenge(sasl.mechanism) # Process challenge and generate response response = sasl.process(base64.b64decode(challenge)) # Send response resp_root = self.send_challenge_response(response) success = self.check_authenticate_success(resp_root) if success is None and\ resp_root.find('{{{0}}}challenge'.format(XMPP_SASL_NS)) is not None: resp_root = self.send_challenge_response('') return self.check_authenticate_success(resp_root) return success
def __init__(self, transport, host, service, mechanism='GSSAPI', **sasl_kwargs): from puresasl.client import SASLClient self.transport = transport self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs) self.__wbuf = BufferIO() self.__rbuf = BufferIO(b'')
def authenticate_xmpp(self): """Authenticate the user to the XMPP server via the BOSH connection.""" self.request_sid() self.log.debug('Prepare the XMPP authentication') # Instantiate a sasl object sasl = SASLClient(host=self.to, service='xmpp', username=self.jid, password=self.password) # Choose an auth mechanism sasl.choose_mechanism(self.server_auth, allow_anonymous=False) # Request challenge challenge = self.get_challenge(sasl.mechanism) # Process challenge and generate response response = sasl.process(base64.b64decode(challenge)) # Send response success = self.send_challenge_response(response) if not success: return False self.request_restart() self.bind_resource() return True
def __init__(self, host, service, qops, properties): properties = properties or {} self.sasl = SASLClient(host, service, 'GSSAPI', qops=qops, **properties)
def test_init(self): # defaults SASLClient('localhost') # with mechanism sasl_client = SASLClient('localhost', mechanism=AnonymousMechanism.name) self.assertIsInstance(sasl_client._chosen_mech, AnonymousMechanism) self.assertIs(sasl_client._chosen_mech.sasl, sasl_client) # invalid mech self.assertRaises(SASLError, SASLClient, 'localhost', mechanism='WRONG')
def connect(self): # use service name component from principal service = re.split('[\/@]', str(self.hdfs_namenode_principal))[0] if not self.sasl: self.sasl = SASLClient(self._trans.host, service) negotiate = RpcSaslProto() negotiate.state = 1 self._send_sasl_message(negotiate) # do while true while True: res = self._recv_sasl_message() # TODO: check mechanisms if res.state == 1: mechs = [] for auth in res.auths: mechs.append(auth.mechanism) log.debug("Available mechs: %s" % (",".join(mechs))) self.sasl.choose_mechanism(mechs, allow_anonymous=False) log.debug("Chosen mech: %s" % self.sasl.mechanism) initiate = RpcSaslProto() initiate.state = 2 initiate.token = self.sasl.process() for auth in res.auths: if auth.mechanism == self.sasl.mechanism: auth_method = initiate.auths.add() auth_method.mechanism = self.sasl.mechanism auth_method.method = auth.method auth_method.protocol = auth.protocol auth_method.serverId = self._trans.host self._send_sasl_message(initiate) continue if res.state == 3: res_token = self._evaluate_token(res) response = RpcSaslProto() response.token = res_token response.state = 4 self._send_sasl_message(response) continue if res.state == 0: return True
class SaslAuthenticator(Authenticator): """ A pass-through :class:`~.Authenticator` using the third party package 'pure-sasl' for authentication """ def __init__(self, host, service, mechanism='GSSAPI', **sasl_kwargs): if SASLClient is None: raise ImportError('The puresasl library has not been installed') self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs) def initial_response(self): return self.sasl.process() def evaluate_challenge(self, challenge): return self.sasl.process(challenge)
class GSSAPIAuthenticator(BaseDSEAuthenticator): def __init__(self, host, service, qops, properties): properties = properties or {} self.sasl = SASLClient(host, service, 'GSSAPI', qops=qops, **properties) def get_mechanism(self): return "GSSAPI" def get_initial_challenge(self): return "GSSAPI-START" def evaluate_challenge(self, challenge): if challenge == 'GSSAPI-START': return self.sasl.process() else: return self.sasl.process(challenge)
def test_init_basic(self, *args): sasl = SASLClient('localhost', mechanism=self.mechanism_class.name, **self.sasl_kwargs) mech = sasl._chosen_mech self.assertIs(mech.sasl, sasl) self.assertIsInstance(mech, self.mechanism_class)
class SaslAuthenticator(Authenticator): """ An :class:`~.Authenticator` that works with DSE's KerberosAuthenticator. .. versionadded:: 2.1.3-post """ def __init__(self, host, service, mechanism='GSSAPI', **sasl_kwargs): if SASLClient is None: raise ImportError('The puresasl library has not been installed') self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs) def initial_response(self): return self.sasl.process() def evaluate_challenge(self, challenge): return self.sasl.process(challenge)
def test_unchosen_mechanism(self): client = SASLClient('localhost') self.assertRaises(SASLError, client.process) self.assertRaises(SASLError, client.wrap, 'msg') self.assertRaises(SASLError, client.unwrap, 'msg') with self.assertRaises(SASLError): client.complete self.assertRaises(SASLError, client.dispose)
class SaslAuthenticator(Authenticator): """ A pass-through :class:`~.Authenticator` using the third party package 'pure-sasl' for authentication .. versionadded:: 2.1.4 """ def __init__(self, host, service, mechanism='GSSAPI', **sasl_kwargs): if SASLClient is None: raise ImportError('The puresasl library has not been installed') self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs) def initial_response(self): return self.sasl.process() def evaluate_challenge(self, challenge): return self.sasl.process(challenge)
def test_chosen_mechanism(self): client = SASLClient('localhost', mechanism=PlainMechanism.name, username='******', password='******') self.assertTrue(client.process()) self.assertTrue(client.complete) msg = 'msg' self.assertEqual(client.wrap(msg), msg) self.assertEqual(client.unwrap(msg), msg) client.dispose()
def __init__(self, transport, host, service, mechanism=six.u('GSSAPI'), **sasl_kwargs): """ transport: an underlying transport to use, typically just a TSocket host: the name of the server, from a SASL perspective service: the name of the server's service, from a SASL perspective mechanism: the name of the preferred mechanism to use All other kwargs will be passed to the puresasl.client.SASLClient constructor. """ self.transport = transport if six.PY3: self._patch_pure_sasl() self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs) self.__wbuf = BytesIO() self.__rbuf = BytesIO()
def __init__(self, transport, host, service, mechanism='GSSAPI', **sasl_kwargs): """ transport: an underlying transport to use, typically just a TSocket host: the name of the server, from a SASL perspective service: the name of the server's service, from a SASL perspective mechanism: the name of the preferred mechanism to use All other kwargs will be passed to the puresasl.client.SASLClient constructor. """ from puresasl.client import SASLClient self.transport = transport self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs) self.__wbuf = BufferIO() self.__rbuf = BufferIO(b'')
def __init__(self, transport, host, service, mechanism='GSSAPI', **sasl_kwargs): from puresasl.client import SASLClient self.transport = transport self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs) self.__wbuf = StringIO() self.__rbuf = StringIO()
class GSSAPIAuthenticator(BaseDSEAuthenticator): def __init__(self, host, service, qops, properties): properties = properties or {} self.sasl = SASLClient(host, service, 'GSSAPI', qops=qops, **properties) def get_mechanism(self): return six.b("GSSAPI") def get_initial_challenge(self): return six.b("GSSAPI-START") def evaluate_challenge(self, challenge): if challenge == six.b('GSSAPI-START'): return self.sasl.process() else: return self.sasl.process(challenge)
def test_process_with_authorization_id_or_identity(self): challenge = u"\U0001F44D" identity = 'user2' # Test that we can pass an identity sasl_kwargs = self.sasl_kwargs.copy() sasl_kwargs.update({'identity': identity}) sasl = SASLClient('localhost', mechanism=self.mechanism_class.name, **sasl_kwargs) response = sasl.process(challenge) self.assertEqual( response, six.b('{0}\x00{1}\x00{2}'.format(identity, self.username, self.password))) self.assertIsInstance(response, six.binary_type) # Test that the sasl authorization_id has priority over identity auth_id = 'user3' sasl_kwargs.update({'authorization_id': auth_id}) sasl = SASLClient('localhost', mechanism=self.mechanism_class.name, **sasl_kwargs) response = sasl.process(challenge) self.assertEqual( response, six.b('{0}\x00{1}\x00{2}'.format(auth_id, self.username, self.password))) self.assertIsInstance(response, six.binary_type)
def __init__(self, transport, host, service, mechanism="GSSAPI", **sasl_kwargs): """ transport: an underlying transport to use, typically just a TSocket host: the name of the server, from a SASL perspective service: the name of the server's service, from a SASL perspective mechanism: the name of the preferred mechanism to use All other kwargs will be passed to the puresasl.client.SASLClient constructor. """ self.transport = transport self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs) self.__wbuf = StringIO() self.__rbuf = StringIO()
def test_process_server_answer(self): sasl_kwargs = {'username': "******", 'password': "******"} sasl = SASLClient('elwood.innosoft.com', service="imap", mechanism=self.mechanism_class.name, mutual_auth=True, **sasl_kwargs) testChallenge = ( b'utf-8,username="******",realm="elwood.innosoft.com",' b'nonce="OA6MG9tEQGm2hh",nc=00000001,cnonce="OA6MHXh6VqTrRk",' b'digest-uri="imap/elwood.innosoft.com",' b'response=d388dad90d4bbd760a152321f2143af7,qop=auth') sasl.process(testChallenge) # cnonce is generated randomly so we have to set it so # we assert the expected value sasl._chosen_mech.cnonce = b"OA6MHXh6VqTrRk" serverResponse = (b'rspauth=ea40f60335c427b5527b84dbabcdfffd') sasl.process(serverResponse)
def __init__(self, host, service, mechanism='GSSAPI', **sasl_kwargs): if SASLClient is None: raise ImportError('The puresasl library has not been installed') self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs)
def test_choose_mechanism(self): client = SASLClient('localhost', service='something') choices = ['invalid'] self.assertRaises(SASLError, client.choose_mechanism, choices) choices = [m for m in mechanisms.values() if m is not DigestMD5Mechanism] mech_names = set(m.name for m in choices) client.choose_mechanism(mech_names) self.assertIsInstance(client._chosen_mech, max(choices, key=lambda m: m.score)) anon_names = set(m.name for m in choices if m.allows_anonymous) client.choose_mechanism(anon_names) self.assertIn(client.mechanism, anon_names) self.assertRaises(SASLError, client.choose_mechanism, anon_names, allow_anonymous=False) plain_names = set(m.name for m in choices if m.uses_plaintext) client.choose_mechanism(plain_names) self.assertIn(client.mechanism, plain_names) self.assertRaises(SASLError, client.choose_mechanism, plain_names, allow_plaintext=False) not_active_names = set(m.name for m in choices if not m.active_safe) client.choose_mechanism(not_active_names) self.assertIn(client.mechanism, not_active_names) self.assertRaises(SASLError, client.choose_mechanism, not_active_names, allow_active=False) not_dict_names = set(m.name for m in choices if not m.dictionary_safe) client.choose_mechanism(not_dict_names) self.assertIn(client.mechanism, not_dict_names) self.assertRaises(SASLError, client.choose_mechanism, not_dict_names, allow_dictionary=False)
class ThriftSASLClientProtocol(ThriftClientProtocol): START = 1 OK = 2 BAD = 3 ERROR = 4 COMPLETE = 5 MAX_LENGTH = 2 ** 31 - 1 def __init__(self, client_class, iprot_factory, oprot_factory=None, host=None, service=None, mechanism='GSSAPI', **sasl_kwargs): ThriftClientProtocol.__init__(self, client_class, iprot_factory, oprot_factory) self._sasl_negotiation_deferred = None self._sasl_negotiation_status = None self.client = None if host is not None: self.createSASLClient(host, service, mechanism, **sasl_kwargs) def createSASLClient(self, host, service, mechanism, **kwargs): self.sasl = SASLClient(host, service, mechanism, **kwargs) def dispatch(self, msg): encoded = self.sasl.wrap(msg) len_and_encoded = ''.join((struct.pack('!i', len(encoded)), encoded)) ThriftClientProtocol.dispatch(self, len_and_encoded) @defer.inlineCallbacks def connectionMade(self): self._sendSASLMessage(self.START, self.sasl.mechanism) initial_message = yield deferToThread(self.sasl.process) self._sendSASLMessage(self.OK, initial_message) while True: status, challenge = yield self._receiveSASLMessage() if status == self.OK: response = yield deferToThread(self.sasl.process, challenge) self._sendSASLMessage(self.OK, response) elif status == self.COMPLETE: if not self.sasl.complete: msg = "The server erroneously indicated that SASL " \ "negotiation was complete" raise TTransportException(msg, message=msg) else: break else: msg = "Bad SASL negotiation status: %d (%s)" % (status, challenge) raise TTransportException(msg, message=msg) self._sasl_negotiation_deferred = None ThriftClientProtocol.connectionMade(self) def _sendSASLMessage(self, status, body): if body is None: body = "" header = struct.pack(">BI", status, len(body)) self.transport.write(header + body) def _receiveSASLMessage(self): self._sasl_negotiation_deferred = defer.Deferred() self._sasl_negotiation_status = None return self._sasl_negotiation_deferred def connectionLost(self, reason=connectionDone): if self.client: ThriftClientProtocol.connectionLost(self, reason) def dataReceived(self, data): if self._sasl_negotiation_deferred: # we got a sasl challenge in the format (status, length, challenge) # save the status, let IntNStringReceiver piece the challenge data together self._sasl_negotiation_status, = struct.unpack("B", data[0]) ThriftClientProtocol.dataReceived(self, data[1:]) else: # normal frame, let IntNStringReceiver piece it together ThriftClientProtocol.dataReceived(self, data) def stringReceived(self, frame): if self._sasl_negotiation_deferred: # the frame is just a SASL challenge response = (self._sasl_negotiation_status, frame) self._sasl_negotiation_deferred.callback(response) else: # there's a second 4 byte length prefix inside the frame decoded_frame = self.sasl.unwrap(frame[4:]) ThriftClientProtocol.stringReceived(self, decoded_frame)
class SaslRpcClient: def __init__(self, trans, hdfs_namenode_principal=None): self.sasl = None self._trans = trans self.hdfs_namenode_principal = hdfs_namenode_principal def _send_sasl_message(self, message): rpcheader = RpcRequestHeaderProto() rpcheader.rpcKind = 2 # RPC_PROTOCOL_BUFFER rpcheader.rpcOp = 0 rpcheader.callId = -33 # SASL rpcheader.retryCount = -1 rpcheader.clientId = b"" s_rpcheader = rpcheader.SerializeToString() s_message = message.SerializeToString() header_length = len(s_rpcheader) + encoder._VarintSize( len(s_rpcheader)) + len(s_message) + encoder._VarintSize( len(s_message)) self._trans.write(struct.pack('!I', header_length)) self._trans.write_delimited(s_rpcheader) self._trans.write_delimited(s_message) log_protobuf_message("Send out", message) def _recv_sasl_message(self): bytestream = self._trans.recv_rpc_message() sasl_response = self._trans.parse_response(bytestream, RpcSaslProto) return sasl_response def connect(self): # use service name component from principal service = re.split('[\/@]', str(self.hdfs_namenode_principal))[0] if not self.sasl: self.sasl = SASLClient(self._trans.host, service) negotiate = RpcSaslProto() negotiate.state = 1 self._send_sasl_message(negotiate) # do while true while True: res = self._recv_sasl_message() # TODO: check mechanisms if res.state == 1: mechs = [] for auth in res.auths: mechs.append(auth.mechanism) log.debug("Available mechs: %s" % (",".join(mechs))) self.sasl.choose_mechanism(mechs, allow_anonymous=False) log.debug("Chosen mech: %s" % self.sasl.mechanism) initiate = RpcSaslProto() initiate.state = 2 initiate.token = self.sasl.process() for auth in res.auths: if auth.mechanism == self.sasl.mechanism: auth_method = initiate.auths.add() auth_method.mechanism = self.sasl.mechanism auth_method.method = auth.method auth_method.protocol = auth.protocol auth_method.serverId = self._trans.host self._send_sasl_message(initiate) continue if res.state == 3: res_token = self._evaluate_token(res) response = RpcSaslProto() response.token = res_token response.state = 4 self._send_sasl_message(response) continue if res.state == 0: return True def _evaluate_token(self, sasl_response): return self.sasl.process(challenge=sasl_response.token) def wrap(self, message): encoded = self.sasl.wrap(message) sasl_message = RpcSaslProto() sasl_message.state = 5 # WRAP sasl_message.token = encoded self._send_sasl_message(sasl_message) def unwrap(self): response = self._recv_sasl_message() if response.state != 5: raise Exception("Server send non-wrapped response") return self.sasl.unwrap(response.token) def use_wrap(self): # SASL wrapping is only used if the connection has a QOP, and # the value is not auth. ex. auth-int & auth-priv if self.sasl.qop.decode() == 'auth-int' or self.sasl.qop.decode( ) == 'auth-conf': return True return False
class TSaslClientTransport(TTransportBase): """ SASL transport """ START = 1 OK = 2 BAD = 3 ERROR = 4 COMPLETE = 5 def __init__(self, transport, host, service, mechanism=six.u('GSSAPI'), generate_tickets=False, using_keytab=False, principal=None, keytab_file=None, ccache_file=None, password=None, **sasl_kwargs): """ transport: an underlying transport to use, typically just a TSocket host: the name of the server, from a SASL perspective service: the name of the server's service, from a SASL perspective mechanism: the name of the preferred mechanism to use All other kwargs will be passed to the puresasl.client.SASLClient constructor. """ self.transport = transport if six.PY3: self._patch_pure_sasl() self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs) self.__wbuf = BytesIO() self.__rbuf = BytesIO() self.generate_tickets = generate_tickets if self.generate_tickets: self.krb_context = krbContext(using_keytab, principal, keytab_file, ccache_file, password) self.krb_context.init_with_keytab() def _patch_pure_sasl(self): ''' we need to patch pure_sasl to support python 3 ''' puresasl.mechanisms.mechanisms['GSSAPI'] = CustomGSSAPIMechanism def is_open(self): return self.transport.is_open() and bool(self.sasl) @with_ticket def open(self): if not self.transport.is_open(): self.transport.open() self.send_sasl_msg(self.START, self.sasl.mechanism.encode('utf8')) self.send_sasl_msg(self.OK, self.sasl.process()) while True: status, challenge = self.recv_sasl_msg() if status == self.OK: self.send_sasl_msg(self.OK, self.sasl.process(challenge)) elif status == self.COMPLETE: if not self.sasl.complete: raise TTransportException( TTransportException.NOT_OPEN, "The server erroneously indicated " "that SASL negotiation was complete") else: break else: raise TTransportException( TTransportException.NOT_OPEN, "Bad SASL negotiation status: %d (%s)" % (status, challenge)) def send_sasl_msg(self, status, body): ''' body:bytes ''' header = pack(">BI", status, len(body)) self.transport.write(header + body) self.transport.flush() def recv_sasl_msg(self): header = readall(self.transport.read, 5) status, length = unpack(">BI", header) if length > 0: payload = readall(self.transport.read, length) else: payload = "" return status, payload @with_ticket def write(self, data): self.__wbuf.write(data) @with_ticket def flush(self): data = self.__wbuf.getvalue() encoded = self.sasl.wrap(data) if six.PY2: self.transport.write(''.join([pack("!i", len(encoded)), encoded])) else: self.transport.write(b''.join((pack("!i", len(encoded)), encoded))) self.transport.flush() self.__wbuf = BytesIO() def read(self, sz): ret = self.__rbuf.read(sz) if len(ret) != 0 or sz == 0: return ret self._read_frame() return self.__rbuf.read(sz) def _read_frame(self): header = readall(self.transport.read, 4) length, = unpack('!i', header) encoded = readall(self.transport.read, length) self.__rbuf = BytesIO(self.sasl.unwrap(encoded)) def close(self): self.sasl.dispose() self.transport.close()
class LDAPSocket(object): """Holds a connection to an LDAP server. :param str host_uri: "scheme://netloc" to connect to :param int connect_timeout: Number of seconds to wait for connection to be accepted :param bool ssl_verify: Validate the certificate and hostname on an SSL/TLS connection :param str ssl_ca_file: Path to PEM-formatted concatenated CA certficates file :param str ssl_ca_path: Path to directory with CA certs under hashed file names. See https://www.openssl.org/docs/man1.1.0/ssl/SSL_CTX_load_verify_locations.html for more information about the format of this directory. :param ssl_ca_data: An ASCII string of one or more PEM-encoded certs or a bytes object containing DER-encoded certificates. :type ssl_ca_data: str or bytes """ RECV_BUFFER = 4096 # For ldapi:/// try to connect to these socket files in order # Globs must match exactly one result LDAPI_SOCKET_PATHS = ['/var/run/ldapi', '/var/run/slapd/ldapi', '/var/run/slapd-*.socket'] # OIDs of unsolicited messages OID_DISCONNECTION_NOTICE = '1.3.6.1.4.1.1466.20036' # RFC 4511 sec 4.4.1 Notice of Disconnection def __init__(self, host_uri, connect_timeout=5, ssl_verify=True, ssl_ca_file=None, ssl_ca_path=None, ssl_ca_data=None): self._prop_init(connect_timeout) self._uri_connect(host_uri, ssl_verify, ssl_ca_file, ssl_ca_path, ssl_ca_data) def _prop_init(self, connect_timeout=5): # get socket ID number global _next_sock_id self.ID = _next_sock_id _next_sock_id += 1 # misc init self._message_queues = {} self._next_message_id = 1 self._sasl_client = None self.refcount = 0 self.bound = False self.unbound = False self.abandoned_mids = [] self.started_tls = False self.connect_timeout = connect_timeout def _parse_uri(self, host_uri): # parse host_uri parts = host_uri.split('://') if len(parts) == 1: netloc = unquote(parts[0]) if netloc[0] == '/': scheme = 'ldapi' else: scheme = 'ldap' elif len(parts) == 2: scheme = parts[0] netloc = unquote(parts[1]) else: raise LDAPError('Invalid host_uri') self.uri = '{0}://{1}'.format(scheme, netloc) return scheme, netloc def _uri_connect(self, host_uri, ssl_verify, ssl_ca_file, ssl_ca_path, ssl_ca_data): # connect scheme, netloc = self._parse_uri(host_uri) logger.info('Connecting to {0} on #{1}'.format(self.uri, self.ID)) if scheme == 'ldap': self._inet_connect(netloc, 389) elif scheme == 'ldaps': self._inet_connect(netloc, 636) self.start_tls(ssl_verify, ssl_ca_file, ssl_ca_path, ssl_ca_data) logger.info('Connected with TLS on #{0}'.format(self.ID)) elif scheme == 'ldapi': if not _have_unix_socket: raise LDAPError('Unix sockets are not supported on your platform, please choose a protocol other' 'than ldapi') self.sock_path = None self._sock = socket(AF_UNIX) self.host = 'localhost' if netloc == '/': for sockGlob in LDAPSocket.LDAPI_SOCKET_PATHS: fn = glob(sockGlob) if not fn: continue if len(fn) > 1: logger.debug('Multiple results for glob {0}'.format(sockGlob)) continue fn = fn[0] try: self._connect(fn) self.sock_path = fn break except SocketError: continue if self.sock_path is None: raise LDAPConnectionError('Could not find any local LDAPI unix socket - full ' 'socket path must be supplied in URI') else: try: self._connect(netloc) self.sock_path = netloc except SocketError as e: raise LDAPConnectionError('failed connect to unix socket {0} - {1} ({2})'.format( netloc, e.strerror, e.errno )) logger.debug('Connected to unix socket {0} on #{1}'.format(self.sock_path, self.ID)) else: raise LDAPError('Unsupported scheme "{0}"'.format(scheme)) def _connect(self, addr): self._sock.settimeout(self.connect_timeout) self._sock.connect(addr) self._sock.settimeout(None) def _inet_connect(self, netloc, default_port): ap = netloc.rsplit(':', 1) self.host = ap[0] if len(ap) == 1: port = default_port else: port = int(ap[1]) try: self._sock = create_connection((self.host, port), self.connect_timeout) logger.debug('Connected to {0}:{1} on #{2}'.format(self.host, port, self.ID)) except SocketError as e: raise LDAPConnectionError('failed connect to {0}:{1} - {2} ({3})'.format( self.host, port, e.strerror, e.errno)) def start_tls(self, verify=True, ca_file=None, ca_path=None, ca_data=None): """Install TLS layer on this socket connection. :param bool verify: Validate the certificate and hostname on an SSL/TLS connection :param str ca_file: Path to PEM-formatted concatenated CA certficates file :param str ca_path: Path to directory with CA certs under hashed file names. See https://www.openssl.org/docs/man1.1.0/ssl/SSL_CTX_load_verify_locations.html for more information about the format of this directory. :param ca_data: An ASCII string of one or more PEM-encoded certs or a bytes object containing DER-encoded certificates. :type ca_data: str or bytes """ if self.started_tls: raise LDAPError('TLS layer already installed') if verify: verify_mode = ssl.CERT_REQUIRED else: verify_mode = ssl.CERT_NONE try: proto = ssl.PROTOCOL_TLS except AttributeError: proto = ssl.PROTOCOL_SSLv23 try: ctx = ssl.SSLContext(proto) ctx.verify_mode = verify_mode ctx.check_hostname = False # we do this ourselves if verify: ctx.load_default_certs() if ca_file or ca_path or ca_data: ctx.load_verify_locations(cafile=ca_file, capath=ca_path, cadata=ca_data) self._sock = ctx.wrap_socket(self._sock) except AttributeError: # SSLContext wasn't added until 2.7.9 if ca_path or ca_data: raise RuntimeError('python version >= 2.7.9 required for SSL ca_path/ca_data') self._sock = ssl.wrap_socket(self._sock, ca_certs=ca_file, cert_reqs=verify_mode, ssl_version=proto) if verify: cert = self._sock.getpeercert() cert_cn = dict([e[0] for e in cert['subject']])['commonName'] self.check_hostname(cert_cn, cert) else: logger.debug('Skipping hostname validation') self.started_tls = True logger.debug('Installed TLS layer on #{0}'.format(self.ID)) def check_hostname(self, cert_cn, cert): """SSL check_hostname according to RFC 4513 sec 3.1.3. Compares supplied values against ``self.host`` to determine the validity of the cert. :param str cert_cn: The common name of the cert :param dict cert: A dictionary representing the rest of the cert. Checks key subjectAltNames for a list of (type, value) tuples, where type is 'DNS' or 'IP'. DNS supports leading wildcard. :rtype: None :raises LDAPConnectionError: if no supplied values match ``self.host`` """ if self.host == cert_cn: logger.debug('Matched server identity to cert commonName') else: valid = False tried = [cert_cn] for type, value in cert.get('subjectAltName', []): if type == 'DNS' and value.startswith('*.'): valid = self.host.endswith(value[1:]) else: valid = (self.host == value) tried.append(value) if valid: logger.debug('Matched server identity to cert {0} subjectAltName'.format(type)) break if not valid: raise LDAPConnectionError('Server identity "{0}" does not match any cert names: {1}'.format( self.host, ', '.join(tried))) def sasl_init(self, mechs, **props): """Initialize a :class:`.puresasl.client.SASLClient`""" self._sasl_client = SASLClient(self.host, 'ldap', **props) self._sasl_client.choose_mechanism(mechs) def _has_sasl_client(self): return self._sasl_client is not None def _require_sasl_client(self): if not self._has_sasl_client(): raise LDAPSASLError('SASL init not complete') @property def sasl_qop(self): """Obtain the chosen quality of protection""" self._require_sasl_client() return self._sasl_client.qop @property def sasl_mech(self): """Obtain the chosen mechanism""" self._require_sasl_client() mech = self._sasl_client.mechanism if mech is None: raise LDAPSASLError('SASL init not complete - no mech chosen') else: return mech def sasl_process_auth_challenge(self, challenge): """Process an auth challenge and return the correct response""" self._require_sasl_client() return self._sasl_client.process(challenge) def _prep_message(self, op, obj, controls=None): """Prepare a message for transmission""" mid = self._next_message_id self._next_message_id += 1 lm = pack(mid, op, obj, controls) raw = ber_encode(lm) if self._has_sasl_client(): raw = self._sasl_client.wrap(raw) return mid, raw def send_message(self, op, obj, controls=None): """Create and send an LDAPMessage given an operation name and a corresponding object. Operation names must be defined as component names in laurelin.ldap.rfc4511.ProtocolOp and the object must be of the corresponding type. :param str op: The protocol operation name :param object obj: The associated protocol object (see :class:`.rfc4511.ProtocolOp` for mapping. :param controls: Any request controls for the message :type controls: rfc4511.Controls or None :return: The message ID for this message :rtype: int """ mid, raw = self._prep_message(op, obj, controls) self._sock.sendall(raw) return mid def recv_one(self, want_message_id): """Get the next message with ``want_message_id`` being sent by the server :param int want_message_id: The desired message ID. :return: The LDAP message :rtype: rfc4511.LDAPMessage """ return next(self.recv_messages(want_message_id)) def recv_messages(self, want_message_id): """Iterate all messages with ``want_message_id`` being sent by the server. :param int want_message_id: The desired message ID. :return: An iterator over :class:`.rfc4511.LDAPMessage`. """ flush_queue = True raw = b'' while True: if flush_queue: if want_message_id in self._message_queues: q = self._message_queues[want_message_id] while True: if len(q) == 0: break obj = q.popleft() if len(q) == 0: del self._message_queues[want_message_id] yield obj else: flush_queue = True if want_message_id in self.abandoned_mids: return try: newraw = self._sock.recv(LDAPSocket.RECV_BUFFER) if self._has_sasl_client(): newraw = self._sasl_client.unwrap(newraw) raw += newraw while len(raw) > 0: response, raw = ber_decode(raw, asn1Spec=LDAPMessage()) have_message_id = response.getComponentByName('messageID') if want_message_id == have_message_id: yield response elif have_message_id == 0: msg = 'Received unsolicited message (default message - should never be seen)' try: mid, xr, ctrls = unpack('extendedResp', response) res_code = xr.getComponentByName('resultCode') xr_oid = six.text_type(xr.getComponentByName('responseName')) if xr_oid == LDAPSocket.OID_DISCONNECTION_NOTICE: mtype = 'Notice of Disconnection' else: mtype = 'Unhandled ({0})'.format(xr_oid) diag = xr.getComponentByName('diagnosticMessage') msg = 'Got unsolicited message: {0}: {1}: {2}'.format(mtype, res_code, diag) if res_code == ResultCode('protocolError'): msg += (' (This may indicate an incompatability between laurelin-ldap and your server ' 'distribution)') elif res_code == ResultCode('strongerAuthRequired'): # this is a direct quote from RFC 4511 sec 4.4.1 msg += (' (The server has detected that an established security association between the' ' client and server has unexpectedly failed or been compromised)') except UnexpectedResponseType: msg = 'Unhandled unsolicited message from server' finally: raise LDAPUnsolicitedMessage(response, msg) else: if have_message_id not in self._message_queues: self._message_queues[have_message_id] = deque() self._message_queues[have_message_id].append(response) except SubstrateUnderrunError: flush_queue = False continue def close(self): """Close the low-level socket connection.""" return self._sock.close()
def sasl_init(self, mechs, **props): """Initialize a :class:`.puresasl.client.SASLClient`""" self._sasl_client = SASLClient(self.host, 'ldap', **props) self._sasl_client.choose_mechanism(mechs)
class TSaslClientTransport(TTransportBase, CReadableTransport): """ SASL transport """ START = 1 OK = 2 BAD = 3 ERROR = 4 COMPLETE = 5 def __init__(self, transport, host, service, mechanism='GSSAPI', **sasl_kwargs): """ transport: an underlying transport to use, typically just a TSocket host: the name of the server, from a SASL perspective service: the name of the server's service, from a SASL perspective mechanism: the name of the preferred mechanism to use All other kwargs will be passed to the puresasl.client.SASLClient constructor. """ from puresasl.client import SASLClient self.transport = transport self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs) self.__wbuf = StringIO() self.__rbuf = StringIO() def open(self): if not self.transport.isOpen(): self.transport.open() self.send_sasl_msg(self.START, self.sasl.mechanism) self.send_sasl_msg(self.OK, self.sasl.process()) while True: status, challenge = self.recv_sasl_msg() if status == self.OK: self.send_sasl_msg(self.OK, self.sasl.process(challenge)) elif status == self.COMPLETE: if not self.sasl.complete: raise TTransportException( "The server erroneously indicated " "that SASL negotiation was complete") else: break else: raise TTransportException( "Bad SASL negotiation status: %d (%s)" % (status, challenge)) def send_sasl_msg(self, status, body): header = pack(">BI", status, len(body)) self.transport.write(header + body) self.transport.flush() def recv_sasl_msg(self): header = self.transport.readAll(5) status, length = unpack(">BI", header) if length > 0: payload = self.transport.readAll(length) else: payload = "" return status, payload def write(self, data): self.__wbuf.write(data) def flush(self): data = self.__wbuf.getvalue() encoded = self.sasl.wrap(data) self.transport.write(''.join((pack("!i", len(encoded)), encoded))) self.transport.flush() self.__wbuf = StringIO() def read(self, sz): ret = self.__rbuf.read(sz) if len(ret) != 0: return ret self._read_frame() return self.__rbuf.read(sz) def _read_frame(self): header = self.transport.readAll(4) length, = unpack('!i', header) encoded = self.transport.readAll(length) self.__rbuf = StringIO(self.sasl.unwrap(encoded)) def close(self): self.sasl.dispose() self.transport.close() # based on TFramedTransport @property def cstringio_buf(self): return self.__rbuf def cstringio_refill(self, prefix, reqlen): # self.__rbuf will already be empty here because fastbinary doesn't # ask for a refill until the previous buffer is empty. Therefore, # we can start reading new frames immediately. while len(prefix) < reqlen: self._read_frame() prefix += self.__rbuf.getvalue() self.__rbuf = StringIO(prefix) return self.__rbuf
class Client(object): def __init__(self): self.lasterror = '' self.attributes = {} self.sasl = None # Set attributes to be used in authenticating the session. All attributes should be set # before init() is called. # # @param key Name of attribute being set # @param value Value of attribute being set # @return true iff success. If false is returned, call getError() for error details. # # Available attribute keys: # # service - Name of the service being accessed # username - User identity for authentication # authname - User identity for authorization (if different from username) # password - Password associated with username # host - Fully qualified domain name of the server host # maxbufsize - Maximum receive buffer size for the security layer # minssf - Minimum acceptable security strength factor (integer) # maxssf - Maximum acceptable security strength factor (integer) # externalssf - Security strength factor supplied by external mechanism (i.e. SSL/TLS) # externaluser - Authentication ID (of client) as established by external mechanism def setAttr(self, key, value): self.attributes[key] = value # Initialize the client object. This should be called after all of the properties have been set. # # @return true iff success. If false is returned, call getError() for error details. def getAttr(self, key): return self.attributes[key] def init(self): return True # Start the SASL exchange with the server. # # @param mechList List of mechanisms provided by the server # @param chosen The mechanism chosen by the client # @param initialResponse Initial block of data to send to the server # # @return true iff success. If false is returned, call getError() for error details. def start(self, chosen): self.sasl = SASLClient(self.attributes['host'], mechanism=chosen, callback=self.getAttr) # ret, (bytes)chosen_mech, (bytes)initial_response = self.sasl.start(self.mechanism) return True, chosen.encode(), self.sasl.process() # Step the SASL handshake. # # @param challenge The challenge supplied by the server # @param response (output) The response to be sent back to the server # # @return true iff success. If false is returned, call getError() for error details. def step(self, challenge): return True, self.sasl.process(challenge) # Encode data for secure transmission to the server. # # @param clearText Clear text data to be encrypted # @param cipherText (output) Encrypted data to be transmitted # # @return true iff success. If false is returned, call getError() for error details. def encode(self, clearText): return True, self.sasl.wrap(clearText) # Decode data received from the server. # # @param cipherText Encrypted data received from the server # @param clearText (output) Decrypted clear text data # # @return true iff success. If false is returned, call getError() for error details. def decode(self, cipherText): return True, self.sasl.unwrap(clearText) # Get the user identity (used for authentication) associated with this session. # Note that this is particularly useful for single-sign-on mechanisms in which the # username is not supplied by the application. # # @param userId (output) Authenticated user ID for this session. def getUserId(self): return self.attributes['externaluser'] # Get the security strength factor associated with this session. # # @param ssf (output) Negotiated SSF value. def getSSF(self): return True # Get error message for last error. # This function will return the last error message then clear the error state. # If there was no error or the error state has been cleared, this function will output # an empty string. # # @param error Error message string def getError(self): error = self.lasterror[:] self.lasterror = '' return error
def setUp(self): self.sasl = SASLClient('localhost', mechanism=self.mechanism_class.name, **self.sasl_kwargs) self.mechanism = self.sasl._chosen_mech
def negotiate_sasl(self, token): log.debug("##############NEGOTIATING SASL#####################") # Prepares negotiate request header_bytes = self.create_sasl_header().SerializeToString() negotiate_request = RpcSaslProto() negotiate_request.state = RpcSaslProto.NEGOTIATE negotiate_request.version = 0 sasl_bytes = negotiate_request.SerializeToString() total_length = ( len(header_bytes) + len(sasl_bytes) + encoder._VarintSize(len(header_bytes)) + encoder._VarintSize(len(sasl_bytes)) ) # Sends negotiate request self.write(struct.pack("!I", total_length)) self.write_delimited(header_bytes) self.write_delimited(sasl_bytes) # Gets negotiate response bytes = self.recv_rpc_message() resp = self.parse_response(bytes, RpcSaslProto) chosen_auth = None for auth in resp.auths: if auth.method == "TOKEN" and auth.mechanism == "DIGEST-MD5": chosen_auth = auth if chosen_auth is None: raise IOError("Token digest-MD5 authentication not supported by server") # Prepares initiate request self.sasl = SASLClient( chosen_auth.serverId, chosen_auth.protocol, mechanism=chosen_auth.mechanism, username=base64.b64encode(token["identifier"]), password=base64.b64encode(token["password"]), ) challenge_resp = self.sasl.process(chosen_auth.challenge) auth = RpcSaslProto.SaslAuth() auth.method = chosen_auth.method auth.mechanism = chosen_auth.mechanism auth.protocol = chosen_auth.protocol auth.serverId = chosen_auth.serverId initiate_request = RpcSaslProto() initiate_request.state = RpcSaslProto.INITIATE initiate_request.version = 0 initiate_request.auths.extend([auth]) initiate_request.token = challenge_resp sasl_bytes = initiate_request.SerializeToString() total_length = ( len(header_bytes) + len(sasl_bytes) + encoder._VarintSize(len(header_bytes)) + encoder._VarintSize(len(sasl_bytes)) ) # Sends initiate request self.write(struct.pack("!I", total_length)) self.write_delimited(header_bytes) self.write_delimited(sasl_bytes) bytes = self.recv_rpc_message() resp = self.parse_response(bytes, RpcSaslProto)
def _connect(self, host, port): client = self.client self.logger.info('Connecting to %s:%s, use_ssl: %r', host, port, self.client.use_ssl) self.logger.log(BLATHER, ' Using session_id: %r session_passwd: %s', client._session_id, hexlify(client._session_passwd)) with self._socket_error_handling(): self._socket = self.handler.create_connection( address=(host, port), timeout=client._session_timeout / 1000.0, use_ssl=self.client.use_ssl, keyfile=self.client.keyfile, certfile=self.client.certfile, ca=self.client.ca, keyfile_password=self.client.keyfile_password, verify_certs=self.client.verify_certs, ) self._socket.setblocking(0) connect = Connect(0, client.last_zxid, client._session_timeout, client._session_id or 0, client._session_passwd, client.read_only) connect_result, zxid = self._invoke(client._session_timeout / 1000.0, connect) if connect_result.time_out <= 0: raise SessionExpiredError("Session has expired") if zxid: client.last_zxid = zxid # Load return values client._session_id = connect_result.session_id client._protocol_version = connect_result.protocol_version negotiated_session_timeout = connect_result.time_out connect_timeout = negotiated_session_timeout / len(client.hosts) read_timeout = negotiated_session_timeout * 2.0 / 3.0 client._session_passwd = connect_result.passwd self.logger.log( BLATHER, 'Session created, session_id: %r session_passwd: %s\n' ' negotiated session timeout: %s\n' ' connect timeout: %s\n' ' read timeout: %s', client._session_id, hexlify(client._session_passwd), negotiated_session_timeout, connect_timeout, read_timeout) if connect_result.read_only: client._session_callback(KeeperState.CONNECTED_RO) self._ro_mode = iter(self._server_pinger()) else: self._ro_mode = None # Get a copy of the auth data before iterating, in case it is # changed. client_auth_data_copy = copy.copy(client.auth_data) if client.use_sasl and self.sasl_cli is None: if PURESASL_AVAILABLE: for scheme, auth in client_auth_data_copy: if scheme == 'sasl': username, password = auth.split(":") self.sasl_cli = SASLClient( host=client.sasl_server_principal, service='zookeeper', mechanism='DIGEST-MD5', username=username, password=password) break # As described in rfc # https://tools.ietf.org/html/rfc2831#section-2.1 # sending empty challenge self._send_sasl_request(challenge=b'', timeout=connect_timeout) else: self.logger.warn('Pure-sasl library is missing while sasl' ' authentification is configured. Please' ' install pure-sasl library to connect ' 'using sasl. Now falling back ' 'connecting WITHOUT any ' 'authentification.') client.use_sasl = False client._session_callback(KeeperState.CONNECTED) else: client._session_callback(KeeperState.CONNECTED) for scheme, auth in client_auth_data_copy: if scheme == "digest": ap = Auth(0, scheme, auth) zxid = self._invoke(connect_timeout / 1000.0, ap, xid=AUTH_XID) if zxid: client.last_zxid = zxid return read_timeout, connect_timeout
def __init__(self, *args, **kwds): SASLClient.__init__(self, 'testhost')
class SocketRpcChannel(RpcChannel): ERROR_BYTES = 18446744073709551615L RPC_HEADER = "hrpc" RPC_SERVICE_CLASS = 0x00 AUTH_PROTOCOL_NONE = 0x00 AUTH_PROTOCOL_SASL = 0xDF # -33 RPC_PROTOCOL_BUFFFER = 0x02 """Socket implementation of an RpcChannel. """ def __init__(self, host, port, version, context_protocol, timeout=30): """SocketRpcChannel to connect to a socket server on a user defined port.""" self.host = host self.port = port self.sock = None self.call_id = ( -3 ) # First time (when the connection context is sent, the call_id should be -3, otherwise start with 0 and increment) self.version = version self.client_id = str(uuid.uuid4()) self.context_protocol = context_protocol self.timeout = timeout self.token = None def validate_request(self, request): """Validate the client request against the protocol file.""" # Check the request is correctly initialized if not request.IsInitialized(): raise Exception("Client request (%s) is missing mandatory fields" % type(request)) def get_connection(self, host, port): """Open a socket connection to a given host and port and writes the Hadoop header The Hadoop RPC protocol looks like this when creating a connection: +---------------------------------------------------------------------+ | Header, 4 bytes ("hrpc") | +---------------------------------------------------------------------+ | Version, 1 byte (default verion 9) | +---------------------------------------------------------------------+ | RPC service class, 1 byte (0x00) | +---------------------------------------------------------------------+ | Auth protocol, 1 byte (Auth method None = 0) | +---------------------------------------------------------------------+ | Length of the RpcRequestHeaderProto + length of the | | of the IpcConnectionContextProto (4 bytes/32 bit int) | +---------------------------------------------------------------------+ | Serialized delimited RpcRequestHeaderProto | +---------------------------------------------------------------------+ | Serialized delimited IpcConnectionContextProto | +---------------------------------------------------------------------+ """ log.debug("############## CONNECTING ##############") auth = self.AUTH_PROTOCOL_NONE if self.token is None else self.AUTH_PROTOCOL_SASL # Open socket self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) self.sock.settimeout(self.timeout) # Connect socket to server - defined by host and port arguments self.sock.connect((host, port)) # Send RPC headers self.write(self.RPC_HEADER) # header self.write(struct.pack("B", self.version)) # version self.write(struct.pack("B", self.RPC_SERVICE_CLASS)) # RPC service class self.write(struct.pack("B", auth)) # serialization type (default none) if auth == SocketRpcChannel.AUTH_PROTOCOL_SASL: self.negotiate_sasl(self.token) self.call_id = -3 rpc_header = self.create_rpc_request_header() context = ( self.create_connection_context() if auth is self.AUTH_PROTOCOL_NONE else self.create_connection_context_auth() ) header_length = ( len(rpc_header) + encoder._VarintSize(len(rpc_header)) + len(context) + encoder._VarintSize(len(context)) ) if log.getEffectiveLevel() == logging.DEBUG: log.debug("Header length: %s (%s)" % (header_length, format_bytes(struct.pack("!I", header_length)))) self.write(struct.pack("!I", header_length)) self.write_delimited(rpc_header) self.write_delimited(context) def write(self, data): if log.getEffectiveLevel() == logging.DEBUG: log.debug("Sending: %s", format_bytes(data)) self.sock.send(data) def write_delimited(self, data): self.write(encoder._VarintBytes(len(data))) self.write(data) def create_rpc_request_header(self): """Creates and serializes a delimited RpcRequestHeaderProto message.""" rpcheader = RpcRequestHeaderProto() rpcheader.rpcKind = 2 # rpcheaderproto.RpcKindProto.Value('RPC_PROTOCOL_BUFFER') rpcheader.rpcOp = 0 # rpcheaderproto.RpcPayloadOperationProto.Value('RPC_FINAL_PACKET') rpcheader.callId = self.call_id rpcheader.retryCount = -1 rpcheader.clientId = self.client_id[0:16] if self.call_id == -3: self.call_id = 0 else: self.call_id += 1 # Serialize delimited s_rpcHeader = rpcheader.SerializeToString() log_protobuf_message("RpcRequestHeaderProto (len: %d)" % (len(s_rpcHeader)), rpcheader) return s_rpcHeader def create_connection_context(self): """Creates and seriazlies a IpcConnectionContextProto (not delimited)""" context = IpcConnectionContextProto() local_user = pwd.getpwuid(os.getuid())[0] context.userInfo.effectiveUser = local_user context.protocol = self.context_protocol s_context = context.SerializeToString() log_protobuf_message("RequestContext (len: %d)" % len(s_context), context) return s_context def create_connection_context_auth(self): """Creates and seriazlies a IpcConnectionContextProto (not delimited)""" context = IpcConnectionContextProto() # TODO do this better context.userInfo.effectiveUser = ( "appattempt_" + str(self.appid["cluster_timestamp"]) + "_" + str(self.appid["id"]).zfill(4) + "_000001" ) context.protocol = self.context_protocol import ipdb ipdb.set_trace() s_context = context.SerializeToString() log_protobuf_message("RequestContext (len: %d)" % len(s_context), context) return s_context def create_sasl_header(self): rpcheader = RpcRequestHeaderProto() rpcheader.rpcKind = 2 # rpcheaderproto.RpcKindProto.Value('RPC_PROTOCOL_BUFFER') rpcheader.rpcOp = 0 # rpcheaderproto.RpcPayloadOperationProto.Value('RPC_FINAL_PACKET') rpcheader.callId = -33 rpcheader.retryCount = -1 rpcheader.clientId = self.client_id[0:16] return rpcheader def negotiate_sasl(self, token): log.debug("##############NEGOTIATING SASL#####################") # Prepares negotiate request header_bytes = self.create_sasl_header().SerializeToString() negotiate_request = RpcSaslProto() negotiate_request.state = RpcSaslProto.NEGOTIATE negotiate_request.version = 0 sasl_bytes = negotiate_request.SerializeToString() total_length = ( len(header_bytes) + len(sasl_bytes) + encoder._VarintSize(len(header_bytes)) + encoder._VarintSize(len(sasl_bytes)) ) # Sends negotiate request self.write(struct.pack("!I", total_length)) self.write_delimited(header_bytes) self.write_delimited(sasl_bytes) # Gets negotiate response bytes = self.recv_rpc_message() resp = self.parse_response(bytes, RpcSaslProto) chosen_auth = None for auth in resp.auths: if auth.method == "TOKEN" and auth.mechanism == "DIGEST-MD5": chosen_auth = auth if chosen_auth is None: raise IOError("Token digest-MD5 authentication not supported by server") # Prepares initiate request self.sasl = SASLClient( chosen_auth.serverId, chosen_auth.protocol, mechanism=chosen_auth.mechanism, username=base64.b64encode(token["identifier"]), password=base64.b64encode(token["password"]), ) challenge_resp = self.sasl.process(chosen_auth.challenge) auth = RpcSaslProto.SaslAuth() auth.method = chosen_auth.method auth.mechanism = chosen_auth.mechanism auth.protocol = chosen_auth.protocol auth.serverId = chosen_auth.serverId initiate_request = RpcSaslProto() initiate_request.state = RpcSaslProto.INITIATE initiate_request.version = 0 initiate_request.auths.extend([auth]) initiate_request.token = challenge_resp sasl_bytes = initiate_request.SerializeToString() total_length = ( len(header_bytes) + len(sasl_bytes) + encoder._VarintSize(len(header_bytes)) + encoder._VarintSize(len(sasl_bytes)) ) # Sends initiate request self.write(struct.pack("!I", total_length)) self.write_delimited(header_bytes) self.write_delimited(sasl_bytes) bytes = self.recv_rpc_message() resp = self.parse_response(bytes, RpcSaslProto) # If desired, server can be authenticated using the rspauth in the response def send_rpc_message(self, method, request): """Sends a Hadoop RPC request to the NameNode. The IpcConnectionContextProto, RpcPayloadHeaderProto and HadoopRpcRequestProto should already be serialized in the right way (delimited or not) before they are passed in this method. The Hadoop RPC protocol looks like this for sending requests: When sending requests +---------------------------------------------------------------------+ | Length of the next three parts (4 bytes/32 bit int) | +---------------------------------------------------------------------+ | Delimited serialized RpcRequestHeaderProto (varint len + header) | +---------------------------------------------------------------------+ | Delimited serialized RequestHeaderProto (varint len + header) | +---------------------------------------------------------------------+ | Delimited serialized Request (varint len + request) | +---------------------------------------------------------------------+ """ log.debug("############## SENDING ##############") # 0. RpcRequestHeaderProto rpc_request_header = self.create_rpc_request_header() # 1. RequestHeaderProto request_header = self.create_request_header(method) # 2. Param param = request.SerializeToString() if log.getEffectiveLevel() == logging.DEBUG: log_protobuf_message("Request", request) rpc_message_length = ( len(rpc_request_header) + encoder._VarintSize(len(rpc_request_header)) + len(request_header) + encoder._VarintSize(len(request_header)) + len(param) + encoder._VarintSize(len(param)) ) if log.getEffectiveLevel() == logging.DEBUG: log.debug( "RPC message length: %s (%s)" % (rpc_message_length, format_bytes(struct.pack("!I", rpc_message_length))) ) self.write(struct.pack("!I", rpc_message_length)) self.write_delimited(rpc_request_header) self.write_delimited(request_header) self.write_delimited(param) def create_request_header(self, method): header = RequestHeaderProto() header.methodName = method.name header.declaringClassProtocolName = self.context_protocol header.clientProtocolVersion = 1 s_header = header.SerializeToString() log_protobuf_message("RequestHeaderProto (len: %d)" % len(s_header), header) return s_header def recv_rpc_message(self): """Handle reading an RPC reply from the server. This is done by wrapping the socket in a RcpBufferedReader that allows for rewinding of the buffer stream. """ log.debug("############## RECVING ##############") byte_stream = RpcBufferedReader(self.sock) return byte_stream def get_length(self, byte_stream): """ In Hadoop protobuf RPC, some parts of the stream are delimited with protobuf varint, while others are delimited with 4 byte integers. This reads 4 bytes from the byte stream and retruns the length of the delimited part that follows, by unpacking the 4 bytes and returning the first element from a tuple. The tuple that is returned from struc.unpack() only contains one element. """ length = struct.unpack("!i", byte_stream.read(4))[0] log.debug("4 bytes delimited part length: %d" % length) return length def parse_response(self, byte_stream, response_class): """Parses a Hadoop RPC response. The RpcResponseHeaderProto contains a status field that marks SUCCESS or ERROR. The Hadoop RPC protocol looks like the diagram below for receiving SUCCESS requests. +-----------------------------------------------------------+ | Length of the RPC resonse (4 bytes/32 bit int) | +-----------------------------------------------------------+ | Delimited serialized RpcResponseHeaderProto | +-----------------------------------------------------------+ | Serialized delimited RPC response | +-----------------------------------------------------------+ In case of an error, the header status is set to ERROR and the error fields are set. """ log.debug("############## PARSING ##############") log.debug("Payload class: %s" % response_class) # Read first 4 bytes to get the total length len_bytes = byte_stream.read(4) total_length = struct.unpack("!I", len_bytes)[0] log.debug("Total response length: %s" % total_length) header = RpcResponseHeaderProto() (header_len, header_bytes) = get_delimited_message_bytes(byte_stream) log.debug("Header read %d" % header_len) header.ParseFromString(header_bytes) log_protobuf_message("RpcResponseHeaderProto", header) if header.status == 0: log.debug("header: %s, total: %s" % (header_len, total_length)) if header_len >= total_length: return response = response_class() response_bytes = get_delimited_message_bytes(byte_stream, total_length - header_len)[1] if len(response_bytes) > 0: response.ParseFromString(response_bytes) if log.getEffectiveLevel() == logging.DEBUG: log_protobuf_message("Response", response) return response else: self.handle_error(header) def handle_error(self, header): raise RequestError("\n".join([header.exceptionClassName, header.errorMsg])) def close_socket(self): """Closes the socket and resets the channel.""" log.debug("Closing socket") if self.sock: try: self.sock.close() except: pass self.sock = None def CallMethod(self, method, controller, request, response_class, done): """Call the RPC method. The naming doesn't confirm PEP8, since it's a method called by protobuf """ try: self.validate_request(request) if not self.sock: self.get_connection(self.host, self.port) self.send_rpc_message(method, request) byte_stream = self.recv_rpc_message() return self.parse_response(byte_stream, response_class) except RequestError: # Raise a request error, but don't close the socket raise except Exception: # All other errors close the socket self.close_socket() raise
class TSaslClientTransport(TTransportBase, CReadableTransport): """ A SASL transport based on the pure-sasl library: https://github.com/thobbs/pure-sasl """ START = 1 OK = 2 BAD = 3 ERROR = 4 COMPLETE = 5 def __init__(self, transport, host, service, mechanism="GSSAPI", **sasl_kwargs): """ transport: an underlying transport to use, typically just a TSocket host: the name of the server, from a SASL perspective service: the name of the server's service, from a SASL perspective mechanism: the name of the preferred mechanism to use All other kwargs will be passed to the puresasl.client.SASLClient constructor. """ self.transport = transport self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs) self.__wbuf = StringIO() self.__rbuf = StringIO() # extremely awful hack, but you've got to do what you've got to do. # essentially "wrap" and "unwrap" are defined for the base Mechanism class and raise a NotImplementedError by # default, and PlainMechanism doesn't implement its own versions (lol). # self.sasl._chosen_mech.wrap = lambda x: x # self.sasl._chosen_mech.unwrap = lambda x: x def open(self): if not self.transport.isOpen(): self.transport.open() self.send_sasl_msg(self.START, self.sasl.mechanism) self.send_sasl_msg(self.OK, self.sasl.process() or "") while True: status, challenge = self.recv_sasl_msg() if status == self.OK: self.send_sasl_msg(self.OK, self.sasl.process(challenge) or "") elif status == self.COMPLETE: # self.sasl.complete is not set for PLAIN authentication (trollface.jpg) so we have to skip this check # break if not self.sasl.complete: raise TTransportException("The server erroneously indicated " "that SASL negotiation was complete") else: break else: raise TTransportException("Bad SASL negotiation status: %d (%s)" % (status, challenge)) def send_sasl_msg(self, status, body): if body is None: body = "" header = pack(">BI", status, len(body)) body = body if isinstance(body, bytes) else body.encode("utf-8") self.transport.write(header + body) self.transport.flush() def recv_sasl_msg(self): header = self.transport.readAll(5) status, length = unpack(">BI", header) if length > 0: payload = self.transport.readAll(length) else: payload = "" return status, payload def write(self, data): self.__wbuf.write(data) def flush(self): data = self.__wbuf.getvalue() encoded = self.sasl.wrap(data) self.transport.write("".join((pack("!i", len(encoded)), encoded))) self.transport.flush() self.__wbuf = StringIO() def read(self, sz): ret = self.__rbuf.read(sz) if len(ret) != 0: return ret self._read_frame() return self.__rbuf.read(sz) def _read_frame(self): header = self.transport.readAll(4) length, = unpack("!i", header) encoded = self.transport.readAll(length) self.__rbuf = StringIO(self.sasl.unwrap(encoded)) def close(self): self.sasl.dispose() self.transport.close() # based on TFramedTransport @property def cstringio_buf(self): return self.__rbuf def cstringio_refill(self, prefix, reqlen): # self.__rbuf will already be empty here because fastbinary doesn't # ask for a refill until the previous buffer is empty. Therefore, # we can start reading new frames immediately. while len(prefix) < reqlen: self._read_frame() prefix += self.__rbuf.getvalue() self.__rbuf = StringIO(prefix) return self.__rbuf
def createSASLClient(self, host, service, mechanism, **kwargs): self.sasl = SASLClient(host, service, mechanism, **kwargs)
def start(self, chosen): self.sasl = SASLClient(self.attributes['host'], mechanism=chosen, callback=self.getAttr) # ret, (bytes)chosen_mech, (bytes)initial_response = self.sasl.start(self.mechanism) return True, chosen.encode(), self.sasl.process()
class TSaslClientTransport(TTransportBase, CReadableTransport): """ SASL transport """ START = 1 OK = 2 BAD = 3 ERROR = 4 COMPLETE = 5 def __init__(self, transport, host, service, mechanism='GSSAPI', **sasl_kwargs): """ transport: an underlying transport to use, typically just a TSocket host: the name of the server, from a SASL perspective service: the name of the server's service, from a SASL perspective mechanism: the name of the preferred mechanism to use All other kwargs will be passed to the puresasl.client.SASLClient constructor. """ from puresasl.client import SASLClient self.transport = transport self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs) self.__wbuf = StringIO() self.__rbuf = StringIO() def open(self): if not self.transport.isOpen(): self.transport.open() self.send_sasl_msg(self.START, self.sasl.mechanism) self.send_sasl_msg(self.OK, self.sasl.process()) while True: status, challenge = self.recv_sasl_msg() if status == self.OK: self.send_sasl_msg(self.OK, self.sasl.process(challenge)) elif status == self.COMPLETE: if not self.sasl.complete: raise TTransportException("The server erroneously indicated " "that SASL negotiation was complete") else: break else: raise TTransportException("Bad SASL negotiation status: %d (%s)" % (status, challenge)) def send_sasl_msg(self, status, body): header = pack(">BI", status, len(body)) self.transport.write(header + body) self.transport.flush() def recv_sasl_msg(self): header = self.transport.readAll(5) status, length = unpack(">BI", header) if length > 0: payload = self.transport.readAll(length) else: payload = "" return status, payload def write(self, data): self.__wbuf.write(data) def flush(self): data = self.__wbuf.getvalue() encoded = self.sasl.wrap(data) self.transport.write(''.join((pack("!i", len(encoded)), encoded))) self.transport.flush() self.__wbuf = StringIO() def read(self, sz): ret = self.__rbuf.read(sz) if len(ret) != 0: return ret self._read_frame() return self.__rbuf.read(sz) def _read_frame(self): header = self.transport.readAll(4) length, = unpack('!i', header) encoded = self.transport.readAll(length) self.__rbuf = StringIO(self.sasl.unwrap(encoded)) def close(self): self.sasl.dispose() self.transport.close() # based on TFramedTransport @property def cstringio_buf(self): return self.__rbuf def cstringio_refill(self, prefix, reqlen): # self.__rbuf will already be empty here because fastbinary doesn't # ask for a refill until the previous buffer is empty. Therefore, # we can start reading new frames immediately. while len(prefix) < reqlen: self._read_frame() prefix += self.__rbuf.getvalue() self.__rbuf = StringIO(prefix) return self.__rbuf
class TSaslClientTransport(TTransportBase, CReadableTransport): START = 1 OK = 2 BAD = 3 ERROR = 4 COMPLETE = 5 def __init__(self, transport, host, service, mechanism='GSSAPI', **sasl_kwargs): from puresasl.client import SASLClient self.transport = transport self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs) self.__wbuf = StringIO() self.__rbuf = StringIO() def open(self): if not self.transport.isOpen(): self.transport.open() self.send_sasl_msg(self.START, self.sasl.mechanism) self.send_sasl_msg(self.OK, self.sasl.process()) while True: status, challenge = self.recv_sasl_msg() if status == self.OK: self.send_sasl_msg(self.OK, self.sasl.process(challenge)) elif status == self.COMPLETE: if not self.sasl.complete: raise TTransportException("The server erroneously indicated " "that SASL negotiation was complete") else: break else: raise TTransportException("Bad SASL negotiation status: %d (%s)" % (status, challenge)) def send_sasl_msg(self, status, body): header = struct.pack(">BI", status, len(body)) self.transport.write(header + body) self.transport.flush() def recv_sasl_msg(self): header = self.transport.readAll(5) status, length = struct.unpack(">BI", header) if length > 0: payload = self.transport.readAll(length) else: payload = "" return status, payload def write(self, data): self.__wbuf.write(data) def flush(self): data = self.__wbuf.getvalue() encoded = self.sasl.wrap(data) # Note stolen from TFramedTransport: # N.B.: Doing this string concatenation is WAY cheaper than making # two separate calls to the underlying socket object. Socket writes in # Python turn out to be REALLY expensive, but it seems to do a pretty # good job of managing string buffer operations without excessive copies self.transport.write(''.join((struct.pack("!i", len(encoded)), encoded))) self.transport.flush() self.__wbuf = StringIO() def read(self, sz): ret = self.__rbuf.read(sz) if len(ret) != 0: return ret self._read_frame() return self.__rbuf.read(sz) def _read_frame(self): header = self.transport.readAll(4) length, = struct.unpack('!i', header) encoded = self.transport.readAll(length) self.__rbuf = StringIO(self.sasl.unwrap(encoded)) def close(self): self.sasl.dispose() self.transport.close() # Implement the CReadableTransport interface. # Stolen shamelessly from TFramedTransport @property def cstringio_buf(self): return self.__rbuf def cstringio_refill(self, prefix, reqlen): # self.__rbuf will already be empty here because fastbinary doesn't # ask for a refill until the previous buffer is empty. Therefore, # we can start reading new frames immediately. while len(prefix) < reqlen: self._read_frame() prefix += self.__rbuf.getvalue() self.__rbuf = StringIO(prefix) return self.__rbuf
class TSaslClientTransport(TTransportBase, CReadableTransport): START = 1 OK = 2 BAD = 3 ERROR = 4 COMPLETE = 5 def __init__(self, transport, host, service, mechanism='GSSAPI', **sasl_kwargs): from puresasl.client import SASLClient self.transport = transport self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs) self.__wbuf = BufferIO() self.__rbuf = BufferIO(b'') def open(self): if not self.transport.isOpen(): self.transport.open() self.send_sasl_msg(self.START, self.sasl.mechanism) self.send_sasl_msg(self.OK, self.sasl.process()) while True: status, challenge = self.recv_sasl_msg() if status == self.OK: self.send_sasl_msg(self.OK, self.sasl.process(challenge)) elif status == self.COMPLETE: if not self.sasl.complete: raise TTransportException( TTransportException.NOT_OPEN, "importing server.. this " "sudah dilakukan") else: break else: raise TTransportException( TTransportException.NOT_OPEN, "statistik: %d (%s)" % (status, challenge)) def send_sasl_msg(self, status, body): header = pack(">BI", status, len(body)) self.transport.write(header + body) self.transport.flush() def recv_sasl_msg(self): header = self.transport.readAll(5) status, length = unpack(">BI", header) if length > 0: payload = self.transport.readAll(length) else: payload = "" return status, payload def write(self, data): self.__wbuf.write(data) def flush(self): data = self.__wbuf.getvalue() encoded = self.sasl.wrap(data) self.transport.write(''.join((pack("!i", len(encoded)), encoded))) self.transport.flush() self.__wbuf = BufferIO() def read(self, sz): ret = self.__rbuf.read(sz) if len(ret) != 0: return ret self._read_frame() return self.__rbuf.read(sz) def _read_frame(self): header = self.transport.readAll(4) length, = unpack('!i', header) encoded = self.transport.readAll(length) self.__rbuf = BufferIO(self.sasl.unwrap(encoded)) def close(self): self.sasl.dispose() self.transport.close() @property def cstringio_buf(self): return self.__rbuf def cstringio_refill(self, prefix, reqlen): while len(prefix) < reqlen: self._read_frame() prefix += self.__rbuf.getvalue() self.__rbuf = BufferIO(prefix) return self.__rbuf
def test_choose_mechanism(self): client = SASLClient('localhost', service='something') choices = ['invalid'] self.assertRaises(SASLError, client.choose_mechanism, choices) choices = [ m for m in mechanisms.values() if m is not DigestMD5Mechanism ] mech_names = set(m.name for m in choices) client.choose_mechanism(mech_names) self.assertIsInstance(client._chosen_mech, max(choices, key=lambda m: m.score)) anon_names = set(m.name for m in choices if m.allows_anonymous) client.choose_mechanism(anon_names) self.assertIn(client.mechanism, anon_names) self.assertRaises(SASLError, client.choose_mechanism, anon_names, allow_anonymous=False) plain_names = set(m.name for m in choices if m.uses_plaintext) client.choose_mechanism(plain_names) self.assertIn(client.mechanism, plain_names) self.assertRaises(SASLError, client.choose_mechanism, plain_names, allow_plaintext=False) not_active_names = set(m.name for m in choices if not m.active_safe) client.choose_mechanism(not_active_names) self.assertIn(client.mechanism, not_active_names) self.assertRaises(SASLError, client.choose_mechanism, not_active_names, allow_active=False) not_dict_names = set(m.name for m in choices if not m.dictionary_safe) client.choose_mechanism(not_dict_names) self.assertIn(client.mechanism, not_dict_names) self.assertRaises(SASLError, client.choose_mechanism, not_dict_names, allow_dictionary=False)