Ejemplo n.º 1
0
 def test_chosen_mechanism(self):
     client = SASLClient('localhost', mechanism=PlainMechanism.name, username='******', password='******')
     self.assertTrue(client.process())
     self.assertTrue(client.complete)
     msg = 'msg'
     self.assertEqual(client.wrap(msg), msg)
     self.assertEqual(client.unwrap(msg), msg)
     client.dispose()
Ejemplo n.º 2
0
 def test_chosen_mechanism(self):
     client = SASLClient('localhost', mechanism=PlainMechanism.name, username='******', password='******')
     self.assertTrue(client.process())
     self.assertTrue(client.complete)
     msg = 'msg'
     self.assertEqual(client.wrap(msg), msg)
     self.assertEqual(client.unwrap(msg), msg)
     client.dispose()
Ejemplo n.º 3
0
class TSaslClientTransport(TTransportBase, CReadableTransport):
  """
  SASL transport 
  """

  START = 1
  OK = 2
  BAD = 3
  ERROR = 4
  COMPLETE = 5

  def __init__(self, transport, host, service, mechanism='GSSAPI',
      **sasl_kwargs):
    """
    transport: an underlying transport to use, typically just a TSocket
    host: the name of the server, from a SASL perspective
    service: the name of the server's service, from a SASL perspective
    mechanism: the name of the preferred mechanism to use

    All other kwargs will be passed to the puresasl.client.SASLClient
    constructor.
    """

    from puresasl.client import SASLClient

    self.transport = transport
    self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs)

    self.__wbuf = StringIO()
    self.__rbuf = StringIO()

  def open(self):
    if not self.transport.isOpen():
      self.transport.open()

    self.send_sasl_msg(self.START, self.sasl.mechanism)
    self.send_sasl_msg(self.OK, self.sasl.process())

    while True:
      status, challenge = self.recv_sasl_msg()
      if status == self.OK:
        self.send_sasl_msg(self.OK, self.sasl.process(challenge))
      elif status == self.COMPLETE:
        if not self.sasl.complete:
          raise TTransportException("The server erroneously indicated "
              "that SASL negotiation was complete")
        else:
          break
      else:
        raise TTransportException("Bad SASL negotiation status: %d (%s)"
            % (status, challenge))

  def send_sasl_msg(self, status, body):
    header = pack(">BI", status, len(body))
    self.transport.write(header + body)
    self.transport.flush()

  def recv_sasl_msg(self):
    header = self.transport.readAll(5)
    status, length = unpack(">BI", header)
    if length > 0:
      payload = self.transport.readAll(length)
    else:
      payload = ""
    return status, payload

  def write(self, data):
    self.__wbuf.write(data)

  def flush(self):
    data = self.__wbuf.getvalue()
    encoded = self.sasl.wrap(data)
    self.transport.write(''.join((pack("!i", len(encoded)), encoded)))
    self.transport.flush()
    self.__wbuf = StringIO()

  def read(self, sz):
    ret = self.__rbuf.read(sz)
    if len(ret) != 0:
      return ret

    self._read_frame()
    return self.__rbuf.read(sz)

  def _read_frame(self):
    header = self.transport.readAll(4)
    length, = unpack('!i', header)
    encoded = self.transport.readAll(length)
    self.__rbuf = StringIO(self.sasl.unwrap(encoded))

  def close(self):
    self.sasl.dispose()
    self.transport.close()

  # based on TFramedTransport
  @property
  def cstringio_buf(self):
    return self.__rbuf

  def cstringio_refill(self, prefix, reqlen):
    # self.__rbuf will already be empty here because fastbinary doesn't
    # ask for a refill until the previous buffer is empty.  Therefore,
    # we can start reading new frames immediately.
    while len(prefix) < reqlen:
      self._read_frame()
      prefix += self.__rbuf.getvalue()
    self.__rbuf = StringIO(prefix)
    return self.__rbuf
Ejemplo n.º 4
0
class TSaslClientTransport(TTransportBase, CReadableTransport):
    """
  SASL transport 
  """

    START = 1
    OK = 2
    BAD = 3
    ERROR = 4
    COMPLETE = 5

    def __init__(self,
                 transport,
                 host,
                 service,
                 mechanism='GSSAPI',
                 **sasl_kwargs):
        """
    transport: an underlying transport to use, typically just a TSocket
    host: the name of the server, from a SASL perspective
    service: the name of the server's service, from a SASL perspective
    mechanism: the name of the preferred mechanism to use

    All other kwargs will be passed to the puresasl.client.SASLClient
    constructor.
    """

        from puresasl.client import SASLClient

        self.transport = transport
        self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs)

        self.__wbuf = StringIO()
        self.__rbuf = StringIO()

    def open(self):
        if not self.transport.isOpen():
            self.transport.open()

        self.send_sasl_msg(self.START, self.sasl.mechanism)
        self.send_sasl_msg(self.OK, self.sasl.process())

        while True:
            status, challenge = self.recv_sasl_msg()
            if status == self.OK:
                self.send_sasl_msg(self.OK, self.sasl.process(challenge))
            elif status == self.COMPLETE:
                if not self.sasl.complete:
                    raise TTransportException(
                        "The server erroneously indicated "
                        "that SASL negotiation was complete")
                else:
                    break
            else:
                raise TTransportException(
                    "Bad SASL negotiation status: %d (%s)" %
                    (status, challenge))

    def send_sasl_msg(self, status, body):
        header = pack(">BI", status, len(body))
        self.transport.write(header + body)
        self.transport.flush()

    def recv_sasl_msg(self):
        header = self.transport.readAll(5)
        status, length = unpack(">BI", header)
        if length > 0:
            payload = self.transport.readAll(length)
        else:
            payload = ""
        return status, payload

    def write(self, data):
        self.__wbuf.write(data)

    def flush(self):
        data = self.__wbuf.getvalue()
        encoded = self.sasl.wrap(data)
        self.transport.write(''.join((pack("!i", len(encoded)), encoded)))
        self.transport.flush()
        self.__wbuf = StringIO()

    def read(self, sz):
        ret = self.__rbuf.read(sz)
        if len(ret) != 0:
            return ret

        self._read_frame()
        return self.__rbuf.read(sz)

    def _read_frame(self):
        header = self.transport.readAll(4)
        length, = unpack('!i', header)
        encoded = self.transport.readAll(length)
        self.__rbuf = StringIO(self.sasl.unwrap(encoded))

    def close(self):
        self.sasl.dispose()
        self.transport.close()

    # based on TFramedTransport
    @property
    def cstringio_buf(self):
        return self.__rbuf

    def cstringio_refill(self, prefix, reqlen):
        # self.__rbuf will already be empty here because fastbinary doesn't
        # ask for a refill until the previous buffer is empty.  Therefore,
        # we can start reading new frames immediately.
        while len(prefix) < reqlen:
            self._read_frame()
            prefix += self.__rbuf.getvalue()
        self.__rbuf = StringIO(prefix)
        return self.__rbuf
