コード例 #1
0
ファイル: authentication.py プロジェクト: baby636/lamden
    def __init__(self, client: ContractingClient, ctx: zmq.asyncio.Context, bootnodes: dict={},
                 loop=asyncio.get_event_loop(), domain='*', cert_dir=CERT_DIR, debug=False):

        # Create the directory if it doesn't exist
        self.client = client

        self.cert_dir = pathlib.Path.home() / cert_dir
        self.cert_dir.mkdir(parents=True, exist_ok=True)

        self.ctx = ctx

        self.domain = domain

        self.loop = loop

        self.log = get_logger('zmq.auth')
        self.log.propagate = debug

        self.bootnodes = bootnodes

        # This should throw an exception if the socket already exist
        try:
            self.authenticator = AsyncioAuthenticator(context=self.ctx, loop=self.loop)
            self.authenticator.start()

        except ZMQBaseError:
            self.log.error('Error starting ZMQ Authenticator. Is it already running?')

        finally:
            for node in bootnodes.keys():
                self.add_verifying_key(node)

            self.authenticator.configure_curve(domain=self.domain, location=self.cert_dir)
コード例 #2
0
    def __init__(self,
                 ctx: zmq.asyncio.Context,
                 loop=asyncio.get_event_loop(),
                 domain='*',
                 cert_dir=CERT_DIR,
                 debug=True):

        # Create the directory if it doesn't exist

        self.cert_dir = pathlib.Path.home() / cert_dir
        self.cert_dir.mkdir(parents=True, exist_ok=True)

        self.ctx = ctx

        self.domain = domain

        self.loop = loop

        self.log = get_logger('zmq.auth')
        self.log.propagate = debug

        # This should throw an exception if the socket already exist
        try:
            self.authenticator = AsyncioAuthenticator(context=self.ctx,
                                                      loop=self.loop)
            self.authenticator.start()
            self.authenticator.configure_curve(domain=self.domain,
                                               location=self.cert_dir)

        except ZMQBaseError:
            pass
コード例 #3
0
 def start_auth(self, context: zmq.Context) -> bool:
     """
     Starts the ZMQ auth service thread, enabling authorization on all sockets within this context
     """
     if not self.auth_configured:
         return False
     self._socket.curve_secretkey = self._auth_config.server_secret_key
     self._socket.curve_publickey = self._auth_config.server_public_key
     self._socket.curve_server = True
     self._authenticator = AsyncioAuthenticator(context)
     if self._preloaded_keys:
         self.set_client_keys(self._preloaded_keys)
     else:
         self.load_client_keys_from_directory()
     self._authenticator.start()
     return True
コード例 #4
0
    def setup(self, socket_type):
        """
        :param socket_type: zmq.DEALER or zmq.ROUTER
        """
        if self._secured:
            if self._server_public_key is None or \
                    self._server_private_key is None:
                raise LocalConfigurationError("Attempting to start socket "
                                              "in secure mode, but complete "
                                              "server keys were not provided")

        self._event_loop = zmq.asyncio.ZMQEventLoop()
        asyncio.set_event_loop(self._event_loop)
        self._context = zmq.asyncio.Context()
        self._socket = self._context.socket(socket_type)

        if socket_type == zmq.DEALER:
            self._socket.identity = "{}-{}".format(
                self._zmq_identity,
                hashlib.sha512(uuid.uuid4().hex.encode()).hexdigest()
                [:23]).encode('ascii')

            if self._secured:
                # Generate ephemeral certificates for this connection
                self._socket.curve_publickey, self._socket.curve_secretkey = \
                    zmq.curve_keypair()

                self._socket.curve_serverkey = self._server_public_key

            self._dispatcher.add_send_message(self._connection,
                                              self.send_message)
            self._socket.connect(self._address)
        elif socket_type == zmq.ROUTER:
            if self._secured:
                auth = AsyncioAuthenticator(self._context)
                auth.start()
                auth.configure_curve(domain='*',
                                     location=zmq.auth.CURVE_ALLOW_ANY)

                self._socket.curve_secretkey = self._server_private_key
                self._socket.curve_publickey = self._server_public_key
                self._socket.curve_server = True

            self._dispatcher.add_send_message(self._connection,
                                              self.send_message)
            self._socket.bind(self._address)

        self._recv_queue = asyncio.Queue()

        asyncio.ensure_future(self._receive_message(), loop=self._event_loop)

        if self._heartbeat:
            asyncio.ensure_future(self._send_heartbeat(),
                                  loop=self._event_loop)

        with self._condition:
            self._condition.notify_all()
        self._event_loop.run_forever()
