예제 #1
0
파일: server.py 프로젝트: hughrobb/aioldap3
class ClientProtocol(LDAPProtocol):
    def __init__(self, server=None):
        super().__init__()
        self.server = server

    async def starttls(self):
        future = asyncio.Future()
        self.transport.resume_reading()
        sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
        sslcontext.load_cert_chain('cert.pem')
        self.tlssession = SSLProtocol( loop=self.transport._loop, app_protocol=self, sslcontext=sslcontext, waiter=future, server_side=True, server_hostname='localhost' )
        self.transport._protocol = self.tlssession
        self.tlssession.connection_made( self.transport )
        await future

    def make_attribute_list(self, *args):
        hr = rfc4511.PartialAttributeList()
        for k,v in args:
            k = rfc4511.AttributeDescription(k)
            v = rfc4511.AttributeValue(v)
            attr = rfc4511.PartialAttribute()
            attr[0] = k
            attr[1] = rfc4511.SetOfAttributeValue(value=[v])
            hr.append( attr )
        return hr
예제 #2
0
파일: protocol.py 프로젝트: tdna/aiosmtplib
    def upgrade_transport(self,
                          context: ssl.SSLContext,
                          server_hostname: str = None,
                          waiter: Awaitable = None) -> SSLProtocol:
        """
        Upgrade our transport to TLS in place.
        """
        assert not self._over_ssl, 'Already using TLS'

        if self._stream_reader is None or self._stream_writer is None:
            raise SMTPServerDisconnected('Client not connected')

        transport = self._stream_reader._transport  # type: ignore

        tls_protocol = SSLProtocol(self._loop,
                                   self,
                                   context,
                                   waiter,
                                   server_side=False,
                                   server_hostname=server_hostname)

        app_transport = tls_protocol._app_transport
        # Use set_protocol if we can
        if hasattr(transport, 'set_protocol'):
            transport.set_protocol(tls_protocol)
        else:
            transport._protocol = tls_protocol

        self._stream_reader._transport = app_transport  # type: ignore
        self._stream_writer._transport = app_transport  # type: ignore

        tls_protocol.connection_made(transport)
        self._over_ssl = True  # type: bool

        return tls_protocol
예제 #3
0
파일: server.py 프로젝트: hughrobb/aioldap3
class ClientProtocol(LDAPProtocol):
    def __init__(self, server=None):
        super().__init__()
        self.server = server

    async def starttls(self):
        future = asyncio.Future()
        self.transport.resume_reading()
        sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
        sslcontext.load_cert_chain('cert.pem')
        self.tlssession = SSLProtocol(loop=self.transport._loop,
                                      app_protocol=self,
                                      sslcontext=sslcontext,
                                      waiter=future,
                                      server_side=True,
                                      server_hostname='localhost')
        self.transport._protocol = self.tlssession
        self.tlssession.connection_made(self.transport)
        await future

    def make_attribute_list(self, *args):
        hr = rfc4511.PartialAttributeList()
        for k, v in args:
            k = rfc4511.AttributeDescription(k)
            v = rfc4511.AttributeValue(v)
            attr = rfc4511.PartialAttribute()
            attr[0] = k
            attr[1] = rfc4511.SetOfAttributeValue(value=[v])
            hr.append(attr)
        return hr
예제 #4
0
파일: server.py 프로젝트: hughrobb/aioldap3
 async def starttls(self):
     future = asyncio.Future()
     self.transport.resume_reading()
     sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
     sslcontext.load_cert_chain('cert.pem')
     self.tlssession = SSLProtocol(loop=self.transport._loop,
                                   app_protocol=self,
                                   sslcontext=sslcontext,
                                   waiter=future,
                                   server_side=True,
                                   server_hostname='localhost')
     self.transport._protocol = self.tlssession
     self.tlssession.connection_made(self.transport)
     await future
예제 #5
0
    def wrap_transport(self, waiter):
        old_transport = self.stream_writer._transport
        old_protocol = self.stream_writer._protocol

        tls_protocol = SSLProtocol(self.loop,
                                   old_protocol,
                                   self.tls_context,
                                   waiter,
                                   server_side=True,
                                   call_connection_made=False)

        old_transport.set_protocol(tls_protocol)
        self.stream_reader._transport = tls_protocol._app_transport
        self.stream_writer._transport = tls_protocol._app_transport

        tls_protocol.connection_made(old_transport)
        tls_protocol._over_ssl = True
예제 #6
0
파일: server.py 프로젝트: hughrobb/aioldap3
 async def starttls(self):
     future = asyncio.Future()
     self.transport.resume_reading()
     sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
     sslcontext.load_cert_chain('cert.pem')
     self.tlssession = SSLProtocol( loop=self.transport._loop, app_protocol=self, sslcontext=sslcontext, waiter=future, server_side=True, server_hostname='localhost' )
     self.transport._protocol = self.tlssession
     self.tlssession.connection_made( self.transport )
     await future