Ejemplo n.º 5
0
class Client(object):

    def __init__(self):
        self.lasterror = ''
        self.attributes = {}
        self.sasl = None

    # Set attributes to be used in authenticating the session.  All attributes should be set
    # before init() is called.
    #
    # @param key Name of attribute being set
    # @param value Value of attribute being set
    # @return true iff success.  If false is returned, call getError() for error details.
    #
    # Available attribute keys:
    #
    #    service      - Name of the service being accessed
    #    username     - User identity for authentication
    #    authname     - User identity for authorization (if different from username)
    #    password     - Password associated with username
    #    host         - Fully qualified domain name of the server host
    #    maxbufsize   - Maximum receive buffer size for the security layer
    #    minssf       - Minimum acceptable security strength factor (integer)
    #    maxssf       - Maximum acceptable security strength factor (integer)
    #    externalssf  - Security strength factor supplied by external mechanism (i.e. SSL/TLS)
    #    externaluser - Authentication ID (of client) as established by external mechanism

    def setAttr(self, key, value):
        self.attributes[key] = value

    # Initialize the client object.  This should be called after all of the properties have been set.
    #
    # @return true iff success.  If false is returned, call getError() for error details.

    def getAttr(self, key):
        return self.attributes[key]

    def init(self):
        return True

    # Start the SASL exchange with the server.
    #
    # @param mechList List of mechanisms provided by the server
    # @param chosen The mechanism chosen by the client
    # @param initialResponse Initial block of data to send to the server
    #
    # @return true iff success.  If false is returned, call getError() for error details.

    def start(self, chosen):
        self.sasl = SASLClient(self.attributes['host'], mechanism=chosen, callback=self.getAttr)
        # ret, (bytes)chosen_mech, (bytes)initial_response = self.sasl.start(self.mechanism)
        return True, chosen.encode(), self.sasl.process()

    # Step the SASL handshake.
    #
    # @param challenge The challenge supplied by the server
    # @param response (output) The response to be sent back to the server
    #
    # @return true iff success.  If false is returned, call getError() for error details.
            
    def step(self, challenge):
        return True, self.sasl.process(challenge)

    # Encode data for secure transmission to the server.
    #
    # @param clearText Clear text data to be encrypted
    # @param cipherText (output) Encrypted data to be transmitted
    #
    # @return true iff success.  If false is returned, call getError() for error details.

    def encode(self, clearText):
        return True, self.sasl.wrap(clearText)

    # Decode data received from the server.
    #
    # @param cipherText Encrypted data received from the server
    # @param clearText (output) Decrypted clear text data 
    #
    # @return true iff success.  If false is returned, call getError() for error details.

    def decode(self, cipherText):
        return True, self.sasl.unwrap(clearText)

    # Get the user identity (used for authentication) associated with this session.
    # Note that this is particularly useful for single-sign-on mechanisms in which the 
    # username is not supplied by the application.
    #
    # @param userId (output) Authenticated user ID for this session.

    def getUserId(self):
        return self.attributes['externaluser']

    # Get the security strength factor associated with this session.
    #
    # @param ssf (output) Negotiated SSF value.

    def getSSF(self):
        return True

    # Get error message for last error.
    # This function will return the last error message then clear the error state.
    # If there was no error or the error state has been cleared, this function will output
    # an empty string.
    #
    # @param error Error message string

    def getError(self):
        error = self.lasterror[:]
        self.lasterror = ''
        return error
Ejemplo n.º 6
0
class SaslRpcClient:
    def __init__(self, trans, hdfs_namenode_principal=None):
        self.sasl = None
        self._trans = trans
        self.hdfs_namenode_principal = hdfs_namenode_principal

    def _send_sasl_message(self, message):
        rpcheader = RpcRequestHeaderProto()
        rpcheader.rpcKind = 2  # RPC_PROTOCOL_BUFFER
        rpcheader.rpcOp = 0
        rpcheader.callId = -33  # SASL
        rpcheader.retryCount = -1
        rpcheader.clientId = b""

        s_rpcheader = rpcheader.SerializeToString()
        s_message = message.SerializeToString()

        header_length = len(s_rpcheader) + encoder._VarintSize(
            len(s_rpcheader)) + len(s_message) + encoder._VarintSize(
                len(s_message))

        self._trans.write(struct.pack('!I', header_length))
        self._trans.write_delimited(s_rpcheader)
        self._trans.write_delimited(s_message)

        log_protobuf_message("Send out", message)

    def _recv_sasl_message(self):
        bytestream = self._trans.recv_rpc_message()
        sasl_response = self._trans.parse_response(bytestream, RpcSaslProto)

        return sasl_response

    def connect(self):
        # use service name component from principal
        service = re.split('[\/@]', str(self.hdfs_namenode_principal))[0]

        if not self.sasl:
            self.sasl = SASLClient(self._trans.host, service)

        negotiate = RpcSaslProto()
        negotiate.state = 1
        self._send_sasl_message(negotiate)

        # do while true
        while True:
            res = self._recv_sasl_message()
            # TODO: check mechanisms
            if res.state == 1:
                mechs = []
                for auth in res.auths:
                    mechs.append(auth.mechanism)

                log.debug("Available mechs: %s" % (",".join(mechs)))
                self.sasl.choose_mechanism(mechs, allow_anonymous=False)
                log.debug("Chosen mech: %s" % self.sasl.mechanism)

                initiate = RpcSaslProto()
                initiate.state = 2
                initiate.token = self.sasl.process()

                for auth in res.auths:
                    if auth.mechanism == self.sasl.mechanism:
                        auth_method = initiate.auths.add()
                        auth_method.mechanism = self.sasl.mechanism
                        auth_method.method = auth.method
                        auth_method.protocol = auth.protocol
                        auth_method.serverId = self._trans.host

                self._send_sasl_message(initiate)
                continue

            if res.state == 3:
                res_token = self._evaluate_token(res)
                response = RpcSaslProto()
                response.token = res_token
                response.state = 4
                self._send_sasl_message(response)
                continue

            if res.state == 0:
                return True

    def _evaluate_token(self, sasl_response):
        return self.sasl.process(challenge=sasl_response.token)

    def wrap(self, message):
        encoded = self.sasl.wrap(message)

        sasl_message = RpcSaslProto()
        sasl_message.state = 5  #  WRAP
        sasl_message.token = encoded

        self._send_sasl_message(sasl_message)

    def unwrap(self):
        response = self._recv_sasl_message()
        if response.state != 5:
            raise Exception("Server send non-wrapped response")

        return self.sasl.unwrap(response.token)

    def use_wrap(self):
        # SASL wrapping is only used if the connection has a QOP, and
        # the value is not auth.  ex. auth-int & auth-priv
        if self.sasl.qop.decode() == 'auth-int' or self.sasl.qop.decode(
        ) == 'auth-conf':
            return True
        return False
Ejemplo n.º 7
0
class TSaslClientTransport(TTransportBase, CReadableTransport):
    START = 1
    OK = 2
    BAD = 3
    ERROR = 4
    COMPLETE = 5

    def __init__(self,
                 transport,
                 host,
                 service,
                 mechanism='GSSAPI',
                 **sasl_kwargs):
        from puresasl.client import SASLClient
        self.transport = transport
        self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs)
        self.__wbuf = BufferIO()
        self.__rbuf = BufferIO(b'')

    def open(self):
        if not self.transport.isOpen():
            self.transport.open()
        self.send_sasl_msg(self.START, self.sasl.mechanism)
        self.send_sasl_msg(self.OK, self.sasl.process())
        while True:
            status, challenge = self.recv_sasl_msg()
            if status == self.OK:
                self.send_sasl_msg(self.OK, self.sasl.process(challenge))
            elif status == self.COMPLETE:
                if not self.sasl.complete:
                    raise TTransportException(
                        TTransportException.NOT_OPEN,
                        "importing server.. this "
                        "sudah dilakukan")
                else:
                    break
            else:
                raise TTransportException(
                    TTransportException.NOT_OPEN,
                    "statistik: %d (%s)" % (status, challenge))

    def send_sasl_msg(self, status, body):
        header = pack(">BI", status, len(body))
        self.transport.write(header + body)
        self.transport.flush()

    def recv_sasl_msg(self):
        header = self.transport.readAll(5)
        status, length = unpack(">BI", header)
        if length > 0:
            payload = self.transport.readAll(length)
        else:
            payload = ""
        return status, payload

    def write(self, data):
        self.__wbuf.write(data)

    def flush(self):
        data = self.__wbuf.getvalue()
        encoded = self.sasl.wrap(data)
        self.transport.write(''.join((pack("!i", len(encoded)), encoded)))
        self.transport.flush()
        self.__wbuf = BufferIO()

    def read(self, sz):
        ret = self.__rbuf.read(sz)
        if len(ret) != 0:
            return ret
        self._read_frame()
        return self.__rbuf.read(sz)

    def _read_frame(self):
        header = self.transport.readAll(4)
        length, = unpack('!i', header)
        encoded = self.transport.readAll(length)
        self.__rbuf = BufferIO(self.sasl.unwrap(encoded))

    def close(self):
        self.sasl.dispose()
        self.transport.close()

    @property
    def cstringio_buf(self):
        return self.__rbuf

    def cstringio_refill(self, prefix, reqlen):
        while len(prefix) < reqlen:
            self._read_frame()
            prefix += self.__rbuf.getvalue()
        self.__rbuf = BufferIO(prefix)
        return self.__rbuf