コード例 #5
0
ファイル: stunnel_server.py プロジェクト: nanzeng/stunnel
    async def run(self):
        authenticator = AsyncioAuthenticator(self.context)
        authenticator.start()
        asyncio.create_task(
            self.monitor_certificates(authenticator, self.public_keys_dir))

        self.socket = self.context.socket(zmq.ROUTER)
        self.socket.curve_secretkey = self.secret_key
        self.socket.curve_publickey = self.public_key
        self.socket.curve_server = True
        self.socket.bind(f'tcp://0.0.0.0:{self.port}')
        logging.info(f'Listening on tunnel port {self.port}')

        while True:
            msg = await self.socket.recv_multipart()
            addr = msg[0]
            request = msg[2:]
            cmd = request[0]

            if addr not in self.liveness:
                asyncio.create_task(self.create_session(addr))

            if cmd == RELAY:
                asyncio.create_task(self.to_client(addr, *request[1:]))

            self.liveness[addr] = self.heartbeat_liveness

        authenticator.stop()
コード例 #6
0
def serve(datapath):
    with open(datapath) as fh:
        data = json.load(fh)

    ctx = zmq.asyncio.Context()
    socket = ctx.socket(zmq.PUB)
    authenticator = AsyncioAuthenticator(ctx)
    authenticator.configure_plain(passwords={'stats': 'test'})
    authenticator.allow('127.0.0.1')
    socket.plain_server = True
    socket.bind("tcp://127.0.0.1:55055")

    asyncio.run(recv_and_process(authenticator, socket, data))
コード例 #7
0
 def make_auth(self):
     return AsyncioAuthenticator(self.context)
コード例 #8
0
    def setup(self, socket_type, complete_or_error_queue):
        """Setup the asyncio event loop.

        Args:
            socket_type (int from zmq.*): One of zmq.DEALER or zmq.ROUTER
            complete_or_error_queue (queue.Queue): A way to propagate errors
                back to the calling thread. Needed since this function is
                directly used in Thread.

        Returns:
            None
        """
        try:
            if self._secured:
                if self._server_public_key is None or \
                        self._server_private_key is None:
                    raise LocalConfigurationError(
                        "Attempting to start socket in secure mode, "
                        "but complete server keys were not provided")

            self._event_loop = zmq.asyncio.ZMQEventLoop()
            asyncio.set_event_loop(self._event_loop)
            self._context = zmq.asyncio.Context()
            self._socket = self._context.socket(socket_type)

            if socket_type == zmq.DEALER:
                self._socket.identity = "{}-{}".format(
                    self._zmq_identity,
                    hashlib.sha512(uuid.uuid4().hex.encode()).hexdigest()
                    [:23]).encode('ascii')

                if self._secured:
                    # Generate ephemeral certificates for this connection

                    pubkey, secretkey = zmq.curve_keypair()
                    self._socket.curve_publickey = pubkey
                    self._socket.curve_secretkey = secretkey
                    self._socket.curve_serverkey = self._server_public_key

                self._socket.connect(self._address)
            elif socket_type == zmq.ROUTER:
                if self._secured:
                    auth = AsyncioAuthenticator(self._context)
                    self._auth = auth
                    auth.start()
                    auth.configure_curve(domain='*',
                                         location=zmq.auth.CURVE_ALLOW_ANY)

                    self._socket.curve_secretkey = self._server_private_key
                    self._socket.curve_publickey = self._server_public_key
                    self._socket.curve_server = True

                try:
                    self._socket.bind(self._address)
                except zmq.error.ZMQError as e:
                    raise LocalConfigurationError(
                        "Can't bind to {}: {}".format(self._address, str(e)))
                else:
                    LOGGER.info("Listening on %s", self._address)

            self._dispatcher.add_send_message(self._connection,
                                              self.send_message)
            self._dispatcher.add_send_last_message(self._connection,
                                                   self.send_last_message)

            asyncio.ensure_future(self._receive_message(),
                                  loop=self._event_loop)
            if self._monitor:
                self._monitor_fd = "inproc://monitor.s-{}".format(
                    _generate_id()[0:5])
                self._monitor_sock = self._socket.get_monitor_socket(
                    zmq.EVENT_DISCONNECTED, addr=self._monitor_fd)
                asyncio.ensure_future(self._monitor_disconnects(),
                                      loop=self._event_loop)

        except Exception as e:
            # Put the exception on the queue where in start we are waiting
            # for it.
            complete_or_error_queue.put_nowait(e)
            raise

        if self._heartbeat:
            asyncio.ensure_future(self._do_heartbeat(), loop=self._event_loop)

        # Put a 'complete with the setup tasks' sentinel on the queue.
        complete_or_error_queue.put_nowait(_STARTUP_COMPLETE_SENTINEL)

        asyncio.ensure_future(self._notify_started(), loop=self._event_loop)

        self._event_loop.run_forever()
        # event_loop.stop called elsewhere will cause the loop to break out
        # of run_forever then it can be closed and the context destroyed.
        self._event_loop.close()
        self._socket.close(linger=0)
        if self._monitor:
            self._monitor_sock.close(linger=0)
        self._context.destroy(linger=0)