예제 #7
0
 async def starttls(self):
     if self._getters:
         raise LDAPExceptionError(
             'cannot start tls while any operations are active.')
     response = await self.extend('1.3.6.1.4.1.1466.20037')
     if int(response[1].chosen[0]) != 0:
         raise LDAPExtensionError(response)
     future = asyncio.Future()
     sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
     self.tlssession = SSLProtocol(loop=self.transport._loop,
                                   app_protocol=self,
                                   sslcontext=sslcontext,
                                   waiter=future,
                                   server_side=False,
                                   server_hostname=None)
     self.transport._protocol = self.tlssession
     self.tlssession.connection_made(self.transport)
     return (await future)
예제 #8
0
파일: client.py 프로젝트: hughrobb/aioldap3
 async def starttls(self):
     if self._getters:
         raise LDAPExceptionError('cannot start tls while any operations are active.')
     response = await self.extend( '1.3.6.1.4.1.1466.20037' )
     if int(response[1].chosen[0]) != 0:
         raise LDAPExtensionError( response )
     future = asyncio.Future()
     sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
     self.tlssession = SSLProtocol( loop=self.transport._loop, app_protocol=self, sslcontext=sslcontext, waiter=future, server_side=False, server_hostname=None )
     self.transport._protocol = self.tlssession
     self.tlssession.connection_made( self.transport )
     return (await future)
예제 #9
0
파일: fake.py 프로젝트: pyzh/blackout
async def make_pair(c, s):
    loop = get_event_loop()

    f = Future()

    pc = SSLProtocol(loop, c, create_tls_context(), f)
    ps = SSLProtocol(loop, s, create_tls_context(), None, server_side=True)

    tc = FakeTransport(ps)
    ts = FakeTransport(pc)

    ps.connection_made(ts)
    pc.connection_made(tc)

    await f
예제 #10
0
 def create_protocol(self):  # type: () -> Protocol
     connection = self.server.new_connection()
     if self.ssl_context is not None:
         protocol = HttpProtocol(connection, on_ssl=True)
         protocol = SSLProtocol(
             loop=asyncio.get_running_loop(),
             app_protocol=protocol,
             sslcontext=self.ssl_context,
             waiter=None,
             server_side=True,
         )
     else:
         protocol = HttpProtocol(connection, on_ssl=False)
     return protocol
예제 #11
0
파일: compat.py 프로젝트: decaz/aiosmtplib
async def start_tls(
    loop: asyncio.AbstractEventLoop,
    transport: asyncio.Transport,
    protocol: asyncio.Protocol,
    sslcontext: ssl.SSLContext,
    server_side: bool = False,
    server_hostname: Optional[str] = None,
    ssl_handshake_timeout: Optional[Union[float, int]] = None,
) -> asyncio.Transport:
    # We use hasattr here, as uvloop also supports start_tls.
    if hasattr(loop, "start_tls"):
        return await loop.start_tls(  # type: ignore
            transport,
            protocol,
            sslcontext,
            server_side=server_side,
            server_hostname=server_hostname,
            ssl_handshake_timeout=ssl_handshake_timeout,
        )

    waiter = loop.create_future()
    ssl_protocol = SSLProtocol(loop, protocol, sslcontext, waiter, server_side,
                               server_hostname)

    # Pause early so that "ssl_protocol.data_received()" doesn't
    # have a chance to get called before "ssl_protocol.connection_made()".
    transport.pause_reading()

    # Use set_protocol if we can
    if hasattr(transport, "set_protocol"):
        transport.set_protocol(ssl_protocol)  # type: ignore
    else:
        transport._protocol = ssl_protocol  # type: ignore

    conmade_cb = loop.call_soon(ssl_protocol.connection_made, transport)
    resume_cb = loop.call_soon(transport.resume_reading)

    try:
        await asyncio.wait_for(waiter, timeout=ssl_handshake_timeout)
    except Exception:
        transport.close()
        conmade_cb.cancel()
        resume_cb.cancel()
        raise

    return ssl_protocol._app_transport