Ejemplo n.º 8
0
class TSaslClientTransport(TTransportBase, CReadableTransport):

    START = 1
    OK = 2
    BAD = 3
    ERROR = 4
    COMPLETE = 5

    def __init__(self, transport, host, service,
            mechanism='GSSAPI', **sasl_kwargs):

        from puresasl.client import SASLClient

        self.transport = transport
        self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs)

        self.__wbuf = StringIO()
        self.__rbuf = StringIO()

    def open(self):
        if not self.transport.isOpen():
            self.transport.open()

        self.send_sasl_msg(self.START, self.sasl.mechanism)
        self.send_sasl_msg(self.OK, self.sasl.process())

        while True:
            status, challenge = self.recv_sasl_msg()
            if status == self.OK:
                self.send_sasl_msg(self.OK, self.sasl.process(challenge))
            elif status == self.COMPLETE:
                if not self.sasl.complete:
                    raise TTransportException("The server erroneously indicated "
                            "that SASL negotiation was complete")
                else:
                    break
            else:
                raise TTransportException("Bad SASL negotiation status: %d (%s)"
                        % (status, challenge))

    def send_sasl_msg(self, status, body):
        header = struct.pack(">BI", status, len(body))
        self.transport.write(header + body)
        self.transport.flush()

    def recv_sasl_msg(self):
        header = self.transport.readAll(5)
        status, length = struct.unpack(">BI", header)
        if length > 0:
            payload = self.transport.readAll(length)
        else:
            payload = ""
        return status, payload

    def write(self, data):
        self.__wbuf.write(data)

    def flush(self):
        data = self.__wbuf.getvalue()
        encoded = self.sasl.wrap(data)
        # Note stolen from TFramedTransport:
        # N.B.: Doing this string concatenation is WAY cheaper than making
        # two separate calls to the underlying socket object. Socket writes in
        # Python turn out to be REALLY expensive, but it seems to do a pretty
        # good job of managing string buffer operations without excessive copies
        self.transport.write(''.join((struct.pack("!i", len(encoded)), encoded)))
        self.transport.flush()
        self.__wbuf = StringIO()

    def read(self, sz):
        ret = self.__rbuf.read(sz)
        if len(ret) != 0:
            return ret

        self._read_frame()
        return self.__rbuf.read(sz)

    def _read_frame(self):
        header = self.transport.readAll(4)
        length, = struct.unpack('!i', header)
        encoded = self.transport.readAll(length)
        self.__rbuf = StringIO(self.sasl.unwrap(encoded))

    def close(self):
        self.sasl.dispose()
        self.transport.close()

    # Implement the CReadableTransport interface.
    # Stolen shamelessly from TFramedTransport
    @property
    def cstringio_buf(self):
        return self.__rbuf

    def cstringio_refill(self, prefix, reqlen):
        # self.__rbuf will already be empty here because fastbinary doesn't
        # ask for a refill until the previous buffer is empty.  Therefore,
        # we can start reading new frames immediately.
        while len(prefix) < reqlen:
            self._read_frame()
            prefix += self.__rbuf.getvalue()
        self.__rbuf = StringIO(prefix)
        return self.__rbuf
Ejemplo n.º 9
0
class TSaslClientTransport(TTransportBase, CReadableTransport):
    """
    A SASL transport based on the pure-sasl library:
        https://github.com/thobbs/pure-sasl
    """

    START = 1
    OK = 2
    BAD = 3
    ERROR = 4
    COMPLETE = 5

    def __init__(self, transport, host, service, mechanism="GSSAPI", **sasl_kwargs):
        """
        transport: an underlying transport to use, typically just a TSocket
        host: the name of the server, from a SASL perspective
        service: the name of the server's service, from a SASL perspective
        mechanism: the name of the preferred mechanism to use
        All other kwargs will be passed to the puresasl.client.SASLClient
        constructor.
        """
        self.transport = transport
        self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs)
        self.__wbuf = StringIO()
        self.__rbuf = StringIO()

        # extremely awful hack, but you've got to do what you've got to do.
        # essentially "wrap" and "unwrap" are defined for the base Mechanism class and raise a NotImplementedError by
        # default, and PlainMechanism doesn't implement its own versions (lol).

    #        self.sasl._chosen_mech.wrap = lambda x: x
    #        self.sasl._chosen_mech.unwrap = lambda x: x

    def open(self):
        if not self.transport.isOpen():
            self.transport.open()

        self.send_sasl_msg(self.START, self.sasl.mechanism)
        self.send_sasl_msg(self.OK, self.sasl.process() or "")

        while True:
            status, challenge = self.recv_sasl_msg()
            if status == self.OK:
                self.send_sasl_msg(self.OK, self.sasl.process(challenge) or "")
            elif status == self.COMPLETE:
                # self.sasl.complete is not set for PLAIN authentication (trollface.jpg) so we have to skip this check
                #                break
                if not self.sasl.complete:
                    raise TTransportException("The server erroneously indicated " "that SASL negotiation was complete")
                else:
                    break
            else:
                raise TTransportException("Bad SASL negotiation status: %d (%s)" % (status, challenge))

    def send_sasl_msg(self, status, body):
        if body is None:
            body = ""
        header = pack(">BI", status, len(body))

        body = body if isinstance(body, bytes) else body.encode("utf-8")

        self.transport.write(header + body)
        self.transport.flush()

    def recv_sasl_msg(self):
        header = self.transport.readAll(5)
        status, length = unpack(">BI", header)
        if length > 0:
            payload = self.transport.readAll(length)
        else:
            payload = ""
        return status, payload

    def write(self, data):
        self.__wbuf.write(data)

    def flush(self):
        data = self.__wbuf.getvalue()
        encoded = self.sasl.wrap(data)
        self.transport.write("".join((pack("!i", len(encoded)), encoded)))
        self.transport.flush()
        self.__wbuf = StringIO()

    def read(self, sz):
        ret = self.__rbuf.read(sz)
        if len(ret) != 0:
            return ret

        self._read_frame()
        return self.__rbuf.read(sz)

    def _read_frame(self):
        header = self.transport.readAll(4)
        length, = unpack("!i", header)
        encoded = self.transport.readAll(length)
        self.__rbuf = StringIO(self.sasl.unwrap(encoded))

    def close(self):
        self.sasl.dispose()
        self.transport.close()

    # based on TFramedTransport
    @property
    def cstringio_buf(self):
        return self.__rbuf

    def cstringio_refill(self, prefix, reqlen):
        # self.__rbuf will already be empty here because fastbinary doesn't
        # ask for a refill until the previous buffer is empty.  Therefore,
        # we can start reading new frames immediately.
        while len(prefix) < reqlen:
            self._read_frame()
            prefix += self.__rbuf.getvalue()
        self.__rbuf = StringIO(prefix)
        return self.__rbuf