コード例 #9
0
async def run():
    ''' Run Ironhouse example '''

    # These directories are generated by the generate_certificates script
    base_dir = Path(__file__).parent
    keys_dir = base_dir / 'certificates'
    public_keys_dir = base_dir / 'public_keys'
    secret_keys_dir = base_dir / 'private_keys'

    if (
        not keys_dir.is_dir()
        or not public_keys_dir.is_dir()
        or not secret_keys_dir.is_dir()
    ):
        logging.critical(
            "Certificates are missing - run generate_certificates.py script first"
        )
        sys.exit(1)

    ctx = Context.instance()

    # Start an authenticator for this context.
    auth = AsyncioAuthenticator(ctx)
    auth.start()
    auth.allow('127.0.0.1')
    # Tell authenticator to use the certificate in a directory
    auth.configure_curve(domain='*', location=public_keys_dir)

    server = ctx.socket(zmq.ROUTER)

    server_secret_file = secret_keys_dir / "server.key_secret"
    server_public, server_secret = zmq.auth.load_certificate(server_secret_file)
    server.curve_secretkey = server_secret
    server.curve_publickey = server_public
    server.curve_server = True  # must come before bind
    server.bind('tcp://*:9000')

    client = ctx.socket(zmq.DEALER)

    # We need two certificates, one for the client and one for
    # the server. The client must know the server's public key
    # to make a CURVE connection.
    client_secret_file = secret_keys_dir / "client.key_secret"
    client_public, client_secret = zmq.auth.load_certificate(client_secret_file)
    client.curve_secretkey = client_secret
    client.curve_publickey = client_public

    server_public_file = public_keys_dir / "server.key"
    server_public, _ = zmq.auth.load_certificate(server_public_file)
    # The client must know the server's public key to make a CURVE connection.
    client.curve_serverkey = server_public
    client.connect('tcp://127.0.0.1:9000')

    await client.send(b"Hello")

    if await server.poll(1000):
        # use copy=False to allow access to message properties via the zmq.Frame API
        # default recv(copy=True) returns only bytes, discarding properties
        identity, msg = await server.recv_multipart(copy=False)
        logging.info(f"Received {msg.bytes} from {msg['User-Id']!r}")
        if msg.bytes == b"Hello":
            logging.info("Ironhouse test OK")
    else:
        logging.error("Ironhouse test FAIL")

    # close sockets
    server.close()
    client.close()
    # stop auth task
    auth.stop()
コード例 #10
0
ファイル: ironhouse.py プロジェクト: Mat001/cilantro
class Ironhouse:
    def __init__(self,
                 sk=None,
                 auth_validate=None,
                 wipe_certs=False,
                 auth_port=None,
                 keyname=None,
                 *args,
                 **kwargs):
        self.auth_port = auth_port or os.getenv('AUTH_PORT', 4523)
        self.keyname = keyname or os.getenv('HOSTNAME',
                                            basename(splitext(__file__)[0]))
        self.base_dir = 'certs/{}'.format(self.keyname)
        self.keys_dir = os.path.join(self.base_dir, 'certificates')
        self.public_keys_dir = os.path.join(self.base_dir, 'public_keys')
        self.secret_keys_dir = os.path.join(self.base_dir, 'private_keys')
        self.secret_file = os.path.join(self.secret_keys_dir,
                                        "{}.key_secret".format(self.keyname))
        if auth_validate:
            self.auth_validate = auth_validate
        else:
            self.auth_validate = Ironhouse.auth_validate
        self.wipe_certs = wipe_certs
        if sk:
            self.generate_certificates(sk)
        self.public_key, self.secret = zmq.auth.load_certificate(
            self.secret_file)

    def vk2pk(self, vk):
        return encode(
            VerifyKey(
                bytes.fromhex(vk)).to_curve25519_public_key()._public_key)

    def generate_certificates(self, sk):
        sk = SigningKey(seed=bytes.fromhex(sk))
        self.vk = sk.verify_key.encode().hex()
        self.public_key = self.vk2pk(self.vk)
        private_key = crypto_sign_ed25519_sk_to_curve25519(
            sk._signing_key).hex()

        for d in [self.keys_dir, self.public_keys_dir, self.secret_keys_dir]:
            if self.wipe_certs and os.path.exists(d):
                shutil.rmtree(d)
            os.makedirs(d, exist_ok=True)

        if self.wipe_certs:
            self.create_from_private_key(private_key)

            # move public keys to appropriate directory
            for key_file in os.listdir(self.keys_dir):
                if key_file.endswith(".key"):
                    shutil.move(os.path.join(self.keys_dir, key_file),
                                os.path.join(self.public_keys_dir, '.'))

            # move secret keys to appropriate directory
            for key_file in os.listdir(self.keys_dir):
                if key_file.endswith(".key_secret"):
                    shutil.move(os.path.join(self.keys_dir, key_file),
                                os.path.join(self.secret_keys_dir, '.'))

            log.info('Generated CURVE certificate files!')

    def create_from_private_key(self, private_key):
        priv = PrivateKey(bytes.fromhex(private_key))
        publ = priv.public_key
        self.public_key = public_key = encode(publ._public_key)
        secret_key = encode(priv._private_key)

        base_filename = os.path.join(self.keys_dir, self.keyname)
        secret_key_file = "{0}.key_secret".format(base_filename)
        public_key_file = "{0}.key".format(base_filename)
        now = datetime.datetime.now()

        zmq.auth.certs._write_key_file(
            public_key_file, zmq.auth.certs._cert_public_banner.format(now),
            public_key)

        zmq.auth.certs._write_key_file(
            secret_key_file,
            zmq.auth.certs._cert_secret_banner.format(now),
            public_key,
            secret_key=secret_key)

    def create_from_public_key(self, public_key):
        if self.public_key == public_key:
            return
        keyname = decode(public_key).hex()
        base_filename = os.path.join(self.public_keys_dir, keyname)
        public_key_file = "{0}.key".format(base_filename)
        now = datetime.datetime.now()

        if os.path.exists(public_key_file):
            log.debug('Public cert for {} has already been created.'.format(
                public_key))
            return

        os.makedirs(self.public_keys_dir, exist_ok=True)
        log.info(
            'Adding new public key cert {} to the system.'.format(public_key))

        zmq.auth.certs._write_key_file(
            public_key_file, zmq.auth.certs._cert_public_banner.format(now),
            public_key)

    def secure_context(self, async=False):
        if async:
            ctx = zmq.asyncio.Context()
            auth = AsyncioAuthenticator(ctx)
        else:
            ctx = zmq.Context()
            auth = ThreadAuthenticator(ctx)
        auth.start()
        self.reconfigure_curve(auth)

        return ctx, auth
