Ejemplo n.º 1
0
    def __init__(
        self,
        args,
        encrypted_traffic_data=None,
        interface=None,
    ):
        """Initialize the Driver.

        :param args: Arguments parsed by argparse.
        :type args: Object
        "param encrypted_traffic: Enable|Disable encrypted traffic.
        :type encrypted_traffic: Boolean
        """

        self.encrypted_traffic_data = encrypted_traffic_data
        self.log = logger.getLogger(name="directord")
        self.args = args

        self.identity = getattr(args, "identity", socket.gethostname())
        if not self.identity:
            self.identity = socket.gethostname()

        self.machine_id = getattr(args, "machine_id", self.get_machine_id())
        if not self.machine_id:
            self.machine_id = self.get_machine_id()

        self.interface = interface
Ejemplo n.º 2
0
    def __init__(self, host, username, port, key_file=None, debug=False):
        """Initialize the connection manager.

        :param host: IP or Domain to connect to.
        :type host: String
        :param username: Username for the connection.
        :type username: String
        :param port: Port number used to connect to the remote server.
        :type port: Int
        :param key_file: SSH key file used to connect.
        :type key_file: String
        :param debug: Enable or disable debug mode
        :type debug: Boolean
        """

        self.log = logger.getLogger(name="directord-ssh", debug_logging=debug)
        self.key_file = key_file
        self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self.sock.connect((host, port))

        self.session = Session()
        self.session.options_set(options.HOST, host)
        self.session.options_set(options.USER, username)
        self.session.options_set_port(port)
        self.session.set_socket(self.sock)
        self.session.connect()

        self.log.debug("Handshake with [ %s ] on port [ %s ] complete.", host,
                       port)

        self.channels = dict()
        self.host = host
        self.username = username
        self.key_file = key_file
Ejemplo n.º 3
0
 def test_getlogger_logger_missed_handlers(self):
     with patch("logging.handlers", autospec=True) as mock_handlers:
         mock_handlers.return_value = [{"name": "nottestLogger"}]
         log = logger.getLogger(name="testLogger")
     for handler in log.handlers:
         return self.assertTrue(handler.name == "testLogger")
     else:
         self.fail("The log handler name was not set")
Ejemplo n.º 4
0
    def __init__(self, desc=None):
        """Initialize the component base class.

        When setting up a component, the init should be inheritted allowing
        user defined components to have access to the full suite of defaults.

        > Set the `self.cacheable` object True|False according to how the
          component should be treated in terms of on system cache.
        """

        self.desc = desc
        self.log = logger.getLogger(name="directord")
        self.blueprint = jinja2.Environment(
            loader=jinja2.BaseLoader(),
            keep_trailing_newline=True,
            undefined=StrictUndefined,
        )
        self.known_args = None
        self.unknown_args = None
        self.cacheable = True  # Enables|Disables component caching
        self.requires_lock = False  # Enables|Disables component locking
Ejemplo n.º 5
0
    def __init__(self, catalog, key_file, threads, debug=False):
        """Initialize the Directord mixin.

        Sets up the mixin object.

        :param args: Arguments parsed by argparse.
        :type args: Object
        """

        super(Bootstrap, self).__init__()
        self.catalog = catalog
        self.key_file = key_file
        self.threads = threads
        self.debug = debug
        self.blueprint = jinja2.Environment(
            loader=jinja2.BaseLoader(),
            keep_trailing_newline=True,
            undefined=StrictUndefined,
        )
        self.log = logger.getLogger(name="directord", debug_logging=self.debug)
        self.indicator = None
        self.return_queue = self.get_queue()
Ejemplo n.º 6
0
def send_data(socket_path, data):
    """Send data to the socket path.

    The send method takes serialized data and submits it to the given
    socket path.

    This method will return information provided by the server in
    String format.

    :returns: String
    """

    try:
        with UNIXSocketConnect(socket_path) as s:
            if not s:
                raise SystemExit("No connection available to server.")
            s.sendall(data.encode())
            fragments = []
            while True:
                chunk = s.recv(1024)
                if not chunk:
                    break

                fragments.append(chunk)
            return b"".join(fragments)
    except PermissionError:
        log = logger.getLogger(name="directord")
        error_msg = (
            "Permission error writing to {}. Check write permissions.".format(
                socket_path
            )
        )
        log.error(
            error_msg,
        )
        raise PermissionError(error_msg)