class TSaslClientTransport(TTransportBase):
    """
    SASL transport
    """

    START = 1
    OK = 2
    BAD = 3
    ERROR = 4
    COMPLETE = 5

    def __init__(self,
                 transport,
                 host,
                 service,
                 mechanism=six.u('GSSAPI'),
                 generate_tickets=False,
                 using_keytab=False,
                 principal=None,
                 keytab_file=None,
                 ccache_file=None,
                 password=None,
                 **sasl_kwargs):
        """
        transport: an underlying transport to use, typically just a TSocket
        host: the name of the server, from a SASL perspective
        service: the name of the server's service, from a SASL perspective
        mechanism: the name of the preferred mechanism to use
        All other kwargs will be passed to the puresasl.client.SASLClient
        constructor.
        """

        self.transport = transport

        if six.PY3:
            self._patch_pure_sasl()
        self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs)

        self.__wbuf = BytesIO()
        self.__rbuf = BytesIO()

        self.generate_tickets = generate_tickets
        if self.generate_tickets:
            self.krb_context = krbContext(using_keytab, principal, keytab_file,
                                          ccache_file, password)
            self.krb_context.init_with_keytab()

    def _patch_pure_sasl(self):
        ''' we need to patch pure_sasl to support python 3 '''
        puresasl.mechanisms.mechanisms['GSSAPI'] = CustomGSSAPIMechanism

    def is_open(self):
        return self.transport.is_open() and bool(self.sasl)

    @with_ticket
    def open(self):
        if not self.transport.is_open():
            self.transport.open()

        self.send_sasl_msg(self.START, self.sasl.mechanism.encode('utf8'))
        self.send_sasl_msg(self.OK, self.sasl.process())

        while True:
            status, challenge = self.recv_sasl_msg()
            if status == self.OK:
                self.send_sasl_msg(self.OK, self.sasl.process(challenge))
            elif status == self.COMPLETE:
                if not self.sasl.complete:
                    raise TTransportException(
                        TTransportException.NOT_OPEN,
                        "The server erroneously indicated "
                        "that SASL negotiation was complete")
                else:
                    break
            else:
                raise TTransportException(
                    TTransportException.NOT_OPEN,
                    "Bad SASL negotiation status: %d (%s)" %
                    (status, challenge))

    def send_sasl_msg(self, status, body):
        '''
        body:bytes
        '''
        header = pack(">BI", status, len(body))
        self.transport.write(header + body)
        self.transport.flush()

    def recv_sasl_msg(self):
        header = readall(self.transport.read, 5)
        status, length = unpack(">BI", header)
        if length > 0:
            payload = readall(self.transport.read, length)
        else:
            payload = ""
        return status, payload

    @with_ticket
    def write(self, data):
        self.__wbuf.write(data)

    @with_ticket
    def flush(self):
        data = self.__wbuf.getvalue()
        encoded = self.sasl.wrap(data)
        if six.PY2:
            self.transport.write(''.join([pack("!i", len(encoded)), encoded]))
        else:
            self.transport.write(b''.join((pack("!i", len(encoded)), encoded)))
        self.transport.flush()
        self.__wbuf = BytesIO()

    def read(self, sz):
        ret = self.__rbuf.read(sz)
        if len(ret) != 0 or sz == 0:
            return ret

        self._read_frame()
        return self.__rbuf.read(sz)

    def _read_frame(self):
        header = readall(self.transport.read, 4)
        length, = unpack('!i', header)
        encoded = readall(self.transport.read, length)
        self.__rbuf = BytesIO(self.sasl.unwrap(encoded))

    def close(self):
        self.sasl.dispose()
        self.transport.close()
Ejemplo n.º 11
0
class TSaslClientTransport(TTransportBase, CReadableTransport):

    START = 1
    OK = 2
    BAD = 3
    ERROR = 4
    COMPLETE = 5

    def __init__(self, transport, host, service,
            mechanism='GSSAPI', **sasl_kwargs):

        from puresasl.client import SASLClient

        self.transport = transport
        self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs)

        self.__wbuf = StringIO()
        self.__rbuf = StringIO()

    def open(self):
        if not self.transport.isOpen():
            self.transport.open()

        self.send_sasl_msg(self.START, self.sasl.mechanism)
        self.send_sasl_msg(self.OK, self.sasl.process())

        while True:
            status, challenge = self.recv_sasl_msg()
            if status == self.OK:
                self.send_sasl_msg(self.OK, self.sasl.process(challenge))
            elif status == self.COMPLETE:
                if not self.sasl.complete:
                    raise TTransportException("The server erroneously indicated "
                            "that SASL negotiation was complete")
                else:
                    break
            else:
                raise TTransportException("Bad SASL negotiation status: %d (%s)"
                        % (status, challenge))

    def send_sasl_msg(self, status, body):
        header = struct.pack(">BI", status, len(body))
        self.transport.write(header + body)
        self.transport.flush()

    def recv_sasl_msg(self):
        header = self.transport.readAll(5)
        status, length = struct.unpack(">BI", header)
        if length > 0:
            payload = self.transport.readAll(length)
        else:
            payload = ""
        return status, payload

    def write(self, data):
        self.__wbuf.write(data)

    def flush(self):
        data = self.__wbuf.getvalue()
        encoded = self.sasl.wrap(data)
        # Note stolen from TFramedTransport:
        # N.B.: Doing this string concatenation is WAY cheaper than making
        # two separate calls to the underlying socket object. Socket writes in
        # Python turn out to be REALLY expensive, but it seems to do a pretty
        # good job of managing string buffer operations without excessive copies
        self.transport.write(''.join((struct.pack("!i", len(encoded)), encoded)))
        self.transport.flush()
        self.__wbuf = StringIO()

    def read(self, sz):
        ret = self.__rbuf.read(sz)
        if len(ret) != 0:
            return ret

        self._read_frame()
        return self.__rbuf.read(sz)

    def _read_frame(self):
        header = self.transport.readAll(4)
        length, = struct.unpack('!i', header)
        encoded = self.transport.readAll(length)
        self.__rbuf = StringIO(self.sasl.unwrap(encoded))

    def close(self):
        self.sasl.dispose()
        self.transport.close()

    # Implement the CReadableTransport interface.
    # Stolen shamelessly from TFramedTransport
    @property
    def cstringio_buf(self):
        return self.__rbuf

    def cstringio_refill(self, prefix, reqlen):
        # self.__rbuf will already be empty here because fastbinary doesn't
        # ask for a refill until the previous buffer is empty.  Therefore,
        # we can start reading new frames immediately.
        while len(prefix) < reqlen:
            self._read_frame()
            prefix += self.__rbuf.getvalue()
        self.__rbuf = StringIO(prefix)
        return self.__rbuf
Ejemplo n.º 12
0
class ThriftSASLClientProtocol(ThriftClientProtocol):

    START = 1
    OK = 2
    BAD = 3
    ERROR = 4
    COMPLETE = 5

    MAX_LENGTH = 2 ** 31 - 1

    def __init__(self, client_class, iprot_factory, oprot_factory=None,
            host=None, service=None, mechanism='GSSAPI', **sasl_kwargs):
        ThriftClientProtocol.__init__(self, client_class, iprot_factory, oprot_factory)

        self._sasl_negotiation_deferred = None
        self._sasl_negotiation_status = None
        self.client = None

        if host is not None:
            self.createSASLClient(host, service, mechanism, **sasl_kwargs)

    def createSASLClient(self, host, service, mechanism, **kwargs):
        self.sasl = SASLClient(host, service, mechanism, **kwargs)

    def dispatch(self, msg):
        encoded = self.sasl.wrap(msg)
        len_and_encoded = ''.join((struct.pack('!i', len(encoded)), encoded))
        ThriftClientProtocol.dispatch(self, len_and_encoded)

    @defer.inlineCallbacks
    def connectionMade(self):
        self._sendSASLMessage(self.START, self.sasl.mechanism)
        initial_message = yield deferToThread(self.sasl.process)
        self._sendSASLMessage(self.OK, initial_message)

        while True:
            status, challenge = yield self._receiveSASLMessage()
            if status == self.OK:
                response = yield deferToThread(self.sasl.process, challenge)
                self._sendSASLMessage(self.OK, response)
            elif status == self.COMPLETE:
                if not self.sasl.complete:
                    msg = "The server erroneously indicated that SASL " \
                          "negotiation was complete"
                    raise TTransportException(msg, message=msg)
                else:
                    break
            else:
                msg = "Bad SASL negotiation status: %d (%s)" % (status, challenge)
                raise TTransportException(msg, message=msg)

        self._sasl_negotiation_deferred = None
        ThriftClientProtocol.connectionMade(self)

    def _sendSASLMessage(self, status, body):
        if body is None:
            body = ""
        header = struct.pack(">BI", status, len(body))
        self.transport.write(header + body)

    def _receiveSASLMessage(self):
        self._sasl_negotiation_deferred = defer.Deferred()
        self._sasl_negotiation_status = None
        return self._sasl_negotiation_deferred

    def connectionLost(self, reason=connectionDone):
        if self.client:
            ThriftClientProtocol.connectionLost(self, reason)

    def dataReceived(self, data):
        if self._sasl_negotiation_deferred:
            # we got a sasl challenge in the format (status, length, challenge)
            # save the status, let IntNStringReceiver piece the challenge data together
            self._sasl_negotiation_status, = struct.unpack("B", data[0])
            ThriftClientProtocol.dataReceived(self, data[1:])
        else:
            # normal frame, let IntNStringReceiver piece it together
            ThriftClientProtocol.dataReceived(self, data)

    def stringReceived(self, frame):
        if self._sasl_negotiation_deferred:
            # the frame is just a SASL challenge
            response = (self._sasl_negotiation_status, frame)
            self._sasl_negotiation_deferred.callback(response)
        else:
            # there's a second 4 byte length prefix inside the frame
            decoded_frame = self.sasl.unwrap(frame[4:])
            ThriftClientProtocol.stringReceived(self, decoded_frame)