コード例 #11
0
class Ironhouse:

    auth_port = os.getenv('AUTH_PORT', 4523)
    keyname = os.getenv('HOST_NAME', 'ironhouse')
    authorized_nodes = {}
    base_dir = 'certs/{}'.format(keyname)
    keys_dir = join(base_dir, 'certificates')
    authorized_keys_dir = join(base_dir, 'authorized_keys')
    ctx = None
    auth = None
    daemon_auth = None

    def __init__(self,
                 sk=None,
                 auth_validate=None,
                 wipe_certs=False,
                 auth_port=None,
                 keyname=None,
                 *args,
                 **kwargs):
        if auth_validate: self.auth_validate = auth_validate
        else: self.auth_validate = Ironhouse.auth_validate
        self.auth_port = auth_port or self.auth_port
        self.keyname = keyname or Ironhouse.keyname
        self.authorized_keys = {}
        self.pk2vk = {}
        self.vk, self.public_key, self.secret = self.generate_certificates(
            sk, wipe_certs=wipe_certs)

    @classmethod
    def vk2pk(cls, vk):
        return encode(
            VerifyKey(
                bytes.fromhex(vk)).to_curve25519_public_key()._public_key)

    @classmethod
    def generate_certificates(cls, sk_hex, wipe_certs=False):
        sk = SigningKey(seed=bytes.fromhex(sk_hex))
        vk = sk.verify_key.encode().hex()
        public_key = cls.vk2pk(vk)
        private_key = crypto_sign_ed25519_sk_to_curve25519(
            sk._signing_key).hex()

        for d in [cls.keys_dir, cls.authorized_keys_dir]:
            if wipe_certs and exists(d):
                shutil.rmtree(d)
            os.makedirs(d, exist_ok=True)

        if wipe_certs:
            _, secret = cls.create_from_private_key(private_key)

            for key_file in os.listdir(cls.keys_dir):
                if key_file.endswith(".key"):
                    shutil.move(join(cls.keys_dir, key_file),
                                join(cls.authorized_keys_dir, '.'))

            if exists(cls.keys_dir):
                shutil.rmtree(cls.keys_dir)

            log.info('Generated CURVE certificate files!')

        return vk, public_key, secret

    @classmethod
    def create_from_private_key(cls, private_key):
        priv = PrivateKey(bytes.fromhex(private_key))
        publ = priv.public_key
        public_key = encode(publ._public_key)
        secret = encode(priv._private_key)

        base_filename = join(cls.keys_dir, cls.keyname)
        public_key_file = "{0}.key".format(base_filename)
        now = datetime.datetime.now()

        zmq.auth.certs._write_key_file(
            public_key_file, zmq.auth.certs._cert_public_banner.format(now),
            public_key)

        return public_key, secret

    def add_public_key(self, public_key):
        if self.public_key == public_key: return
        keyname = decode(public_key).hex()
        base_filename = join(self.authorized_keys_dir, keyname)
        public_key_file = "{0}.key".format(base_filename)
        now = datetime.datetime.now()

        if exists(public_key_file):
            log.debug('Public cert for {} has already been created.'.format(
                public_key))
            return

        os.makedirs(self.authorized_keys_dir, exist_ok=True)
        log.info(
            'Adding new public key cert {} to the system.'.format(public_key))

        zmq.auth.certs._write_key_file(
            public_key_file, zmq.auth.certs._cert_public_banner.format(now),
            public_key)

        log.debug('{} has added {} to its authorized list'.format(
            os.getenv('HOST_IP', '127.0.0.1'), public_key))
        self.reconfigure_curve()
        self.authorized_keys[public_key] = True

    def remove_public_key(self, public_key):
        if self.public_key == public_key: return
        keyname = decode(public_key).hex()
        base_filename = join(self.authorized_keys_dir, keyname)
        public_key_file = "{0}.key".format(base_filename)
        if exists(public_key_file):
            os.remove(public_key_file)

        log.debug('{} has remove {} from its authorized list'.format(
            os.getenv('HOST_IP', '127.0.0.1'), public_key))
        self.reconfigure_curve()
        self.authorized_keys[public_key] = False

    @classmethod
    def secure_context(cls, async=False):
        if async:
            ctx = zmq.asyncio.Context()
            auth = AsyncioAuthenticator(ctx)
            auth.log = log  # The constructor doesn't have "log" like its synchronous counter-part
        else:
            ctx = zmq.Context()
            auth = ThreadAuthenticator(ctx, log=log)
        auth.start()
        return ctx, auth