Ejemplo n.º 7
0
class Driver(drivers.BaseDriver):
    def __init__(
        self,
        args,
        encrypted_traffic_data=None,
        interface=None,
    ):
        """Initialize the Driver.

        :param args: Arguments parsed by argparse.
        :type args: Object
        :param encrypted_traffic: Enable|Disable encrypted traffic.
        :type encrypted_traffic: Boolean
        :param interface: The interface instance (client/server)
        :type interface: Object
        """

        super(Driver, self).__init__(
            args=args,
            encrypted_traffic_data=encrypted_traffic_data,
            interface=interface,
        )
        self.mode = getattr(args, "mode", None)

        self.proto = "amqp"
        self.connection_string = "{proto}://{addr}".format(
            proto=self.proto, addr=self.args.messaging_address
        )

        self.conf = self._rpc_conf()
        self.transport = self._rpc_transport()
        self.server = None
        self.backend_server = None
        self.job_q = queue.Queue()
        self.backend_q = queue.Queue()
        self.send_q = queue.Queue()
        self.process_send_q = None
        self.timeout = 1

    def _rpc_conf(self):
        """Initialize the RPC configuration.

        :returns: Object
        """

        conf = cfg.CONF

        # Load the amqp driver from the oslo.messaging.drivers entrypoint and
        # instantiate an instance. This is just so that we can get the options
        # registered in the conf object.
        for oslo_driver in pkg_resources.iter_entry_points(
            "oslo.messaging.drivers"
        ):
            if oslo_driver.name == "amqp":
                proton_driver = oslo_driver.load()
                proton_driver(conf, transport.TransportURL(conf))
                break

        conf.set_default(
            "ssl_cert_file",
            self.args.messaging_ssl_cert,
            "oslo_messaging_amqp",
        )
        conf.set_default(
            "ssl_key_file",
            self.args.messaging_ssl_key,
            "oslo_messaging_amqp",
        )
        conf.set_default(
            "ssl",
            self.args.messaging_ssl,
            "oslo_messaging_amqp",
        )
        conf.set_default(
            "ssl_ca_file",
            self.args.messaging_ssl_ca,
            "oslo_messaging_amqp",
        )

        conf.transport_url = "{}:5672/".format(self.connection_string)

        return conf

    def _check(self, queue, interval=1, constant=1000):
        """Return True if a job contains work ready.

        :param queue: Queueing object.
        :type queue: Object
        :param interval: Exponential Interval used to determine the polling
                         duration for a given socket.
        :type interval: Integer
        :param constant: Constant time used to poll for new jobs.
        :type constant: Integer
        :returns: Boolean
        """

        self.timeout = interval * (constant * 0.001)
        if queue.empty():
            time.sleep(self.timeout)
            return False
        else:
            return True

    @expose
    def _heartbeat(self, *args, **kwargs):
        """Handle a heartbeat interaction.

        Because this method is exposed to the RPC server, some named objects
        may be passed through which are unused. To support these extra args
        and kwargs, *args, is accepted but dumped.

        :param identity: Client identity
        :type identity: String
        :param job_id: Job Id
        :type job_id: String
        :param control: Job control character
        :type control: String
        :param data: Heartbeat data
        :type data: Dictionary
        """

        self.log.debug("Handling heartbeat for [ %s ]", kwargs.get("identity"))
        self.job_q.put(
            [
                kwargs.get("identity"),
                kwargs.get("job_id"),
                kwargs.get("control"),
                None,
                kwargs.get("data"),
                None,
                None,
                None,
            ]
        )

    @expose
    def _job(
        self,
        *args,
        **kwargs,
    ):
        """Handle a job interaction.

        Because this method is exposed to the RPC server, some named objects
        may be passed through which are unused. To support these extra args
        and kwargs, *args, is accepted but dumped.

        :param identity: Client identity
        :type identity: String
        :param job_id: Job Id
        :type job_id: String
        :param control: Job control character
        :type control: String
        :param command: Command
        :type command: String
        :param data: Job data
        :type data: Dictionary
        :param info: Job info
        :type info: Dictionary
        :param stderr: Job stderr output
        :type stderr: String
        :param stdout: Job stdout output
        :type stdout: String
        """

        self.log.debug(
            "Handling job [ %s ] for [ %s ]",
            kwargs.get("job_id"),
            kwargs.get("identity"),
        )
        job = [
            kwargs.get("job_id"),
            kwargs.get("control"),
            kwargs.get("command"),
            kwargs.get("data"),
            kwargs.get("info"),
            kwargs.get("stderr"),
            kwargs.get("stdout"),
        ]

        if self.mode == "server":
            job.insert(0, kwargs.get("identity"))

        self.job_q.put(job)

    @expose
    def _backend(
        self,
        *args,
        **kwargs,
    ):
        """Handle a backend interaction.

        Because this method is exposed to the RPC server, some named objects
        may be passed through which are unused. To support these extra args
        and kwargs, *args, is accepted but dumped.

        :param identity: Client identity
        :type identity: String
        :param job_id: Job Id
        :type job_id: String
        :param control: Job control character
        :type control: String
        :param command: Command
        :type command: String
        :param data: Job data
        :type data: Dictionary
        :param info: Job info
        :type info: Dictionary
        :param stderr: Job stderr output
        :type stderr: String
        :param stdout: Job stdout output
        :type stdout: String
        """

        self.log.debug(
            "Handling backend [ %s ] for [ %s ]",
            kwargs.get("job_id"),
            kwargs.get("identity"),
        )
        job = [
            kwargs.get("job_id"),
            kwargs.get("control"),
            kwargs.get("command"),
            kwargs.get("data"),
            kwargs.get("info"),
            kwargs.get("stderr"),
            kwargs.get("stdout"),
        ]

        if self.mode == "server":
            job.insert(0, kwargs.get("identity"))

        self.backend_q.put(job)

    def _close(self, process_obj):
        """Close the backend.

        :param process_obj: Server process object
        :type process_obj: Object
        """

        self.log.debug("Stopping messaging server")
        if not process_obj:
            self.log.debug("No server to stop")
        else:
            process_obj.stop()
            process_obj.wait()
            self.log.debug("Server to stopped")

    def _init_rpc_servers(self):
        """Initialize the rpc server."""

        if self.mode == "server":
            server_target = "directord"
            pool_size = 16
        else:
            server_target = self.machine_id
            pool_size = 1

        if not self.backend_server:
            self.backend_server = self._rpc_server(
                server_target=server_target, topic="directord-backend"
            )
            self.log.info("Starting messaging backend server")
            self.backend_server.start(override_pool_size=pool_size)

        if not self.server:
            self.server = self._rpc_server(
                server_target=server_target, topic="directord"
            )
            self.log.info("Starting messaging server")
            self.server.start(override_pool_size=pool_size)

    def _process_send(
        self,
        method,
        topic,
        identity=None,
        msg_id=None,
        control=None,
        command=None,
        data=None,
        info=None,
        stderr=None,
        stdout=None,
    ):
        """Send a job message.

        :param method: messaging method
        :type method: String
        :param topic: Messaging topic
        :type topic: String
        :param identity: Client identity
        :type identity: String
        :param job_id: Job Id
        :type job_id: String
        :param control: Job control character
        :type control: String
        :param command: Command
        :type command: String
        :param data: Job data
        :type data: Dictionary
        :param info: Job info
        :type info: Dictionary
        :param stderr: Job stderr output
        :type stderr: String
        :param stdout: Job stdout output
        :type stdout: String
        """

        if not identity:
            target = "directord"
            identity = self.identity
        else:
            worker = self.interface.workers.get(identity)
            target = worker.machine_id

            if not worker.machine_id:
                self.log.fatal(
                    "Machine ID for identity [ %s ] not found", identity
                )
                return

        self._send(
            method=method,
            topic=topic,
            server=target,
            identity=identity,
            job_id=msg_id,
            control=control,
            command=command,
            data=data,
            info=info,
            stderr=stderr,
            stdout=stdout,
        )

    def _rpc_transport(self):
        """Returns an rpc transport.

        :returns: Object
        """

        return oslo_messaging.get_rpc_transport(self.conf)

    def _rpc_server(self, server_target, topic):
        """Returns an rpc server object.

        :param server_target: OSLO target object
        :type server_target: Object
        :param topic: Messaging topic
        :type topic: String
        :returns: Object
        """

        return oslo_messaging.get_rpc_server(
            transport=self.transport,
            target=oslo_messaging.Target(
                topic=topic,
                server=server_target,
            ),
            endpoints=[self],
            executor="threading",
            access_policy=dispatcher.ExplicitRPCAccessPolicy,
        )

    @tenacity.retry(
        retry=tenacity.retry_if_exception_type(Exception),
        wait=tenacity.wait_fixed(1),
        before_sleep=tenacity.before_sleep_log(
            logger.getLogger(name="directord"), logging.WARN
        ),
    )
    def _send(self, method, topic, server="directord", **kwargs):
        """Send a message.

        :param method: Send method type
        :type method: String
        :param topic: Messaging topic
        :type topic: String
        :param method: Server name
        :type method: String
        :param kwargs: Extra named arguments
        :type kwargs: Dictionary
        :returns: Object
        """

        if server:
            target = oslo_messaging.Target(topic=topic, server=server)
        else:
            target = oslo_messaging.Target(topic=topic)

        client = oslo_messaging.RPCClient(
            self.transport, target, timeout=2, retry=3
        )

        try:
            return client.call({}, method, **kwargs)
        except Exception as e:
            self.log.warn(
                "Failed to send message using topic [ %s ] to server [ %s ]",
                topic,
                server,
            )
            raise e

    def backend_check(self, interval=1, constant=1000):
        """Return True if the backend contains work ready.

        :param interval: Exponential Interval used to determine the polling
                         duration for a given socket.
        :type interval: Integer
        :param constant: Constant time used to poll for new jobs.
        :type constant: Integer
        :returns: Boolean
        """

        return self._check(
            queue=self.backend_q, interval=interval, constant=constant
        )

    def backend_close(self):
        """Close the backend."""

        self.log.debug(
            "The messaging driver does not initialize a backend connection"
            " so nothing to close"
        )

    def backend_init(self, sentinel=False):
        """Initialize servers.

        :param sentinel: Breaks the loop
        :type sentinel: Boolean
        """

        self.log.debug(
            "The messaging driver does not initialize a backend connection"
            " so nothing to start"
        )

    def backend_recv(self):
        """Receive a message."""

        return self.backend_q.get()

    def backend_send(
        self,
        identity=None,
        msg_id=None,
        control=None,
        command=None,
        data=None,
        info=None,
        stderr=None,
        stdout=None,
    ):
        """Send a message over the backend.

        :param identity: Client identity
        :type identity: String
        :param job_id: Job Id
        :type job_id: String
        :param control: Job control character
        :type control: String
        :param command: Command
        :type command: String
        :param data: Job data
        :type data: Dictionary
        :param info: Job info
        :type info: Dictionary
        :param stderr: Job stderr output
        :type stderr: String
        :param stdout: Job stdout output
        :type stdout: String
        """

        self._process_send(
            method="_backend",
            topic="directord-backend",
            identity=identity,
            msg_id=msg_id,
            control=control,
            command=command,
            data=data,
            info=info,
            stderr=stderr,
            stdout=stdout,
        )

    def heartbeat_send(
        self, host_uptime=None, agent_uptime=None, version=None, driver=None
    ):
        """Send a heartbeat.

        :param host_uptime: Sender uptime
        :type host_uptime: String
        :param agent_uptime: Sender agent uptime
        :type agent_uptime: String
        :param version: Sender directord version
        :type version: String
        :param version: Driver information
        :type version: String
        """

        job_id = utils.get_uuid()
        self.log.info(
            "Job [ %s ] sending heartbeat from [ %s ] to server",
            job_id,
            self.identity,
        )
        return self._process_send(
            method="_heartbeat",
            topic="directord",
            msg_id=job_id,
            control=self.heartbeat_notice,
            data=json.dumps(
                {
                    "job_id": job_id,
                    "version": version,
                    "host_uptime": host_uptime,
                    "agent_uptime": agent_uptime,
                    "machine_id": self.machine_id,
                    "driver": driver,
                }
            ),
        )

    def job_check(self, interval=1, constant=1000):
        """Return True if a job contains work ready.

        :param interval: Exponential Interval used to determine the polling
                         duration for a given socket.
        :type interval: Integer
        :param constant: Constant time used to poll for new jobs.
        :type constant: Integer
        :returns: Boolean
        """

        return self._check(
            queue=self.job_q, interval=interval, constant=constant
        )

    def job_close(self):
        """Stop the server mode."""

        self._close(process_obj=self.server)
        self._close(process_obj=self.backend_server)

    def job_init(self, sentinel=False):
        """Initialize servers.

        :param sentinel: Breaks the loop
        :type sentinel: Boolean
        """

        self._init_rpc_servers()

    def job_recv(self):
        """Receive a message."""

        return self.job_q.get()

    def job_send(
        self,
        identity=None,
        msg_id=None,
        control=None,
        command=None,
        data=None,
        info=None,
        stderr=None,
        stdout=None,
    ):
        """Send a job message.

        :param identity: Client identity
        :type identity: String
        :param job_id: Job Id
        :type job_id: String
        :param control: Job control character
        :type control: String
        :param command: Command
        :type command: String
        :param data: Job data
        :type data: Dictionary
        :param info: Job info
        :type info: Dictionary
        :param stderr: Job stderr output
        :type stderr: String
        :param stdout: Job stdout output
        :type stdout: String
        """

        self._process_send(
            method="_job",
            topic="directord",
            identity=identity,
            msg_id=msg_id,
            control=control,
            command=command,
            data=data,
            info=info,
            stderr=stderr,
            stdout=stdout,
        )