예제 #12
0
    async def init_connection(self, loop):
        waiter = Future()

        hello_sent = Future()

        context = create_tls_context()
        trans = CaptureClientHello(hello_sent)
        ssl_proto = SSLProtocol(loop, ClientProtocol(), context, waiter)
        ssl_proto.connection_made(trans)

        out_data = await hello_sent
        my_random = out_data[15:43]

        transport = await self.connected
        transport.write(out_data)

        await self.hello_received
        peer_random = self.buffer[15:43]

        if my_random == peer_random:
            transport.close()
            return

        if my_random > peer_random:
            print("I am server", self.proxy.name, repr(self), transport)

            ssl_proto._waiter = None
            ssl_proto._transport = None

            context = create_tls_context()
            proto = SSLProtocol(loop, ServerProtocol(), context, waiter, server_side=True)

            self.proxy.switch(proto)
            proto.connection_made(transport)
            loop.call_soon(proto.data_received, self.buffer)

        else:
            print("I am client", self.proxy.name, repr(self), transport)

            ssl_proto._transport = transport
            self.proxy.switch(ssl_proto)
            loop.call_soon(ssl_proto.data_received, self.buffer[self.length+5:])

        await waiter
예제 #13
0
    async def init_connection(self, loop, connection):
        waiter = Future()
        hello_sent = Future()

        context = create_tls_context()
        trans = CaptureClientHello(hello_sent)
        ssl_proto = SSLProtocol(loop, connection, context, waiter)
        ssl_proto.connection_made(trans)

        out_data = await hello_sent
        my_random = out_data[15:43]

        transport = await self.connected
        transport.write(out_data)

        await self.hello_received
        peer_random = self.buffer[15:43]

        if my_random == peer_random:
            transport.close()
            return

        if my_random > peer_random:
            ssl_proto._waiter = None
            ssl_proto._transport = None

            context = create_tls_context()
            proto = SSLProtocol(loop, connection, context, waiter, server_side=True)

            self.proxy.switch(proto)
            proto.connection_made(transport)
            loop.call_soon(proto.data_received, self.buffer)

        else:
            ssl_proto._transport = transport
            self.proxy.switch(ssl_proto)
            loop.call_soon(ssl_proto.data_received, self.buffer[self.length+5:])

        try:
            await waiter
        except:
            connection.endpoint.connections.pop(connection.addr)
            raise
예제 #14
0
class LDAPClient(LDAPProtocol):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._sequence = 0  # FIXME
        self._getters = dict()

    def bind(self, user, password, controls=None):
        msg = self.make_message()
        request = rfc4511.BindRequest()
        request[0] = 3
        request[1] = user.encode('utf8')
        request[2] = rfc4511.AuthenticationChoice('simple',
                                                  password.encode('utf8'))
        msg[1] = rfc4511.ProtocolOp('bindRequest', request)
        self.send(msg)
        return self.make_future(FutureSimpleResponse, msg)

    def search(
            self,
            search_base,
            search_filter,
            search_scope=2,  # subtree
            dereference_aliases=3,  # deref-always
            attributes=None,
            size_limit=0,
            time_limit=0,
            types_only=False,
            get_operational_attributes=False,
            controls=None,
            paged_size=None,
            paged_criticality=False):
        msg = self.make_message()
        request = rfc4511.SearchRequest()
        request[0] = search_base.encode('utf8')
        request[1] = search_scope
        request[2] = dereference_aliases
        request[3] = size_limit
        request[4] = time_limit
        request[5] = types_only
        kv = rfc4511.AttributeValueAssertion()
        kv[0], kv[1] = b'cn', b'hugh'
        request[6] = rfc4511.Filter('equalityMatch', kv)
        request[7] = rfc4511.AttributeSelection(
            value=b'cn givenName sn'.split())
        msg[1] = rfc4511.ProtocolOp('searchRequest', request)
        self.send(msg)
        return self.make_future(FutureSearchResponse, msg)

    def extend(self, request_name, request_value=None, controls=None):
        msg = self.make_message()
        request = rfc4511.ExtendedRequest()
        request[0] = request_name.encode('utf8')
        if request_value is not None:
            request[1] = request_value.encode('utf8')
        msg[1] = rfc4511.ProtocolOp('extendedReq', request)
        self.send(msg)
        return self.make_future(FutureSimpleResponse, msg)

    def make_message(self):
        self._sequence += 1
        msg = rfc4511.LDAPMessage()
        msg[0] = self._sequence
        return msg

    def make_future(self, future_factory, ldap_message):
        message_id = int(ldap_message[0])
        future = future_factory(message_id)
        self._getters[message_id] = future
        return future

    def message_id_overflowed(self, future, ldap_response):
        self.transport.pause_reading()
        future.set_resume_reading_callback(
            lambda: self.can_resume_reading(future, ldap_response))

    def can_resume_reading(self, future, ldap_message):
        future.write_response(ldap_message)
        self.transport.resume_reading()

    def message_received(self, ldap_message):
        message_id = int(ldap_message[0])

        # unsolicited messages
        if message_id == 0:
            logging.warn('unsolicited message {}'.format(ldap_message))
            return

        # route message to getter
        future = self._getters.get(message_id, None)
        if future is None:
            logging.warn('message without getter {}'.format(ldap_message))
            return

        if future.full():
            self.message_id_overflowed(future, ldap_message)
            return

        future.write_response(ldap_message)
        if future.done():
            del self._getters[future.message_id]

    async def starttls(self):
        if self._getters:
            raise LDAPExceptionError(
                'cannot start tls while any operations are active.')
        response = await self.extend('1.3.6.1.4.1.1466.20037')
        if int(response[1].chosen[0]) != 0:
            raise LDAPExtensionError(response)
        future = asyncio.Future()
        sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
        self.tlssession = SSLProtocol(loop=self.transport._loop,
                                      app_protocol=self,
                                      sslcontext=sslcontext,
                                      waiter=future,
                                      server_side=False,
                                      server_hostname=None)
        self.transport._protocol = self.tlssession
        self.tlssession.connection_made(self.transport)
        return (await future)