コード例 #12
0
class SocketAuthenticator:
    def __init__(self,
                 ctx: zmq.asyncio.Context,
                 loop=asyncio.get_event_loop(),
                 domain='*',
                 cert_dir=CERT_DIR,
                 debug=True):

        # Create the directory if it doesn't exist

        self.cert_dir = pathlib.Path.home() / cert_dir
        self.cert_dir.mkdir(parents=True, exist_ok=True)

        self.ctx = ctx

        self.domain = domain

        self.loop = loop

        self.log = get_logger('zmq.auth')
        self.log.propagate = debug

        # This should throw an exception if the socket already exist
        try:
            self.authenticator = AsyncioAuthenticator(context=self.ctx,
                                                      loop=self.loop)
            self.authenticator.start()
            self.authenticator.configure_curve(domain=self.domain,
                                               location=self.cert_dir)

        except ZMQBaseError:
            pass
            #raise Exception('AsyncioAuthenicator could not be started. Is it already running?')

    def add_governance_sockets(self, masternode_list, on_deck_masternode,
                               delegate_list, on_deck_delegate):
        self.flush_all_keys()

        for mn in masternode_list:
            self.add_verifying_key(mn)

        for dl in delegate_list:
            self.add_verifying_key(dl)

        if on_deck_masternode is not None:
            self.add_verifying_key(on_deck_masternode)

        if on_deck_delegate is not None:
            self.add_verifying_key(on_deck_delegate)

        self.authenticator.configure_curve(domain=self.domain,
                                           location=self.cert_dir)

    def add_verifying_key(self, vk: bytes):
        # Convert to bytes if hex string
        if isinstance(vk, str):
            vk = bytes.fromhex(vk)

        try:
            pk = crypto_sign_ed25519_pk_to_curve25519(vk)
        # Error is thrown if the VK is not within the possibility space of the ED25519 algorithm
        except RuntimeError:
            print('no go')
            return

        zvk = z85.encode(pk).decode('utf-8')
        _write_key_file(self.cert_dir / f'{vk.hex()}.key',
                        banner=_cert_public_banner,
                        public_key=zvk)

    def flush_all_keys(self):
        shutil.rmtree(str(self.cert_dir))
        self.cert_dir.mkdir(parents=True, exist_ok=True)