Ejemplo n.º 13
0
class ThriftSASLClientProtocol(ThriftClientProtocol):

    START = 1
    OK = 2
    BAD = 3
    ERROR = 4
    COMPLETE = 5

    MAX_LENGTH = 2**31 - 1

    def __init__(self,
                 client_class,
                 iprot_factory,
                 oprot_factory=None,
                 host=None,
                 service=None,
                 mechanism='GSSAPI',
                 **sasl_kwargs):
        ThriftClientProtocol.__init__(self, client_class, iprot_factory,
                                      oprot_factory)

        self._sasl_negotiation_deferred = None
        self._sasl_negotiation_status = None
        self.client = None

        if host is not None:
            self.createSASLClient(host, service, mechanism, **sasl_kwargs)

    def createSASLClient(self, host, service, mechanism, **kwargs):
        self.sasl = SASLClient(host, service, mechanism, **kwargs)

    def dispatch(self, msg):
        encoded = self.sasl.wrap(msg)
        len_and_encoded = ''.join((struct.pack('!i', len(encoded)), encoded))
        ThriftClientProtocol.dispatch(self, len_and_encoded)

    @defer.inlineCallbacks
    def connectionMade(self):
        self._sendSASLMessage(self.START, self.sasl.mechanism)
        initial_message = yield deferToThread(self.sasl.process)
        self._sendSASLMessage(self.OK, initial_message)

        while True:
            status, challenge = yield self._receiveSASLMessage()
            if status == self.OK:
                response = yield deferToThread(self.sasl.process, challenge)
                self._sendSASLMessage(self.OK, response)
            elif status == self.COMPLETE:
                if not self.sasl.complete:
                    msg = "The server erroneously indicated that SASL " \
                          "negotiation was complete"
                    raise TTransportException(msg, message=msg)
                else:
                    break
            else:
                msg = "Bad SASL negotiation status: %d (%s)" % (status,
                                                                challenge)
                raise TTransportException(msg, message=msg)

        self._sasl_negotiation_deferred = None
        ThriftClientProtocol.connectionMade(self)

    def _sendSASLMessage(self, status, body):
        if body is None:
            body = ""
        header = struct.pack(">BI", status, len(body))
        self.transport.write(header + body)

    def _receiveSASLMessage(self):
        self._sasl_negotiation_deferred = defer.Deferred()
        self._sasl_negotiation_status = None
        return self._sasl_negotiation_deferred

    def connectionLost(self, reason=connectionDone):
        if self.client:
            ThriftClientProtocol.connectionLost(self, reason)

    def dataReceived(self, data):
        if self._sasl_negotiation_deferred:
            # we got a sasl challenge in the format (status, length, challenge)
            # save the status, let IntNStringReceiver piece the challenge data together
            self._sasl_negotiation_status, = struct.unpack("B", data[0])
            ThriftClientProtocol.dataReceived(self, data[1:])
        else:
            # normal frame, let IntNStringReceiver piece it together
            ThriftClientProtocol.dataReceived(self, data)

    def stringReceived(self, frame):
        if self._sasl_negotiation_deferred:
            # the frame is just a SASL challenge
            response = (self._sasl_negotiation_status, frame)
            self._sasl_negotiation_deferred.callback(response)
        else:
            # there's a second 4 byte length prefix inside the frame
            decoded_frame = self.sasl.unwrap(frame[4:])
            ThriftClientProtocol.stringReceived(self, decoded_frame)