Ejemplo n.º 8
0
class Driver(drivers.BaseDriver):
    def __init__(
        self,
        args,
        encrypted_traffic_data=None,
        interface=None,
    ):
        """Initialize the Driver.

        :param args: Arguments parsed by argparse.
        :type args: Object
        :param encrypted_traffic: Enable|Disable encrypted traffic.
        :type encrypted_traffic: Boolean
        :param interface: The interface instance (client/server)
        :type interface: Object
        """

        self.thread_processor = multiprocessing.Process
        self.event = multiprocessing.Event()
        self.semaphore = multiprocessing.Semaphore
        self.flushqueue = _FlushQueue
        self.args = args
        if getattr(self.args, "zmq_generate_keys", False) is True:
            self._generate_certificates()
            print("New certificates generated")
            raise SystemExit(0)

        self.encrypted_traffic_data = encrypted_traffic_data

        mode = getattr(self.args, "mode", None)
        if mode == "client":
            self.bind_address = self.args.zmq_server_address
        elif mode == "server":
            self.bind_address = self.args.zmq_bind_address
        else:
            self.bind_address = "*"
        self.proto = "tcp"
        self.connection_string = "{proto}://{addr}".format(
            proto=self.proto, addr=self.bind_address)

        if self.encrypted_traffic_data:
            self.encrypted_traffic = self.encrypted_traffic_data.get("enabled")
            self.secret_keys_dir = self.encrypted_traffic_data.get(
                "secret_keys_dir")
            self.public_keys_dir = self.encrypted_traffic_data.get(
                "public_keys_dir")
        else:
            self.encrypted_traffic = False
            self.secret_keys_dir = None
            self.public_keys_dir = None

        self._context = zmq.Context()
        self.ctx = self._context.instance()
        self.poller = zmq.Poller()
        self.interface = interface
        super(Driver, self).__init__(
            args=args,
            encrypted_traffic_data=self.encrypted_traffic_data,
            interface=interface,
        )
        self.bind_job = None
        self.bind_backend = None
        self.hwm = getattr(self.args, "zmq_highwater_mark", 1024)

    def __copy__(self):
        """Return a new copy of the driver."""

        return Driver(
            args=self.args,
            encrypted_traffic_data=self.encrypted_traffic_data,
            interface=self.interface,
        )

    def _backend_bind(self):
        """Bind an address to a backend socket and return the socket.

        :returns: Object
        """

        bind = self._socket_bind(
            socket_type=zmq.ROUTER,
            connection=self.connection_string,
            port=self.args.backend_port,
        )
        bind.set_hwm(self.hwm)
        self.log.debug(
            "Identity [ %s ] backend connect hwm state [ %s ]",
            self.identity,
            bind.get_hwm(),
        )
        return bind

    def _backend_connect(self):
        """Connect to a backend socket and return the socket.

        :returns: Object
        """

        self.log.debug("Establishing backend connection.")
        bind = self._socket_connect(
            socket_type=zmq.DEALER,
            connection=self.connection_string,
            port=self.args.backend_port,
        )
        bind.set_hwm(self.hwm)
        self.log.debug(
            "Identity [ %s ] backend connect hwm state [ %s ]",
            self.identity,
            bind.get_hwm(),
        )
        return bind

    def _bind_check(self, bind, interval=1, constant=1000):
        """Return True if a bind type contains work ready.

        :param bind: A given Socket bind to identify.
        :type bind: Object
        :param interval: Exponential Interval used to determine the polling
                         duration for a given socket.
        :type interval: Integer
        :param constant: Constant time used to poll for new jobs.
        :type constant: Integer
        :returns: Object
        """

        socks = dict(self.poller.poll(interval * constant))
        if socks.get(bind) == zmq.POLLIN:
            return True
        else:
            return False

    def _close(self, socket):
        if socket is None:
            return

        try:
            socket.close(linger=2)
            close_time = time.time()
            while not socket.closed:
                if time.time() - close_time > 60:
                    raise TimeoutError(
                        "Job [ {} ] failed to close transfer socket".format(
                            self.job_id))
                else:
                    socket.close(linger=2)
                    time.sleep(1)
        except Exception as e:
            self.log.error(
                "Ran into an exception while closing the socket %s",
                str(e),
            )
        else:
            self.log.debug("Backend socket closed")

    def _generate_certificates(self, base_dir="/etc/directord"):
        """Generate client and server CURVE certificate files.

        :param base_dir: Directord configuration path.
        :type base_dir: String
        """

        keys_dir = os.path.join(base_dir, "certificates")
        public_keys_dir = os.path.join(base_dir, "public_keys")
        secret_keys_dir = os.path.join(base_dir, "private_keys")

        for item in [keys_dir, public_keys_dir, secret_keys_dir]:
            os.makedirs(item, exist_ok=True)

        # Run certificate backup
        self._move_certificates(directory=public_keys_dir, backup=True)
        self._move_certificates(directory=secret_keys_dir,
                                backup=True,
                                suffix=".key_secret")

        # create new keys in certificates dir
        for item in ["server", "client"]:
            self._key_generate(keys_dir=keys_dir, key_type=item)

        # Move generated certificates in place
        self._move_certificates(
            directory=keys_dir,
            target_directory=public_keys_dir,
            suffix=".key",
        )
        self._move_certificates(
            directory=keys_dir,
            target_directory=secret_keys_dir,
            suffix=".key_secret",
        )

    def _job_bind(self):
        """Bind an address to a job socket and return the socket.

        :returns: Object
        """

        return self._socket_bind(
            socket_type=zmq.ROUTER,
            connection=self.connection_string,
            port=self.args.job_port,
        )

    def _job_connect(self):
        """Connect to a job socket and return the socket.

        :returns: Object
        """

        self.log.debug("Establishing Job connection.")
        return self._socket_connect(
            socket_type=zmq.DEALER,
            connection=self.connection_string,
            port=self.args.job_port,
        )

    def _key_generate(self, keys_dir, key_type):
        """Generate certificate.

        :param keys_dir: Full Directory path where a given key will be stored.
        :type keys_dir: String
        :param key_type: Key type to be generated.
        :type key_type: String
        """

        zmq_auth.create_certificates(keys_dir, key_type)

    @staticmethod
    def _move_certificates(directory,
                           target_directory=None,
                           backup=False,
                           suffix=".key"):
        """Move certificates when required.

        :param directory: Set the origin path.
        :type directory: String
        :param target_directory: Set the target path.
        :type target_directory: String
        :param backup: Enable file backup before moving.
        :type backup:  Boolean
        :param suffix: Set the search suffix
        :type suffix: String
        """

        for item in os.listdir(directory):
            if backup:
                target_file = "{}.bak".format(os.path.basename(item))
            else:
                target_file = os.path.basename(item)

            if item.endswith(suffix):
                os.rename(
                    os.path.join(directory, item),
                    os.path.join(target_directory or directory, target_file),
                )

    def _socket_bind(self, socket_type, connection, port, poller_type=None):
        """Return a socket object which has been bound to a given address.

        When the socket_type is not PUB or PUSH, the bound socket will also be
        registered with self.poller as defined within the Interface
        class.

        :param socket_type: Set the Socket type, typically defined using a
                            ZeroMQ constant.
        :type socket_type: Integer
        :param connection: Set the Address information used for the bound
                           socket.
        :type connection: String
        :param port: Define the port which the socket will be bound to.
        :type port: Integer
        :param poller_type: Set the Socket type, typically defined using a
                            ZeroMQ constant.
        :type poller_type: Integer
        :returns: Object
        """

        if poller_type is None:
            poller_type = zmq.POLLIN

        bind = self._socket_context(socket_type=socket_type)
        auth_enabled = (self.args.zmq_shared_key
                        or self.args.zmq_curve_encryption)

        if auth_enabled:
            self.auth = ThreadAuthenticator(self.ctx, log=self.log)
            self.auth.start()
            self.auth.allow()

            if self.args.zmq_shared_key:
                # Enables basic auth
                self.auth.configure_plain(
                    domain="*", passwords={"admin": self.args.zmq_shared_key})
                bind.plain_server = True  # Enable shared key authentication
                self.log.info("Shared key authentication enabled.")
            elif self.args.zmq_curve_encryption:
                server_secret_file = os.path.join(self.secret_keys_dir,
                                                  "server.key_secret")
                for item in [
                        self.public_keys_dir,
                        self.secret_keys_dir,
                        server_secret_file,
                ]:
                    if not os.path.exists(item):
                        raise SystemExit(
                            "The required path [ {} ] does not exist. Have"
                            " you generated your keys?".format(item))
                self.auth.configure_curve(domain="*",
                                          location=self.public_keys_dir)
                try:
                    server_public, server_secret = zmq_auth.load_certificate(
                        server_secret_file)
                except OSError as e:
                    self.log.error(
                        "Failed to load certificates: %s, Configuration: %s",
                        str(e),
                        vars(self.args),
                    )
                    raise SystemExit("Failed to load certificates")
                else:
                    bind.curve_secretkey = server_secret
                    bind.curve_publickey = server_public
                    bind.curve_server = True  # Enable curve authentication
        bind.bind("{connection}:{port}".format(
            connection=connection,
            port=port,
        ))

        if socket_type not in [zmq.PUB]:
            self.poller.register(bind, poller_type)

        return bind

    def _socket_connect(self, socket_type, connection, port, poller_type=None):
        """Return a socket object which has been bound to a given address.

        > A connection back to the server will wait 10 seconds for an ack
          before going into a retry loop. This is done to forcefully cycle
          the connection object to reset.

        :param socket_type: Set the Socket type, typically defined using a
                            ZeroMQ constant.
        :type socket_type: Integer
        :param connection: Set the Address information used for the bound
                           socket.
        :type connection: String
        :param port: Define the port which the socket will be bound to.
        :type port: Integer
        :param poller_type: Set the Socket type, typically defined using a
                            ZeroMQ constant.
        :type poller_type: Integer
        :returns: Object
        """

        if poller_type is None:
            poller_type = zmq.POLLIN

        bind = self._socket_context(socket_type=socket_type)

        if self.args.zmq_shared_key:
            bind.plain_username = b"admin"  # User is hard coded.
            bind.plain_password = self.args.zmq_shared_key.encode()
            self.log.info("Shared key authentication enabled.")
        elif self.args.zmq_curve_encryption:
            client_secret_file = os.path.join(self.secret_keys_dir,
                                              "client.key_secret")
            server_public_file = os.path.join(self.public_keys_dir,
                                              "server.key")
            for item in [
                    self.public_keys_dir,
                    self.secret_keys_dir,
                    client_secret_file,
                    server_public_file,
            ]:
                if not os.path.exists(item):
                    raise SystemExit(
                        "The required path [ {} ] does not exist. Have"
                        " you generated your keys?".format(item))
            try:
                client_public, client_secret = zmq_auth.load_certificate(
                    client_secret_file)
                server_public, _ = zmq_auth.load_certificate(
                    server_public_file)
            except OSError as e:
                self.log.error(
                    "Error while loading certificates: %s. Configuration: %s",
                    str(e),
                    vars(self.args),
                )
                raise SystemExit("Failed to load keys.")
            else:
                bind.curve_secretkey = client_secret
                bind.curve_publickey = client_public
                bind.curve_serverkey = server_public

        if socket_type == zmq.SUB:
            bind.setsockopt_string(zmq.SUBSCRIBE, self.identity)
        else:
            bind.setsockopt_string(zmq.IDENTITY, self.identity)

        self.poller.register(bind, poller_type)
        bind.connect("{connection}:{port}".format(
            connection=connection,
            port=port,
        ))

        self.log.info("Socket connected to [ %s ].", connection)
        return bind

    def _socket_context(self, socket_type):
        """Create socket context and return a bind object.

        :param socket_type: Set the Socket type, typically defined using a
                            ZeroMQ constant.
        :type socket_type: Integer
        :returns: Object
        """

        bind = self.ctx.socket(socket_type)
        bind.linger = getattr(self.args, "heartbeat_interval", 60)
        hwm = int(self.hwm * 4)
        try:
            bind.sndhwm = bind.rcvhwm = hwm
        except AttributeError:
            bind.hwm = hwm

        bind.set_hwm(hwm)
        bind.setsockopt(zmq.SNDHWM, hwm)
        bind.setsockopt(zmq.RCVHWM, hwm)
        if socket_type == zmq.ROUTER:
            bind.setsockopt(zmq.ROUTER_MANDATORY, 1)

        return bind

    @staticmethod
    def _socket_recv(socket, nonblocking=False):
        """Receive a message over a ZM0 socket.

        The message specification for server is as follows.

            [
                b"Identity"
                b"ID",
                b"ASCII Control Characters",
                b"command",
                b"data",
                b"info",
                b"stderr",
                b"stdout",
            ]

        The message specification for client is as follows.

            [
                b"ID",
                b"ASCII Control Characters",
                b"command",
                b"data",
                b"info",
                b"stderr",
                b"stdout",
            ]

        All message parts are byte encoded.

        All possible control characters are defined within the Interface class.
        For more on control characters review the following
        URL(https://donsnotes.com/tech/charsets/ascii.html#cntrl).

        :param socket: ZeroMQ socket object.
        :type socket: Object
        :param nonblocking: Enable non-blocking receve.
        :type nonblocking: Boolean
        """

        if nonblocking:
            flags = zmq.NOBLOCK
        else:
            flags = 0

        return socket.recv_multipart(flags=flags)

    @tenacity.retry(
        retry=tenacity.retry_if_exception_type(Exception),
        wait=tenacity.wait_fixed(5),
        before_sleep=tenacity.before_sleep_log(
            logger.getLogger(name="directord"), logging.WARN),
    )
    def _socket_send(
        self,
        socket,
        identity=None,
        msg_id=None,
        control=None,
        command=None,
        data=None,
        info=None,
        stderr=None,
        stdout=None,
        nonblocking=False,
    ):
        """Send a message over a ZM0 socket.

        The message specification for server is as follows.

            [
                b"Identity"
                b"ID",
                b"ASCII Control Characters",
                b"command",
                b"data",
                b"info",
                b"stderr",
                b"stdout",
            ]

        The message specification for client is as follows.

            [
                b"ID",
                b"ASCII Control Characters",
                b"command",
                b"data",
                b"info",
                b"stderr",
                b"stdout",
            ]

        All message information is assumed to be byte encoded.

        All possible control characters are defined within the Interface class.
        For more on control characters review the following
        URL(https://donsnotes.com/tech/charsets/ascii.html#cntrl).

        :param socket: ZeroMQ socket object.
        :type socket: Object
        :param identity: Target where message will be sent.
        :type identity: Bytes
        :param msg_id: ID information for a given message. If no ID is
                       provided a UUID will be generated.
        :type msg_id: Bytes
        :param control: ASCII control charaters.
        :type control: Bytes
        :param command: Command definition for a given message.
        :type command: Bytes
        :param data: Encoded data that will be transmitted.
        :type data: Bytes
        :param info: Encoded information that will be transmitted.
        :type info: Bytes
        :param stderr: Encoded error information from a command.
        :type stderr: Bytes
        :param stdout: Encoded output information from a command.
        :type stdout: Bytes
        :param nonblocking: Enable non-blocking send.
        :type nonblocking: Boolean
        :returns: Object
        """
        def _encoder(item):
            try:
                return item.encode()
            except AttributeError:
                return item

        if not msg_id:
            msg_id = utils.get_uuid()

        if not control:
            control = self.nullbyte

        if not command:
            command = self.nullbyte

        if not data:
            data = self.nullbyte

        if not info:
            info = self.nullbyte

        if not stderr:
            stderr = self.nullbyte

        if not stdout:
            stdout = self.nullbyte

        message_parts = [msg_id, control, command, data, info, stderr, stdout]

        if identity:
            message_parts.insert(0, identity)

        message_parts = [_encoder(i) for i in message_parts]

        if nonblocking:
            flags = zmq.NOBLOCK
        else:
            flags = 0

        try:
            return socket.send_multipart(message_parts, flags=flags)
        except Exception as e:
            self.log.warn("Failed to send message to [ %s ]", identity)
            raise e

    def _recv(self, socket, nonblocking=False):
        """Receive message.

        :param socket: ZeroMQ socket object.
        :type socket: Object
        :param nonblocking: Enable non-blocking receve.
        :type nonblocking: Boolean
        :returns: Tuple
        """

        recv_obj = self._socket_recv(socket=socket, nonblocking=nonblocking)
        return tuple([i.decode() for i in recv_obj])

    def backend_recv(self, nonblocking=False):
        """Receive a transfer message.

        :param nonblocking: Enable non-blocking receve.
        :type nonblocking: Boolean
        :returns: Tuple
        """

        return self._recv(socket=self.bind_backend, nonblocking=nonblocking)

    def backend_init(self):
        """Initialize the backend socket.

        For server mode, this is a bound local socket.
        For client mode, it is a connection to the server socket.

        :returns: Object
        """

        if self.args.mode == "server":
            self.bind_backend = self._backend_bind()
        else:
            self.bind_backend = self._backend_connect()

    def backend_close(self):
        """Close the backend socket."""

        self._close(socket=self.bind_backend)

    def backend_check(self, interval=1, constant=1000):
        """Return True if the backend contains work ready.

        :param bind: A given Socket bind to identify.
        :type bind: Object
        :param interval: Exponential Interval used to determine the polling
                         duration for a given socket.
        :type interval: Integer
        :param constant: Constant time used to poll for new jobs.
        :type constant: Integer
        :returns: Object
        """

        return self._bind_check(bind=self.bind_backend,
                                interval=interval,
                                constant=constant)

    def backend_send(self, *args, **kwargs):
        """Send a job message.

        * All args and kwargs are passed through to the socket send.

        :returns: Object
        """

        kwargs["socket"] = self.bind_backend
        return self._socket_send(*args, **kwargs)

    @staticmethod
    def get_lock():
        """Returns a thread lock."""

        return multiprocessing.Lock()

    def heartbeat_send(self,
                       host_uptime=None,
                       agent_uptime=None,
                       version=None,
                       driver=None):
        """Send a heartbeat.

        :param host_uptime: Sender uptime
        :type host_uptime: String
        :param agent_uptime: Sender agent uptime
        :type agent_uptime: String
        :param version: Sender directord version
        :type version: String
        :param version: Driver information
        :type version: String
        """

        job_id = utils.get_uuid()
        self.log.info(
            "Job [ %s ] sending heartbeat from [ %s ] to server",
            job_id,
            self.identity,
        )

        return self.job_send(
            control=self.heartbeat_notice,
            msg_id=job_id,
            data=json.dumps({
                "job_id": job_id,
                "version": version,
                "host_uptime": host_uptime,
                "agent_uptime": agent_uptime,
                "machine_id": self.machine_id,
                "driver": driver,
            }),
        )

    def job_send(self, *args, **kwargs):
        """Send a job message.

        * All args and kwargs are passed through to the socket send.

        :returns: Object
        """

        kwargs["socket"] = self.bind_job
        return self._socket_send(*args, **kwargs)

    def job_recv(self, nonblocking=False):
        """Receive a transfer message.

        :param nonblocking: Enable non-blocking receve.
        :type nonblocking: Boolean
        :returns: Tuple
        """

        return self._recv(socket=self.bind_job, nonblocking=nonblocking)

    def job_init(self):
        """Initialize the job socket.

        For server mode, this is a bound local socket.
        For client mode, it is a connection to the server socket.

        :returns: Object
        """

        if self.args.mode == "server":
            self.bind_job = self._job_bind()
        else:
            self.bind_job = self._job_connect()

    def job_close(self):
        """Close the job socket."""

        self._close(socket=self.bind_job)

    def job_check(self, interval=1, constant=1000):
        """Return True if a job contains work ready.

        :param bind: A given Socket bind to identify.
        :type bind: Object
        :param interval: Exponential Interval used to determine the polling
                         duration for a given socket.
        :type interval: Integer
        :param constant: Constant time used to poll for new jobs.
        :type constant: Integer
        :returns: Object
        """

        return self._bind_check(bind=self.bind_job,
                                interval=interval,
                                constant=constant)

    def shutdown(self):
        """Shutdown the driver."""

        if hasattr(self.ctx, "close"):
            self.ctx.close()

        if hasattr(self._context, "close"):
            self._context.close()

        self.job_close()
        self.backend_close()
Ejemplo n.º 9
0
 def test_getlogger_new_logger(self):
     log = logger.getLogger(name="testLogger")
     for handler in log.handlers:
         return self.assertTrue(handler.name == "testLogger")
     else:
         self.fail("The log handler name was not set")
Ejemplo n.º 10
0
    def __init__(self):
        """Initialize Processor class."""

        self.log = logger.getLogger(name="directord")