コード例 #13
0
ファイル: authentication.py プロジェクト: baby636/lamden
class SocketAuthenticator:
    def __init__(self, client: ContractingClient, ctx: zmq.asyncio.Context, bootnodes: dict={},
                 loop=asyncio.get_event_loop(), domain='*', cert_dir=CERT_DIR, debug=False):

        # Create the directory if it doesn't exist
        self.client = client

        self.cert_dir = pathlib.Path.home() / cert_dir
        self.cert_dir.mkdir(parents=True, exist_ok=True)

        self.ctx = ctx

        self.domain = domain

        self.loop = loop

        self.log = get_logger('zmq.auth')
        self.log.propagate = debug

        self.bootnodes = bootnodes

        # This should throw an exception if the socket already exist
        try:
            self.authenticator = AsyncioAuthenticator(context=self.ctx, loop=self.loop)
            self.authenticator.start()

        except ZMQBaseError:
            self.log.error('Error starting ZMQ Authenticator. Is it already running?')

        finally:
            for node in bootnodes.keys():
                self.add_verifying_key(node)

            self.authenticator.configure_curve(domain=self.domain, location=self.cert_dir)

    def refresh_governance_sockets(self):
        masternode_list = self.client.get_var(
            contract='masternodes',
            variable='S',
            arguments=['members']
        )

        delegate_list = self.client.get_var(
            contract='delegates',
            variable='S',
            arguments=['members']
        )

        self.flush_all_keys()

        for mn in masternode_list:
            self.add_verifying_key(mn)

        for dl in delegate_list:
            self.add_verifying_key(dl)

        self.log.info(f'Refreshing keys for {len(masternode_list)} masters and {len(delegate_list)} delegates.')

        self.authenticator.configure_curve(domain=self.domain, location=self.cert_dir)

    def add_verifying_key(self, vk: str):
        # Convert to bytes if hex string
        bvk = bytes.fromhex(vk)

        try:
            pk = crypto_sign_ed25519_pk_to_curve25519(bvk)
        # Error is thrown if the VK is not within the possibility space of the ED25519 algorithm
        except RuntimeError:
            self.log.error('ED25519 Cryptographic error. The key provided is not within the cryptographic key space.')
            return

        zvk = z85.encode(pk).decode('utf-8')
        _write_key_file(self.cert_dir / f'{vk}.key', banner=_cert_public_banner, public_key=zvk)

    def flush_all_keys(self):
        shutil.rmtree(str(self.cert_dir))
        self.cert_dir.mkdir(parents=True, exist_ok=True)

    def configure(self):
        self.authenticator.configure_curve(domain=self.domain, location=self.cert_dir)