Ejemplo n.º 14
0
class LDAPSocket(object):
    """Holds a connection to an LDAP server.

    :param str host_uri: "scheme://netloc" to connect to
    :param int connect_timeout: Number of seconds to wait for connection to be accepted
    :param bool ssl_verify: Validate the certificate and hostname on an SSL/TLS connection
    :param str ssl_ca_file: Path to PEM-formatted concatenated CA certficates file
    :param str ssl_ca_path: Path to directory with CA certs under hashed file names. See
                            https://www.openssl.org/docs/man1.1.0/ssl/SSL_CTX_load_verify_locations.html for more
                            information about the format of this directory.
    :param ssl_ca_data: An ASCII string of one or more PEM-encoded certs or a bytes object containing DER-encoded
                        certificates.
    :type ssl_ca_data: str or bytes
    """

    RECV_BUFFER = 4096

    # For ldapi:/// try to connect to these socket files in order
    # Globs must match exactly one result
    LDAPI_SOCKET_PATHS = ['/var/run/ldapi', '/var/run/slapd/ldapi', '/var/run/slapd-*.socket']

    # OIDs of unsolicited messages

    OID_DISCONNECTION_NOTICE = '1.3.6.1.4.1.1466.20036'  # RFC 4511 sec 4.4.1 Notice of Disconnection

    def __init__(self, host_uri, connect_timeout=5, ssl_verify=True, ssl_ca_file=None, ssl_ca_path=None,
                 ssl_ca_data=None):

        self._prop_init(connect_timeout)
        self._uri_connect(host_uri, ssl_verify, ssl_ca_file, ssl_ca_path, ssl_ca_data)

    def _prop_init(self, connect_timeout=5):
        # get socket ID number
        global _next_sock_id
        self.ID = _next_sock_id
        _next_sock_id += 1

        # misc init
        self._message_queues = {}
        self._next_message_id = 1
        self._sasl_client = None

        self.refcount = 0
        self.bound = False
        self.unbound = False
        self.abandoned_mids = []
        self.started_tls = False
        self.connect_timeout = connect_timeout

    def _parse_uri(self, host_uri):
        # parse host_uri
        parts = host_uri.split('://')
        if len(parts) == 1:
            netloc = unquote(parts[0])
            if netloc[0] == '/':
                scheme = 'ldapi'
            else:
                scheme = 'ldap'
        elif len(parts) == 2:
            scheme = parts[0]
            netloc = unquote(parts[1])
        else:
            raise LDAPError('Invalid host_uri')
        self.uri = '{0}://{1}'.format(scheme, netloc)
        return scheme, netloc

    def _uri_connect(self, host_uri, ssl_verify, ssl_ca_file, ssl_ca_path, ssl_ca_data):
        # connect
        scheme, netloc = self._parse_uri(host_uri)
        logger.info('Connecting to {0} on #{1}'.format(self.uri, self.ID))
        if scheme == 'ldap':
            self._inet_connect(netloc, 389)
        elif scheme == 'ldaps':
            self._inet_connect(netloc, 636)
            self.start_tls(ssl_verify, ssl_ca_file, ssl_ca_path, ssl_ca_data)
            logger.info('Connected with TLS on #{0}'.format(self.ID))
        elif scheme == 'ldapi':
            if not _have_unix_socket:
                raise LDAPError('Unix sockets are not supported on your platform, please choose a protocol other'
                                'than ldapi')
            self.sock_path = None
            self._sock = socket(AF_UNIX)
            self.host = 'localhost'

            if netloc == '/':
                for sockGlob in LDAPSocket.LDAPI_SOCKET_PATHS:
                    fn = glob(sockGlob)
                    if not fn:
                        continue
                    if len(fn) > 1:
                        logger.debug('Multiple results for glob {0}'.format(sockGlob))
                        continue
                    fn = fn[0]
                    try:
                        self._connect(fn)
                        self.sock_path = fn
                        break
                    except SocketError:
                        continue
                if self.sock_path is None:
                    raise LDAPConnectionError('Could not find any local LDAPI unix socket - full '
                                              'socket path must be supplied in URI')
            else:
                try:
                    self._connect(netloc)
                    self.sock_path = netloc
                except SocketError as e:
                    raise LDAPConnectionError('failed connect to unix socket {0} - {1} ({2})'.format(
                        netloc, e.strerror, e.errno
                    ))

            logger.debug('Connected to unix socket {0} on #{1}'.format(self.sock_path, self.ID))
        else:
            raise LDAPError('Unsupported scheme "{0}"'.format(scheme))

    def _connect(self, addr):
        self._sock.settimeout(self.connect_timeout)
        self._sock.connect(addr)
        self._sock.settimeout(None)

    def _inet_connect(self, netloc, default_port):
        ap = netloc.rsplit(':', 1)
        self.host = ap[0]
        if len(ap) == 1:
            port = default_port
        else:
            port = int(ap[1])
        try:
            self._sock = create_connection((self.host, port), self.connect_timeout)
            logger.debug('Connected to {0}:{1} on #{2}'.format(self.host, port, self.ID))
        except SocketError as e:
            raise LDAPConnectionError('failed connect to {0}:{1} - {2} ({3})'.format(
                                      self.host, port, e.strerror, e.errno))

    def start_tls(self, verify=True, ca_file=None, ca_path=None, ca_data=None):
        """Install TLS layer on this socket connection.

        :param bool verify: Validate the certificate and hostname on an SSL/TLS connection
        :param str ca_file: Path to PEM-formatted concatenated CA certficates file
        :param str ca_path: Path to directory with CA certs under hashed file names. See
                            https://www.openssl.org/docs/man1.1.0/ssl/SSL_CTX_load_verify_locations.html for more
                            information about the format of this directory.
        :param ca_data: An ASCII string of one or more PEM-encoded certs or a bytes object containing DER-encoded
                        certificates.
        :type ca_data: str or bytes
        """
        if self.started_tls:
            raise LDAPError('TLS layer already installed')

        if verify:
            verify_mode = ssl.CERT_REQUIRED
        else:
            verify_mode = ssl.CERT_NONE

        try:
            proto = ssl.PROTOCOL_TLS
        except AttributeError:
            proto = ssl.PROTOCOL_SSLv23

        try:
            ctx = ssl.SSLContext(proto)
            ctx.verify_mode = verify_mode
            ctx.check_hostname = False  # we do this ourselves
            if verify:
                ctx.load_default_certs()
            if ca_file or ca_path or ca_data:
                ctx.load_verify_locations(cafile=ca_file, capath=ca_path, cadata=ca_data)
            self._sock = ctx.wrap_socket(self._sock)
        except AttributeError:
            # SSLContext wasn't added until 2.7.9
            if ca_path or ca_data:
                raise RuntimeError('python version >= 2.7.9 required for SSL ca_path/ca_data')

            self._sock = ssl.wrap_socket(self._sock, ca_certs=ca_file, cert_reqs=verify_mode, ssl_version=proto)

        if verify:
            cert = self._sock.getpeercert()
            cert_cn = dict([e[0] for e in cert['subject']])['commonName']
            self.check_hostname(cert_cn, cert)
        else:
            logger.debug('Skipping hostname validation')
        self.started_tls = True
        logger.debug('Installed TLS layer on #{0}'.format(self.ID))

    def check_hostname(self, cert_cn, cert):
        """SSL check_hostname according to RFC 4513 sec 3.1.3. Compares supplied values against ``self.host`` to
        determine the validity of the cert.

        :param str cert_cn: The common name of the cert
        :param dict cert: A dictionary representing the rest of the cert. Checks key subjectAltNames for a list of
                          (type, value) tuples, where type is 'DNS' or 'IP'. DNS supports leading wildcard.
        :rtype: None
        :raises LDAPConnectionError: if no supplied values match ``self.host``
        """
        if self.host == cert_cn:
            logger.debug('Matched server identity to cert commonName')
        else:
            valid = False
            tried = [cert_cn]
            for type, value in cert.get('subjectAltName', []):
                if type == 'DNS' and value.startswith('*.'):
                    valid = self.host.endswith(value[1:])
                else:
                    valid = (self.host == value)
                tried.append(value)
                if valid:
                    logger.debug('Matched server identity to cert {0} subjectAltName'.format(type))
                    break
            if not valid:
                raise LDAPConnectionError('Server identity "{0}" does not match any cert names: {1}'.format(
                    self.host, ', '.join(tried)))

    def sasl_init(self, mechs, **props):
        """Initialize a :class:`.puresasl.client.SASLClient`"""
        self._sasl_client = SASLClient(self.host, 'ldap', **props)
        self._sasl_client.choose_mechanism(mechs)

    def _has_sasl_client(self):
        return self._sasl_client is not None

    def _require_sasl_client(self):
        if not self._has_sasl_client():
            raise LDAPSASLError('SASL init not complete')

    @property
    def sasl_qop(self):
        """Obtain the chosen quality of protection"""
        self._require_sasl_client()
        return self._sasl_client.qop

    @property
    def sasl_mech(self):
        """Obtain the chosen mechanism"""
        self._require_sasl_client()
        mech = self._sasl_client.mechanism
        if mech is None:
            raise LDAPSASLError('SASL init not complete - no mech chosen')
        else:
            return mech

    def sasl_process_auth_challenge(self, challenge):
        """Process an auth challenge and return the correct response"""
        self._require_sasl_client()
        return self._sasl_client.process(challenge)

    def _prep_message(self, op, obj, controls=None):
        """Prepare a message for transmission"""
        mid = self._next_message_id
        self._next_message_id += 1
        lm = pack(mid, op, obj, controls)
        raw = ber_encode(lm)
        if self._has_sasl_client():
            raw = self._sasl_client.wrap(raw)
        return mid, raw

    def send_message(self, op, obj, controls=None):
        """Create and send an LDAPMessage given an operation name and a corresponding object.

        Operation names must be defined as component names in laurelin.ldap.rfc4511.ProtocolOp and
        the object must be of the corresponding type.

        :param str op: The protocol operation name
        :param object obj: The associated protocol object (see :class:`.rfc4511.ProtocolOp` for mapping.
        :param controls: Any request controls for the message
        :type controls: rfc4511.Controls or None
        :return: The message ID for this message
        :rtype: int
        """
        mid, raw = self._prep_message(op, obj, controls)
        self._sock.sendall(raw)
        return mid

    def recv_one(self, want_message_id):
        """Get the next message with ``want_message_id`` being sent by the server

        :param int want_message_id: The desired message ID.
        :return: The LDAP message
        :rtype: rfc4511.LDAPMessage
        """
        return next(self.recv_messages(want_message_id))

    def recv_messages(self, want_message_id):
        """Iterate all messages with ``want_message_id`` being sent by the server.

        :param int want_message_id: The desired message ID.
        :return: An iterator over :class:`.rfc4511.LDAPMessage`.
        """
        flush_queue = True
        raw = b''
        while True:
            if flush_queue:
                if want_message_id in self._message_queues:
                    q = self._message_queues[want_message_id]
                    while True:
                        if len(q) == 0:
                            break
                        obj = q.popleft()
                        if len(q) == 0:
                            del self._message_queues[want_message_id]
                        yield obj
            else:
                flush_queue = True
            if want_message_id in self.abandoned_mids:
                return
            try:
                newraw = self._sock.recv(LDAPSocket.RECV_BUFFER)
                if self._has_sasl_client():
                    newraw = self._sasl_client.unwrap(newraw)
                raw += newraw
                while len(raw) > 0:
                    response, raw = ber_decode(raw, asn1Spec=LDAPMessage())
                    have_message_id = response.getComponentByName('messageID')
                    if want_message_id == have_message_id:
                        yield response
                    elif have_message_id == 0:
                        msg = 'Received unsolicited message (default message - should never be seen)'
                        try:
                            mid, xr, ctrls = unpack('extendedResp', response)
                            res_code = xr.getComponentByName('resultCode')
                            xr_oid = six.text_type(xr.getComponentByName('responseName'))
                            if xr_oid == LDAPSocket.OID_DISCONNECTION_NOTICE:
                                mtype = 'Notice of Disconnection'
                            else:
                                mtype = 'Unhandled ({0})'.format(xr_oid)
                            diag = xr.getComponentByName('diagnosticMessage')
                            msg = 'Got unsolicited message: {0}: {1}: {2}'.format(mtype, res_code, diag)
                            if res_code == ResultCode('protocolError'):
                                msg += (' (This may indicate an incompatability between laurelin-ldap and your server '
                                        'distribution)')
                            elif res_code == ResultCode('strongerAuthRequired'):
                                # this is a direct quote from RFC 4511 sec 4.4.1
                                msg += (' (The server has detected that an established security association between the'
                                        ' client and server has unexpectedly failed or been compromised)')
                        except UnexpectedResponseType:
                            msg = 'Unhandled unsolicited message from server'
                        finally:
                            raise LDAPUnsolicitedMessage(response, msg)
                    else:
                        if have_message_id not in self._message_queues:
                            self._message_queues[have_message_id] = deque()
                        self._message_queues[have_message_id].append(response)
            except SubstrateUnderrunError:
                flush_queue = False
                continue

    def close(self):
        """Close the low-level socket connection."""
        return self._sock.close()