예제 #15
0
파일: client.py 프로젝트: hughrobb/aioldap3
class LDAPClient(LDAPProtocol):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)    
        self._sequence = 0     # FIXME
        self._getters = dict()

    def bind(self, user, password, controls=None):
        msg = self.make_message()
        request = rfc4511.BindRequest()
        request[0] = 3
        request[1] = user.encode('utf8')
        request[2] = rfc4511.AuthenticationChoice('simple', password.encode('utf8'))
        msg[1] = rfc4511.ProtocolOp('bindRequest', request)
        self.send(msg)
        return self.make_future( FutureSimpleResponse, msg )

    def search(self,
        search_base,
        search_filter,
        search_scope=2,     # subtree
        dereference_aliases=3,  # deref-always
        attributes=None,
        size_limit=0,
        time_limit=0,
        types_only=False,
        get_operational_attributes=False,
        controls=None,
        paged_size=None,
        paged_criticality=False
    ):
        msg = self.make_message()
        request = rfc4511.SearchRequest()
        request[0] = search_base.encode('utf8')
        request[1] = search_scope
        request[2] = dereference_aliases
        request[3] = size_limit
        request[4] = time_limit
        request[5] = types_only
        kv = rfc4511.AttributeValueAssertion()
        kv[0], kv[1] = b'cn', b'hugh'
        request[6] = rfc4511.Filter('equalityMatch', kv)
        request[7] = rfc4511.AttributeSelection(value=b'cn givenName sn'.split())
        msg[1] = rfc4511.ProtocolOp('searchRequest', request)
        self.send(msg)
        return self.make_future( FutureSearchResponse, msg )

    def extend(self, request_name, request_value=None, controls=None):
        msg = self.make_message()
        request = rfc4511.ExtendedRequest()
        request[0] = request_name.encode('utf8')
        if request_value is not None:
            request[1] = request_value.encode('utf8')
        msg[1] = rfc4511.ProtocolOp('extendedReq', request)
        self.send(msg)
        return self.make_future( FutureSimpleResponse, msg )

    def make_message(self):
        self._sequence += 1
        msg = rfc4511.LDAPMessage()
        msg[0] = self._sequence
        return msg
   
    def make_future(self, future_factory, ldap_message):
        message_id = int(ldap_message[0])
        future = future_factory( message_id )
        self._getters[ message_id ] = future
        return future

    def message_id_overflowed(self, future, ldap_response):
        self.transport.pause_reading()
        future.set_resume_reading_callback( lambda: self.can_resume_reading(future, ldap_response) )

    def can_resume_reading(self, future, ldap_message):
        future.write_response( ldap_message )
        self.transport.resume_reading()

    def message_received(self, ldap_message):
        message_id = int(ldap_message[0])

        # unsolicited messages
        if message_id == 0:
            logging.warn('unsolicited message {}'.format(ldap_message))
            return

        # route message to getter
        future = self._getters.get( message_id, None )
        if future is None:
            logging.warn('message without getter {}'.format(ldap_message))
            return

        if future.full():
            self.message_id_overflowed( future, ldap_message )
            return

        future.write_response( ldap_message )
        if future.done():
            del self._getters[ future.message_id ]

    async def starttls(self):
        if self._getters:
            raise LDAPExceptionError('cannot start tls while any operations are active.')
        response = await self.extend( '1.3.6.1.4.1.1466.20037' )
        if int(response[1].chosen[0]) != 0:
            raise LDAPExtensionError( response )
        future = asyncio.Future()
        sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
        self.tlssession = SSLProtocol( loop=self.transport._loop, app_protocol=self, sslcontext=sslcontext, waiter=future, server_side=False, server_hostname=None )
        self.transport._protocol = self.tlssession
        self.tlssession.connection_made( self.transport )
        return (await future)