コード例 #14
0
class Server:
    """
    Server that accepts JSON RPC calls through a socket.
    """
    def __init__(self,
                 rpc_spec: RPCSpec = None,
                 announce_timing: bool = False,
                 serialize_exceptions: bool = True,
                 auth_config: Optional[ServerAuthConfig] = None):
        """
        Create a server that will be linked to a socket

        :param rpc_spec: JSON RPC spec
        :param announce_timing:
        :param serialize_exceptions: If set to True, this Server will catch all exceptions occurring
            internally to it and, when possible, communicate them to the interrogating Client.  If
            set to False, this Server will re-raise any exceptions it encounters (including, but not
            limited to, those which might occur through method calls to rpc_spec) for Server's
            local owner to handle.

            IMPORTANT NOTE: When set to False, this *almost definitely* means an unrecoverable
            crash, and the Server should then be _shutdown().
        :param auth_config: The configuration values necessary to enable Curve ZeroMQ authentication.
            These must be provided at instantiation, so they are available between the creation of the 
            context and socket.
        """
        self.announce_timing = announce_timing
        self.serialize_exceptions = serialize_exceptions

        self.rpc_spec = rpc_spec if rpc_spec else RPCSpec(
            serialize_exceptions=serialize_exceptions)
        self._exit_handlers = []

        self._socket = None
        self._auth_config = auth_config
        self._authenticator = None
        self._preloaded_keys = None

    def rpc_handler(self, f: Callable):
        """
        Add a function to the server. It will respond to JSON RPC requests with the corresponding method name.
        This can be used as both a side-effecting function or as a decorator.

        :param f: Function to add
        :return: Function wrapper (so it can be used as a decorator)
        """
        return self.rpc_spec.add_handler(f)

    def exit_handler(self, f: Callable):
        """
        Add an exit handler - a function which will be called when the server shuts down.

        :param f: Function to add
        """
        self._exit_handlers.append(f)

    async def recv_multipart(self):
        if self.auth_enabled:
            return await self.recv_multipart_with_auth()
        else:
            # If auth is not enabled, then the client "User-Id" will not be retrieved from
            #   the frames received, and we return None for that value.
            return (*await self._socket.recv_multipart(), None)

    async def recv_multipart_with_auth(self) -> Tuple[bytes, list, bytes]:
        """
        Code taken from pyzmq itself: https://github.com/zeromq/pyzmq/blob/master/zmq/sugar/socket.py#L449
          and then adapted to allow us to access the information in the frames.

        When copy=True, only the contents of the messages are returned, rather than the messages themselves.
          The message is necessary to be able to fetch the "User-Id", which is the public key the client used
          to connect to this socket.

        When using auth, knowing which client sent which message is important for authentication, and so 
          we reimplement recv_multipart here, and return the client key as the final member of a tuple
        """

        copy = False
        # Given a ROUTER socket, the first frame will be the sender's identity.
        #   While, per the docs, this _should_ be retrievable from any frame with
        #   frame.get('Identity'), in practice this value was always blank.
        #   If we don't record the identity value, messages cannot be returned to
        #   the correct client.
        identity_frame = await self._socket.recv(0, copy=copy, track=False)
        identity = identity_frame.bytes

        # The client_id is the public key the client used to establish this connection
        #   It can be retrieved from all frames after the first. Here, we assume it
        #   is the same among all frames, and set it to the value from the first frame
        client_key = None

        # After the identity frame, we assemble all further frames in a single buffer.
        parts = bytearray(b'')
        while self._socket.getsockopt(zmq.RCVMORE):
            part = await self._socket.recv(0, copy=copy, track=False)
            data = part.bytes
            if client_key is None:
                client_key = part.get('User-Id')
                if not isinstance(client_key,
                                  bytes) and client_key is not None:
                    client_key = client_key.encode('utf-8')
            parts += data

        _log.debug(
            f'Received authenticated request from client_key {client_key}')

        return (identity, parts, client_key)

    async def run_async(self, endpoint: str):
        """
        Run server main task (asynchronously).

        :param endpoint: Socket endpoint to listen to, e.g. "tcp://*:1234"
        """
        self._connect(endpoint)

        # spawn an initial listen task
        listen_task = asyncio.ensure_future(self.recv_multipart())
        task_list = [listen_task]

        while True:
            dones, pendings = await asyncio.wait(
                task_list, return_when=asyncio.FIRST_COMPLETED)

            # grab one "done" task to handle
            task_list, done_list = list(pendings), list(dones)
            done = done_list.pop()
            task_list += done_list

            if done == listen_task:
                try:
                    # empty_frame may either be:
                    # 1. a single null frame if the client is a REQ socket
                    # 2. an empty list (ie. no frames) if the client is a DEALER socket
                    identity, *empty_frame, msg, client_key = done.result()
                    request = from_msgpack(msg)
                    request.params['client_key'] = client_key

                    # spawn a processing task
                    task_list.append(
                        asyncio.ensure_future(
                            self._process_request(identity, empty_frame,
                                                  request)))
                except Exception as e:
                    if self.serialize_exceptions:
                        _log.exception(
                            'Exception thrown in Server run loop during request '
                            'reception: {}'.format(repr(e)))
                    else:
                        raise e
                finally:
                    # spawn a new listen task
                    listen_task = asyncio.ensure_future(self.recv_multipart())
                    task_list.append(listen_task)
            else:
                # if there's been an exception during processing, consider reraising it
                try:
                    done.result()
                except Exception as e:
                    if self.serialize_exceptions:
                        _log.exception(
                            'Exception thrown in Server run loop during request '
                            'dispatch: {}'.format(repr(e)))
                    else:
                        raise e

    def run(self, endpoint: str, loop: AbstractEventLoop = None):
        """
        Run server main task.

        :param endpoint: Socket endpoint to listen to, e.g. "tcp://*:1234"
        :param loop: Event loop to run server in (alternatively just use run_async method)
        """
        if not loop:
            loop = asyncio.get_event_loop()

        try:
            loop.run_until_complete(self.run_async(endpoint))
        except KeyboardInterrupt:
            self._shutdown()

    def stop(self):
        """
        DEPRECATED
        """
        pass

    def _shutdown(self):
        """
        Shut down the server.
        """
        for exit_handler in self._exit_handlers:
            exit_handler()

        if self._socket:
            self._socket.close()
            self._socket = None

    def _connect(self, endpoint: str):
        """
        Connect the server to an endpoint. Creates a ZMQ ROUTER socket for the given endpoint.

        :param endpoint: Socket endpoint, e.g. "tcp://*:1234"
        """
        if self._socket:
            raise RuntimeError(
                'Cannot run multiple Servers on the same socket')

        context = zmq.asyncio.Context()
        self._socket = context.socket(zmq.ROUTER)
        self.start_auth(context)
        self._socket.bind(endpoint)

        _log.info("Starting server, listening on endpoint {}".format(endpoint))

    async def _process_request(self, identity: bytes, empty_frame: list,
                               request: RPCRequest):
        """
        Executes the method specified in a JSON RPC request and then sends the reply to the socket.

        :param identity: Client identity provided by ZeroMQ
        :param empty_frame: Either an empty list or a single null frame depending on the client type
        :param request: JSON RPC request
        """
        try:
            _log.debug("Client %s sent request: %s", identity, request)
            start_time = datetime.now()
            reply = await self.rpc_spec.run_handler(request)
            if self.announce_timing:
                _log.info("Request {} for {} lasted {} seconds".format(
                    request.id, request.method,
                    (datetime.now() - start_time).total_seconds()))

            _log.debug("Sending client %s reply: %s", identity, reply)
            await self._socket.send_multipart(
                [identity, *empty_frame,
                 to_msgpack(reply)])
        except Exception as e:
            if self.serialize_exceptions:
                _log.exception('Exception thrown in _process_request')
            else:
                raise e

    @property
    def auth_configured(self) -> bool:
        return (self._auth_config is not None) and isinstance(
            self._auth_config.server_secret_key, bytes) and isinstance(
                self._auth_config.server_public_key, bytes)

    @property
    def auth_enabled(self) -> bool:
        return bool(self._socket and self._socket.curve_server)

    def start_auth(self, context: zmq.Context) -> bool:
        """
        Starts the ZMQ auth service thread, enabling authorization on all sockets within this context
        """
        if not self.auth_configured:
            return False
        self._socket.curve_secretkey = self._auth_config.server_secret_key
        self._socket.curve_publickey = self._auth_config.server_public_key
        self._socket.curve_server = True
        self._authenticator = AsyncioAuthenticator(context)
        if self._preloaded_keys:
            self.set_client_keys(self._preloaded_keys)
        else:
            self.load_client_keys_from_directory()
        self._authenticator.start()
        return True

    def stop_auth(self) -> bool:
        """
        Stops the ZMQ auth service thread, allowing NULL authenticated clients (only) to connect to
            all threads within its context
        """
        if self._authenticator:
            self._socket.curve_server = False
            self._authenticator.stop()
            return True
        else:
            return False

    def load_client_keys_from_directory(self,
                                        directory: Optional[str] = None
                                        ) -> bool:
        """
        Reset authorized public key list to those present in the specified directory
        """

        # The directory must either be specified at class creation or on each method call
        if directory is None:
            if self._auth_config.client_keys_directory:
                directory = self._auth_config.client_keys_directory
        if not directory or not self.auth_configured:
            return False
        self._authenticator.configure_curve(domain='*', location=directory)
        return True

    def set_client_keys(self, client_keys: List[bytes]):
        """
        Reset authorized public key list to this set. Avoids the disk read required by configure_curve,
            and allows keys to be managed externally.

        In some cases, keys may be preloaded before the authenticator is started. In this case, we 
            cache those preloaded keys
        """
        if self._authenticator:
            _log.debug(f"Authorizer: Setting client keys to {client_keys}")
            self._authenticator.certs['*'] = {key: True for key in client_keys}
        else:
            _log.debug(f"Authorizer: Preloading client keys to {client_keys}")
            self._preloaded_keys = client_keys