Ejemplo n.º 15
0
class LDAPSocket(object):
    """Holds a connection to an LDAP server.

    :param str host_uri: "scheme://netloc" to connect to
    :param int connect_timeout: Number of seconds to wait for connection to be accepted
    :param bool ssl_verify: Validate the certificate and hostname on an SSL/TLS connection
    :param str ssl_ca_file: Path to PEM-formatted concatenated CA certficates file
    :param str ssl_ca_path: Path to directory with CA certs under hashed file names. See
                            https://www.openssl.org/docs/man1.1.0/ssl/SSL_CTX_load_verify_locations.html for more
                            information about the format of this directory.
    :param ssl_ca_data: An ASCII string of one or more PEM-encoded certs or a bytes object containing DER-encoded
                        certificates.
    :type ssl_ca_data: str or bytes
    """

    RECV_BUFFER = 4096

    # For ldapi:/// try to connect to these socket files in order
    # Globs must match exactly one result
    LDAPI_SOCKET_PATHS = [
        '/var/run/ldapi', '/var/run/slapd/ldapi', '/var/run/slapd-*.socket'
    ]

    # OIDs of unsolicited messages

    OID_DISCONNECTION_NOTICE = '1.3.6.1.4.1.1466.20036'  # RFC 4511 sec 4.4.1 Notice of Disconnection

    def __init__(self,
                 host_uri,
                 connect_timeout=5,
                 ssl_verify=True,
                 ssl_ca_file=None,
                 ssl_ca_path=None,
                 ssl_ca_data=None):

        self._prop_init(connect_timeout)
        self._uri_connect(host_uri, ssl_verify, ssl_ca_file, ssl_ca_path,
                          ssl_ca_data)

    def _prop_init(self, connect_timeout=5):
        # get socket ID number
        global _next_sock_id
        self.ID = _next_sock_id
        _next_sock_id += 1

        # misc init
        self._message_queues = {}
        self._next_message_id = 1
        self._sasl_client = None

        self.refcount = 0
        self.bound = False
        self.unbound = False
        self.abandoned_mids = []
        self.started_tls = False
        self.connect_timeout = connect_timeout

    def _parse_uri(self, host_uri):
        # parse host_uri
        parts = host_uri.split('://')
        if len(parts) == 1:
            netloc = unquote(parts[0])
            if netloc[0] == '/':
                scheme = 'ldapi'
            else:
                scheme = 'ldap'
        elif len(parts) == 2:
            scheme = parts[0]
            netloc = unquote(parts[1])
        else:
            raise LDAPError('Invalid host_uri')
        self.uri = '{0}://{1}'.format(scheme, netloc)
        return scheme, netloc

    def _uri_connect(self, host_uri, ssl_verify, ssl_ca_file, ssl_ca_path,
                     ssl_ca_data):
        # connect
        scheme, netloc = self._parse_uri(host_uri)
        logger.info('Connecting to {0} on #{1}'.format(self.uri, self.ID))
        if scheme == 'ldap':
            self._inet_connect(netloc, 389)
        elif scheme == 'ldaps':
            self._inet_connect(netloc, 636)
            self.start_tls(ssl_verify, ssl_ca_file, ssl_ca_path, ssl_ca_data)
            logger.info('Connected with TLS on #{0}'.format(self.ID))
        elif scheme == 'ldapi':
            if not _have_unix_socket:
                raise LDAPError(
                    'Unix sockets are not supported on your platform, please choose a protocol other'
                    'than ldapi')
            self.sock_path = None
            self._sock = socket(AF_UNIX)
            self.host = 'localhost'

            if netloc == '/':
                for sockGlob in LDAPSocket.LDAPI_SOCKET_PATHS:
                    fn = glob(sockGlob)
                    if not fn:
                        continue
                    if len(fn) > 1:
                        logger.debug(
                            'Multiple results for glob {0}'.format(sockGlob))
                        continue
                    fn = fn[0]
                    try:
                        self._connect(fn)
                        self.sock_path = fn
                        break
                    except SocketError:
                        continue
                if self.sock_path is None:
                    raise LDAPConnectionError(
                        'Could not find any local LDAPI unix socket - full '
                        'socket path must be supplied in URI')
            else:
                try:
                    self._connect(netloc)
                    self.sock_path = netloc
                except SocketError as e:
                    raise LDAPConnectionError(
                        'failed connect to unix socket {0} - {1} ({2})'.format(
                            netloc, e.strerror, e.errno))

            logger.debug('Connected to unix socket {0} on #{1}'.format(
                self.sock_path, self.ID))
        else:
            raise LDAPError('Unsupported scheme "{0}"'.format(scheme))

    def _connect(self, addr):
        self._sock.settimeout(self.connect_timeout)
        self._sock.connect(addr)
        self._sock.settimeout(None)

    def _inet_connect(self, netloc, default_port):
        ap = netloc.rsplit(':', 1)
        self.host = ap[0]
        if len(ap) == 1:
            port = default_port
        else:
            port = int(ap[1])
        try:
            self._sock = create_connection((self.host, port),
                                           self.connect_timeout)
            logger.debug('Connected to {0}:{1} on #{2}'.format(
                self.host, port, self.ID))
        except SocketError as e:
            raise LDAPConnectionError(
                'failed connect to {0}:{1} - {2} ({3})'.format(
                    self.host, port, e.strerror, e.errno))

    def start_tls(self, verify=True, ca_file=None, ca_path=None, ca_data=None):
        """Install TLS layer on this socket connection.

        :param bool verify: Validate the certificate and hostname on an SSL/TLS connection
        :param str ca_file: Path to PEM-formatted concatenated CA certficates file
        :param str ca_path: Path to directory with CA certs under hashed file names. See
                            https://www.openssl.org/docs/man1.1.0/ssl/SSL_CTX_load_verify_locations.html for more
                            information about the format of this directory.
        :param ca_data: An ASCII string of one or more PEM-encoded certs or a bytes object containing DER-encoded
                        certificates.
        :type ca_data: str or bytes
        """
        if self.started_tls:
            raise LDAPError('TLS layer already installed')

        if verify:
            verify_mode = ssl.CERT_REQUIRED
        else:
            verify_mode = ssl.CERT_NONE

        try:
            proto = ssl.PROTOCOL_TLS
        except AttributeError:
            proto = ssl.PROTOCOL_SSLv23

        try:
            ctx = ssl.SSLContext(proto)
            ctx.verify_mode = verify_mode
            ctx.check_hostname = False  # we do this ourselves
            if verify:
                ctx.load_default_certs()
            if ca_file or ca_path or ca_data:
                ctx.load_verify_locations(cafile=ca_file,
                                          capath=ca_path,
                                          cadata=ca_data)
            self._sock = ctx.wrap_socket(self._sock)
        except AttributeError:
            # SSLContext wasn't added until 2.7.9
            if ca_path or ca_data:
                raise RuntimeError(
                    'python version >= 2.7.9 required for SSL ca_path/ca_data')

            self._sock = ssl.wrap_socket(self._sock,
                                         ca_certs=ca_file,
                                         cert_reqs=verify_mode,
                                         ssl_version=proto)

        if verify:
            cert = self._sock.getpeercert()
            cert_cn = dict([e[0] for e in cert['subject']])['commonName']
            self.check_hostname(cert_cn, cert)
        else:
            logger.debug('Skipping hostname validation')
        self.started_tls = True
        logger.debug('Installed TLS layer on #{0}'.format(self.ID))

    def check_hostname(self, cert_cn, cert):
        """SSL check_hostname according to RFC 4513 sec 3.1.3. Compares supplied values against ``self.host`` to
        determine the validity of the cert.

        :param str cert_cn: The common name of the cert
        :param dict cert: A dictionary representing the rest of the cert. Checks key subjectAltNames for a list of
                          (type, value) tuples, where type is 'DNS' or 'IP'. DNS supports leading wildcard.
        :rtype: None
        :raises LDAPConnectionError: if no supplied values match ``self.host``
        """
        if self.host == cert_cn:
            logger.debug('Matched server identity to cert commonName')
        else:
            valid = False
            tried = [cert_cn]
            for type, value in cert.get('subjectAltName', []):
                if type == 'DNS' and value.startswith('*.'):
                    valid = self.host.endswith(value[1:])
                else:
                    valid = (self.host == value)
                tried.append(value)
                if valid:
                    logger.debug(
                        'Matched server identity to cert {0} subjectAltName'.
                        format(type))
                    break
            if not valid:
                raise LDAPConnectionError(
                    'Server identity "{0}" does not match any cert names: {1}'.
                    format(self.host, ', '.join(tried)))

    def sasl_init(self, mechs, **props):
        """Initialize a :class:`.puresasl.client.SASLClient`"""
        self._sasl_client = SASLClient(self.host, 'ldap', **props)
        self._sasl_client.choose_mechanism(mechs)

    def _has_sasl_client(self):
        return self._sasl_client is not None

    def _require_sasl_client(self):
        if not self._has_sasl_client():
            raise LDAPSASLError('SASL init not complete')

    @property
    def sasl_qop(self):
        """Obtain the chosen quality of protection"""
        self._require_sasl_client()
        return self._sasl_client.qop

    @property
    def sasl_mech(self):
        """Obtain the chosen mechanism"""
        self._require_sasl_client()
        mech = self._sasl_client.mechanism
        if mech is None:
            raise LDAPSASLError('SASL init not complete - no mech chosen')
        else:
            return mech

    def sasl_process_auth_challenge(self, challenge):
        """Process an auth challenge and return the correct response"""
        self._require_sasl_client()
        return self._sasl_client.process(challenge)

    def _prep_message(self, op, obj, controls=None):
        """Prepare a message for transmission"""
        mid = self._next_message_id
        self._next_message_id += 1
        lm = pack(mid, op, obj, controls)
        raw = ber_encode(lm)
        if self._has_sasl_client():
            raw = self._sasl_client.wrap(raw)
        return mid, raw

    def send_message(self, op, obj, controls=None):
        """Create and send an LDAPMessage given an operation name and a corresponding object.

        Operation names must be defined as component names in laurelin.ldap.rfc4511.ProtocolOp and
        the object must be of the corresponding type.

        :param str op: The protocol operation name
        :param object obj: The associated protocol object (see :class:`.rfc4511.ProtocolOp` for mapping.
        :param controls: Any request controls for the message
        :type controls: rfc4511.Controls or None
        :return: The message ID for this message
        :rtype: int
        """
        mid, raw = self._prep_message(op, obj, controls)
        self._sock.sendall(raw)
        return mid

    def recv_one(self, want_message_id):
        """Get the next message with ``want_message_id`` being sent by the server

        :param int want_message_id: The desired message ID.
        :return: The LDAP message
        :rtype: rfc4511.LDAPMessage
        """
        return next(self.recv_messages(want_message_id))

    def recv_messages(self, want_message_id):
        """Iterate all messages with ``want_message_id`` being sent by the server.

        :param int want_message_id: The desired message ID.
        :return: An iterator over :class:`.rfc4511.LDAPMessage`.
        """
        flush_queue = True
        raw = b''
        while True:
            if flush_queue:
                if want_message_id in self._message_queues:
                    q = self._message_queues[want_message_id]
                    while True:
                        if len(q) == 0:
                            break
                        obj = q.popleft()
                        if len(q) == 0:
                            del self._message_queues[want_message_id]
                        yield obj
            else:
                flush_queue = True
            if want_message_id in self.abandoned_mids:
                raise StopIteration()
            try:
                newraw = self._sock.recv(LDAPSocket.RECV_BUFFER)
                if self._has_sasl_client():
                    newraw = self._sasl_client.unwrap(newraw)
                raw += newraw
                while len(raw) > 0:
                    response, raw = ber_decode(raw, asn1Spec=LDAPMessage())
                    have_message_id = response.getComponentByName('messageID')
                    if want_message_id == have_message_id:
                        yield response
                    elif have_message_id == 0:
                        msg = 'Received unsolicited message (default message - should never be seen)'
                        try:
                            mid, xr, ctrls = unpack('extendedResp', response)
                            res_code = xr.getComponentByName('resultCode')
                            xr_oid = six.text_type(
                                xr.getComponentByName('responseName'))
                            if xr_oid == LDAPSocket.OID_DISCONNECTION_NOTICE:
                                mtype = 'Notice of Disconnection'
                            else:
                                mtype = 'Unhandled ({0})'.format(xr_oid)
                            diag = xr.getComponentByName('diagnosticMessage')
                            msg = 'Got unsolicited message: {0}: {1}: {2}'.format(
                                mtype, res_code, diag)
                            if res_code == ResultCode('protocolError'):
                                msg += (
                                    ' (This may indicate an incompatability between laurelin-ldap and your server '
                                    'distribution)')
                            elif res_code == ResultCode(
                                    'strongerAuthRequired'):
                                # this is a direct quote from RFC 4511 sec 4.4.1
                                msg += (
                                    ' (The server has detected that an established security association between the'
                                    'client and server has unexpectedly failed or been compromised)'
                                )
                        except UnexpectedResponseType:
                            msg = 'Unhandled unsolicited message from server'
                        finally:
                            raise LDAPUnsolicitedMessage(response, msg)
                    else:
                        if have_message_id not in self._message_queues:
                            self._message_queues[have_message_id] = deque()
                        self._message_queues[have_message_id].append(response)
            except SubstrateUnderrunError:
                flush_queue = False
                continue

    def close(self):
        """Close the low-level socket connection."""
        return self._sock.close()