コード例 #15
0
ファイル: asyncio-ironhouse.py プロジェクト: xj-sun/pyzmq
async def run():
    ''' Run Ironhouse example '''

    # These directories are generated by the generate_certificates script
    base_dir = Path(__file__).parent
    keys_dir = base_dir / 'certificates'
    public_keys_dir = base_dir / 'public_keys'
    secret_keys_dir = base_dir / 'private_keys'

    if not keys_dir.is_dir() or not public_keys_dir.is_dir(
    ) or not secret_keys_dir.is_dir():
        logging.critical(
            "Certificates are missing - run generate_certificates.py script first"
        )
        sys.exit(1)

    ctx = Context.instance()

    # Start an authenticator for this context.
    auth = AsyncioAuthenticator(ctx)
    auth.start()
    auth.allow('127.0.0.1')
    # Tell authenticator to use the certificate in a directory
    auth.configure_curve(domain='*', location=public_keys_dir)

    server = ctx.socket(zmq.PUSH)

    server_secret_file = secret_keys_dir / "server.key_secret"
    server_public, server_secret = zmq.auth.load_certificate(
        server_secret_file)
    server.curve_secretkey = server_secret
    server.curve_publickey = server_public
    server.curve_server = True  # must come before bind
    server.bind('tcp://*:9000')

    client = ctx.socket(zmq.PULL)

    # We need two certificates, one for the client and one for
    # the server. The client must know the server's public key
    # to make a CURVE connection.
    client_secret_file = secret_keys_dir / "client.key_secret"
    client_public, client_secret = zmq.auth.load_certificate(
        client_secret_file)
    client.curve_secretkey = client_secret
    client.curve_publickey = client_public

    server_public_file = public_keys_dir / "server.key"
    server_public, _ = zmq.auth.load_certificate(server_public_file)
    # The client must know the server's public key to make a CURVE connection.
    client.curve_serverkey = server_public
    client.connect('tcp://127.0.0.1:9000')

    await server.send(b"Hello")

    if await client.poll(1000):
        msg = await client.recv()
        if msg == b"Hello":
            logging.info("Ironhouse test OK")
    else:
        logging.error("Ironhouse test FAIL")

    # close sockets
    server.close()
    client.close()
    # stop auth task
    auth.stop()