예제 #1
0
class ZMQDataPublisher(Configurable):

    topic = topic = CBytes(b'datapub')
    session = Instance(Session, allow_none=True)
    pub_socket = Any(allow_none=True)
    parent_header = Dict({})

    def set_parent(self, parent):
        """Set the parent for outbound messages."""
        self.parent_header = extract_header(parent)

    def publish_data(self, data):
        """publish a data_message on the IOPub channel

        Parameters
        ----------

        data : dict
            The data to be published. Think of it as a namespace.
        """
        session = self.session
        buffers = serialize_object(
            data,
            buffer_threshold=session.buffer_threshold,
            item_threshold=session.item_threshold,
        )
        content = json_clean(dict(keys=list(data.keys())))
        session.send(
            self.pub_socket,
            'data_message',
            content=content,
            parent=self.parent_header,
            buffers=buffers,
            ident=self.topic,
        )
예제 #2
0
class Scheduler(SessionFactory):
    client_stream = Instance(zmqstream.ZMQStream,
                             allow_none=True)  # client-facing stream
    engine_stream = Instance(zmqstream.ZMQStream,
                             allow_none=True)  # engine-facing stream
    notifier_stream = Instance(zmqstream.ZMQStream,
                               allow_none=True)  # hub-facing sub stream
    mon_stream = Instance(zmqstream.ZMQStream,
                          allow_none=True)  # hub-facing pub stream
    query_stream = Instance(zmqstream.ZMQStream,
                            allow_none=True)  # hub-facing DEALER stream

    all_completed = Set()  # set of all completed tasks
    all_failed = Set()  # set of all failed tasks
    all_done = Set()  # set of all finished tasks=union(completed,failed)
    all_ids = Set()  # set of all submitted task IDs

    ident = CBytes()  # ZMQ identity. This should just be self.session.session

    # but ensure Bytes
    def _ident_default(self):
        return self.session.bsession

    def start(self):
        self.engine_stream.on_recv(self.dispatch_result, copy=False)
        self.client_stream.on_recv(self.dispatch_submission, copy=False)

    def resume_receiving(self):
        """Resume accepting jobs."""
        self.client_stream.on_recv(self.dispatch_submission, copy=False)

    def stop_receiving(self):
        """Stop accepting jobs while there are no engines.
        Leave them in the ZMQ queue."""
        self.client_stream.on_recv(None)

    def dispatch_result(self, raw_msg):
        raise NotImplementedError("Implement in subclasses")

    def dispatch_submission(self, raw_msg):
        raise NotImplementedError("Implement in subclasses")

    def append_new_msg_id_to_msg(self, new_id, target_id, idents, msg):
        new_idents = [cast_bytes(target_id)] + idents
        msg['header']['msg_id'] = new_id
        new_msg_list = self.session.serialize(msg, ident=new_idents)
        new_msg_list.extend(msg['buffers'])
        return new_msg_list

    def get_new_msg_id(self, original_msg_id, outgoing_id):
        return f'{original_msg_id}_{outgoing_id if isinstance(outgoing_id, str) else outgoing_id.decode("utf8")}'
예제 #3
0
class ZMQDisplayPublisher(DisplayPublisher):
    """A display publisher that publishes data using a ZeroMQ PUB socket."""

    session = Instance(Session, allow_none=True)
    pub_socket = Any(allow_none=True)
    parent_header = Dict({})
    topic = CBytes(b'display_data')

    def set_parent(self, parent):
        """Set the parent for outbound messages."""
        self.parent_header = extract_header(parent)

    def _flush_streams(self):
        """flush IO Streams prior to display"""
        sys.stdout.flush()
        sys.stderr.flush()

    def publish(self, data, metadata=None, source=None):
        self._flush_streams()
        if metadata is None:
            metadata = {}
        self._validate_data(data, metadata)
        content = {}
        content['data'] = encode_images(data)
        content['metadata'] = metadata
        self.session.send(
            self.pub_socket,
            u'display_data',
            json_clean(content),
            parent=self.parent_header,
            ident=self.topic,
        )

    def clear_output(self, wait=False):
        content = dict(wait=wait)
        self._flush_streams()
        self.session.send(
            self.pub_socket,
            u'clear_output',
            content,
            parent=self.parent_header,
            ident=self.topic,
        )
예제 #4
0
class EngineFactory(RegistrationFactory):
    """IPython engine"""

    # configurables:
    out_stream_factory = Type('ipython_kernel.iostream.OutStream',
                              config=True,
                              help="""The OutStream for handling stdout/err.
        Typically 'ipython_kernel.iostream.OutStream'""")
    display_hook_factory = Type('ipython_kernel.displayhook.ZMQDisplayHook',
                                config=True,
                                help="""The class for handling displayhook.
        Typically 'ipython_kernel.displayhook.ZMQDisplayHook'""")
    location = Unicode(
        config=True,
        help="""The location (an IP address) of the controller.  This is
        used for disambiguating URLs, to determine whether
        loopback should be used to connect or the public address.""")
    timeout = Float(
        5.0,
        config=True,
        help="""The time (in seconds) to wait for the Controller to respond
        to registration requests before giving up.""")
    max_heartbeat_misses = Integer(
        50,
        config=True,
        help="""The maximum number of times a check for the heartbeat ping of a 
        controller can be missed before shutting down the engine.
        
        If set to 0, the check is disabled.""")
    sshserver = Unicode(
        config=True,
        help=
        """The SSH server to use for tunneling connections to the Controller."""
    )
    sshkey = Unicode(
        config=True,
        help=
        """The SSH private key file to use when tunneling connections to the Controller."""
    )
    paramiko = Bool(
        sys.platform == 'win32',
        config=True,
        help="""Whether to use paramiko instead of openssh for tunnels.""")

    @property
    def tunnel_mod(self):
        from zmq.ssh import tunnel
        return tunnel

    # not configurable:
    connection_info = Dict()
    user_ns = Dict()
    id = Integer(allow_none=True)
    registrar = Instance('zmq.eventloop.zmqstream.ZMQStream', allow_none=True)
    kernel = Instance(Kernel, allow_none=True)
    hb_check_period = Integer()

    # States for the heartbeat monitoring
    # Initial values for monitored and pinged must satisfy "monitored > pinged == False" so that
    # during the first check no "missed" ping is reported. Must be floats for Python 3 compatibility.
    _hb_last_pinged = 0.0
    _hb_last_monitored = 0.0
    _hb_missed_beats = 0
    # The zmq Stream which receives the pings from the Heart
    _hb_listener = None

    bident = CBytes()
    ident = Unicode()

    def _ident_changed(self, name, old, new):
        self.bident = cast_bytes(new)

    using_ssh = Bool(False)

    def __init__(self, **kwargs):
        super(EngineFactory, self).__init__(**kwargs)
        self.ident = self.session.session

    def init_connector(self):
        """construct connection function, which handles tunnels."""
        self.using_ssh = bool(self.sshkey or self.sshserver)

        if self.sshkey and not self.sshserver:
            # We are using ssh directly to the controller, tunneling localhost to localhost
            self.sshserver = self.url.split('://')[1].split(':')[0]

        if self.using_ssh:
            if self.tunnel_mod.try_passwordless_ssh(self.sshserver,
                                                    self.sshkey,
                                                    self.paramiko):
                password = False
            else:
                password = getpass("SSH Password for %s: " % self.sshserver)
        else:
            password = False

        def connect(s, url):
            url = disambiguate_url(url, self.location)
            if self.using_ssh:
                self.log.debug("Tunneling connection to %s via %s", url,
                               self.sshserver)
                return self.tunnel_mod.tunnel_connection(
                    s,
                    url,
                    self.sshserver,
                    keyfile=self.sshkey,
                    paramiko=self.paramiko,
                    password=password,
                )
            else:
                return s.connect(url)

        def maybe_tunnel(url):
            """like connect, but don't complete the connection (for use by heartbeat)"""
            url = disambiguate_url(url, self.location)
            if self.using_ssh:
                self.log.debug("Tunneling connection to %s via %s", url,
                               self.sshserver)
                url, tunnelobj = self.tunnel_mod.open_tunnel(
                    url,
                    self.sshserver,
                    keyfile=self.sshkey,
                    paramiko=self.paramiko,
                    password=password,
                )
            return str(url)

        return connect, maybe_tunnel

    def register(self):
        """send the registration_request"""

        self.log.info("Registering with controller at %s" % self.url)
        ctx = self.context
        connect, maybe_tunnel = self.init_connector()
        reg = ctx.socket(zmq.DEALER)
        reg.setsockopt(zmq.IDENTITY, self.bident)
        connect(reg, self.url)
        self.registrar = zmqstream.ZMQStream(reg, self.loop)

        content = dict(uuid=self.ident)
        self.registrar.on_recv(
            lambda msg: self.complete_registration(msg, connect, maybe_tunnel))
        # print (self.session.key)
        self.session.send(self.registrar,
                          "registration_request",
                          content=content)

    def _report_ping(self, msg):
        """Callback for when the heartmonitor.Heart receives a ping"""
        #self.log.debug("Received a ping: %s", msg)
        self._hb_last_pinged = time.time()

    def complete_registration(self, msg, connect, maybe_tunnel):
        # print msg
        self.loop.remove_timeout(self._abort_timeout)
        ctx = self.context
        loop = self.loop
        identity = self.bident
        idents, msg = self.session.feed_identities(msg)
        msg = self.session.deserialize(msg)
        content = msg['content']
        info = self.connection_info

        def url(key):
            """get zmq url for given channel"""
            return str(info["interface"] + ":%i" % info[key])

        if content['status'] == 'ok':
            self.id = int(content['id'])

            # launch heartbeat
            # possibly forward hb ports with tunnels
            hb_ping = maybe_tunnel(url('hb_ping'))
            hb_pong = maybe_tunnel(url('hb_pong'))

            hb_monitor = None
            if self.max_heartbeat_misses > 0:
                # Add a monitor socket which will record the last time a ping was seen
                mon = self.context.socket(zmq.SUB)
                mport = mon.bind_to_random_port('tcp://%s' % localhost())
                mon.setsockopt(zmq.SUBSCRIBE, b"")
                self._hb_listener = zmqstream.ZMQStream(mon, self.loop)
                self._hb_listener.on_recv(self._report_ping)

                hb_monitor = "tcp://%s:%i" % (localhost(), mport)

            heart = Heart(hb_ping, hb_pong, hb_monitor, heart_id=identity)
            heart.start()

            # create Shell Connections (MUX, Task, etc.):
            shell_addrs = url('mux'), url('task')

            # Use only one shell stream for mux and tasks
            stream = zmqstream.ZMQStream(ctx.socket(zmq.ROUTER), loop)
            stream.setsockopt(zmq.IDENTITY, identity)
            shell_streams = [stream]
            for addr in shell_addrs:
                connect(stream, addr)

            # control stream:
            control_addr = url('control')
            control_stream = zmqstream.ZMQStream(ctx.socket(zmq.ROUTER), loop)
            control_stream.setsockopt(zmq.IDENTITY, identity)
            connect(control_stream, control_addr)

            # create iopub stream:
            iopub_addr = url('iopub')
            iopub_socket = ctx.socket(zmq.PUB)
            iopub_socket.setsockopt(zmq.IDENTITY, identity)
            connect(iopub_socket, iopub_addr)

            # disable history:
            self.config.HistoryManager.hist_file = ':memory:'

            # Redirect input streams and set a display hook.
            if self.out_stream_factory:
                sys.stdout = self.out_stream_factory(self.session,
                                                     iopub_socket, u'stdout')
                sys.stdout.topic = cast_bytes('engine.%i.stdout' % self.id)
                sys.stderr = self.out_stream_factory(self.session,
                                                     iopub_socket, u'stderr')
                sys.stderr.topic = cast_bytes('engine.%i.stderr' % self.id)
            if self.display_hook_factory:
                sys.displayhook = self.display_hook_factory(
                    self.session, iopub_socket)
                sys.displayhook.topic = cast_bytes('engine.%i.execute_result' %
                                                   self.id)

            self.kernel = Kernel(parent=self,
                                 int_id=self.id,
                                 ident=self.ident,
                                 session=self.session,
                                 control_stream=control_stream,
                                 shell_streams=shell_streams,
                                 iopub_socket=iopub_socket,
                                 loop=loop,
                                 user_ns=self.user_ns,
                                 log=self.log)

            self.kernel.shell.display_pub.topic = cast_bytes(
                'engine.%i.displaypub' % self.id)

            # periodically check the heartbeat pings of the controller
            # Should be started here and not in "start()" so that the right period can be taken
            # from the hubs HeartBeatMonitor.period
            if self.max_heartbeat_misses > 0:
                # Use a slightly bigger check period than the hub signal period to not warn unnecessary
                self.hb_check_period = int(content['hb_period']) + 10
                self.log.info(
                    "Starting to monitor the heartbeat signal from the hub every %i ms.",
                    self.hb_check_period)
                self._hb_reporter = ioloop.PeriodicCallback(
                    self._hb_monitor, self.hb_check_period, self.loop)
                self._hb_reporter.start()
            else:
                self.log.info(
                    "Monitoring of the heartbeat signal from the hub is not enabled."
                )

            # FIXME: This is a hack until IPKernelApp and IPEngineApp can be fully merged
            app = IPKernelApp(parent=self,
                              shell=self.kernel.shell,
                              kernel=self.kernel,
                              log=self.log)
            app.init_profile_dir()
            app.init_code()

            self.kernel.start()
        else:
            self.log.fatal("Registration Failed: %s" % msg)
            raise Exception("Registration Failed: %s" % msg)

        self.log.info("Completed registration with id %i" % self.id)

    def abort(self):
        self.log.fatal("Registration timed out after %.1f seconds" %
                       self.timeout)
        if self.url.startswith('127.'):
            self.log.fatal("""
            If the controller and engines are not on the same machine,
            you will have to instruct the controller to listen on an external IP (in ipcontroller_config.py):
                c.HubFactory.ip='*' # for all interfaces, internal and external
                c.HubFactory.ip='192.168.1.101' # or any interface that the engines can see
            or tunnel connections via ssh.
            """)
        self.session.send(self.registrar,
                          "unregistration_request",
                          content=dict(id=self.id))
        time.sleep(1)
        sys.exit(255)

    def _hb_monitor(self):
        """Callback to monitor the heartbeat from the controller"""
        self._hb_listener.flush()
        if self._hb_last_monitored > self._hb_last_pinged:
            self._hb_missed_beats += 1
            self.log.warn(
                "No heartbeat in the last %s ms (%s time(s) in a row).",
                self.hb_check_period, self._hb_missed_beats)
        else:
            #self.log.debug("Heartbeat received (after missing %s beats).", self._hb_missed_beats)
            self._hb_missed_beats = 0

        if self._hb_missed_beats >= self.max_heartbeat_misses:
            self.log.fatal(
                "Maximum number of heartbeats misses reached (%s times %s ms), shutting down.",
                self.max_heartbeat_misses, self.hb_check_period)
            self.session.send(self.registrar,
                              "unregistration_request",
                              content=dict(id=self.id))
            self.loop.stop()

        self._hb_last_monitored = time.time()

    def start(self):
        loop = self.loop

        def _start():
            self.register()
            self._abort_timeout = loop.add_timeout(loop.time() + self.timeout,
                                                   self.abort)

        self.loop.add_callback(_start)
예제 #5
0
class ZMQDisplayPublisher(DisplayPublisher):
    """A display publisher that publishes data using a ZeroMQ PUB socket."""

    session = Instance(Session, allow_none=True)
    pub_socket = Any(allow_none=True)
    parent_header = Dict({})
    topic = CBytes(b'display_data')

    # thread_local:
    # An attribute used to ensure the correct output message
    # is processed. See ipykernel Issue 113 for a discussion.
    _thread_local = Any()

    def set_parent(self, parent):
        """Set the parent for outbound messages."""
        self.parent_header = extract_header(parent)

    def _flush_streams(self):
        """flush IO Streams prior to display"""
        sys.stdout.flush()
        sys.stderr.flush()

    @default('_thread_local')
    def _default_thread_local(self):
        """Initialize our thread local storage"""
        return local()

    @property
    def _hooks(self):
        if not hasattr(self._thread_local, 'hooks'):
            # create new list for a new thread
            self._thread_local.hooks = []
        return self._thread_local.hooks

    def publish(
        self,
        data,
        metadata=None,
        source=None,
        transient=None,
        update=False,
    ):
        """Publish a display-data message

        Parameters
        ----------
        data: dict
            A mime-bundle dict, keyed by mime-type.
        metadata: dict, optional
            Metadata associated with the data.
        transient: dict, optional, keyword-only
            Transient data that may only be relevant during a live display,
            such as display_id.
            Transient data should not be persisted to documents.
        update: bool, optional, keyword-only
            If True, send an update_display_data message instead of display_data.
        """
        self._flush_streams()
        if metadata is None:
            metadata = {}
        if transient is None:
            transient = {}
        self._validate_data(data, metadata)
        content = {}
        content['data'] = encode_images(data)
        content['metadata'] = metadata
        content['transient'] = transient

        msg_type = 'update_display_data' if update else 'display_data'

        # Use 2-stage process to send a message,
        # in order to put it through the transform
        # hooks before potentially sending.
        msg = self.session.msg(msg_type,
                               json_clean(content),
                               parent=self.parent_header)

        # Each transform either returns a new
        # message or None. If None is returned,
        # the message has been 'used' and we return.
        for hook in self._hooks:
            msg = hook(msg)
            if msg is None:
                return

        self.session.send(
            self.pub_socket,
            msg,
            ident=self.topic,
        )

    def clear_output(self, wait=False):
        """Clear output associated with the current execution (cell).

        Parameters
        ----------
        wait: bool (default: False)
            If True, the output will not be cleared immediately,
            instead waiting for the next display before clearing.
            This reduces bounce during repeated clear & display loops.

        """
        content = dict(wait=wait)
        self._flush_streams()
        self.session.send(
            self.pub_socket,
            u'clear_output',
            content,
            parent=self.parent_header,
            ident=self.topic,
        )

    def register_hook(self, hook):
        """
        Registers a hook with the thread-local storage.

        Parameters
        ----------
        hook : Any callable object

        Returns
        -------
        Either a publishable message, or `None`.

        The DisplayHook objects must return a message from
        the __call__ method if they still require the
        `session.send` method to be called after transformation.
        Returning `None` will halt that execution path, and
        session.send will not be called.
        """
        self._hooks.append(hook)

    def unregister_hook(self, hook):
        """
        Un-registers a hook with the thread-local storage.

        Parameters
        ----------
        hook: Any callable object which has previously been
              registered as a hook.

        Returns
        -------
        bool - `True` if the hook was removed, `False` if it wasn't
               found.
        """
        try:
            self._hooks.remove(hook)
            return True
        except ValueError:
            return False
예제 #6
0
class Session(Configurable):
    """Object for handling serialization and sending of messages.

    The Session object handles building messages and sending them
    with ZMQ sockets or ZMQStream objects.  Objects can communicate with each
    other over the network via Session objects, and only need to work with the
    dict-based IPython message spec. The Session will handle
    serialization/deserialization, security, and metadata.

    Sessions support configurable serialization via packer/unpacker traits,
    and signing with HMAC digests via the key/keyfile traits.

    Parameters
    ----------

    debug : bool
        whether to trigger extra debugging statements
    packer/unpacker : str : 'json', 'pickle' or import_string
        importstrings for methods to serialize message parts.  If just
        'json' or 'pickle', predefined JSON and pickle packers will be used.
        Otherwise, the entire importstring must be used.

        The functions must accept at least valid JSON input, and output *bytes*.

        For example, to use msgpack:
        packer = 'msgpack.packb', unpacker='msgpack.unpackb'
    pack/unpack : callables
        You can also set the pack/unpack callables for serialization directly.
    session : bytes
        the ID of this Session object.  The default is to generate a new UUID.
    username : unicode
        username added to message headers.  The default is to ask the OS.
    key : bytes
        The key used to initialize an HMAC signature.  If unset, messages
        will not be signed or checked.
    keyfile : filepath
        The file containing a key.  If this is set, `key` will be initialized
        to the contents of the file.

    """

    debug = Bool(False, config=True, help="""Debug output in the Session""")

    check_pid = Bool(True, config=True,
        help="""Whether to check PID to protect against calls after fork.

        This check can be disabled if fork-safety is handled elsewhere.
        """)

    packer = DottedObjectName('json',config=True,
            help="""The name of the packer for serializing messages.
            Should be one of 'json', 'pickle', or an import name
            for a custom callable serializer.""")

    @observe('packer')
    def _packer_changed(self, change):
        new = change['new']
        if new.lower() == 'json':
            self.pack = json_packer
            self.unpack = json_unpacker
            self.unpacker = new
        elif new.lower() == 'pickle':
            self.pack = pickle_packer
            self.unpack = pickle_unpacker
            self.unpacker = new
        else:
            self.pack = import_item(str(new))

    unpacker = DottedObjectName('json', config=True,
        help="""The name of the unpacker for unserializing messages.
        Only used with custom functions for `packer`.""")

    @observe('unpacker')
    def _unpacker_changed(self, change):
        new = change['new']
        if new.lower() == 'json':
            self.pack = json_packer
            self.unpack = json_unpacker
            self.packer = new
        elif new.lower() == 'pickle':
            self.pack = pickle_packer
            self.unpack = pickle_unpacker
            self.packer = new
        else:
            self.unpack = import_item(str(new))

    session = CUnicode(u'', config=True,
        help="""The UUID identifying this session.""")
    def _session_default(self):
        u = new_id()
        self.bsession = u.encode('ascii')
        return u

    @observe('session')
    def _session_changed(self, change):
        self.bsession = self.session.encode('ascii')

    # bsession is the session as bytes
    bsession = CBytes(b'')

    username = Unicode(str_to_unicode(os.environ.get('USER', 'username')),
        help="""Username for the Session. Default is your system username.""",
        config=True)

    metadata = Dict({}, config=True,
        help="""Metadata dictionary, which serves as the default top-level metadata dict for each message.""")

    # if 0, no adapting to do.
    adapt_version = Integer(0)

    # message signature related traits:

    key = CBytes(config=True,
        help="""execution key, for signing messages.""")
    def _key_default(self):
        return new_id_bytes()

    @observe('key')
    def _key_changed(self, change):
        self._new_auth()

    signature_scheme = Unicode('hmac-sha256', config=True,
        help="""The digest scheme used to construct the message signatures.
        Must have the form 'hmac-HASH'.""")

    @observe('signature_scheme')
    def _signature_scheme_changed(self, change):
        new = change['new']
        if not new.startswith('hmac-'):
            raise TraitError("signature_scheme must start with 'hmac-', got %r" % new)
        hash_name = new.split('-', 1)[1]
        try:
            self.digest_mod = getattr(hashlib, hash_name)
        except AttributeError:
            raise TraitError("hashlib has no such attribute: %s" % hash_name)
        self._new_auth()

    digest_mod = Any()
    def _digest_mod_default(self):
        return hashlib.sha256

    auth = Instance(hmac.HMAC, allow_none=True)

    def _new_auth(self):
        if self.key:
            self.auth = hmac.HMAC(self.key, digestmod=self.digest_mod)
        else:
            self.auth = None

    digest_history = Set()
    digest_history_size = Integer(2**16, config=True,
        help="""The maximum number of digests to remember.

        The digest history will be culled when it exceeds this value.
        """
    )

    keyfile = Unicode('', config=True,
        help="""path to file containing execution key.""")

    @observe('keyfile')
    def _keyfile_changed(self, change):
        with open(change['new'], 'rb') as f:
            self.key = f.read().strip()

    # for protecting against sends from forks
    pid = Integer()

    # serialization traits:

    pack = Any(default_packer) # the actual packer function

    @observe('pack')
    def _pack_changed(self, change):
        new = change['new']
        if not callable(new):
            raise TypeError("packer must be callable, not %s"%type(new))

    unpack = Any(default_unpacker) # the actual packer function

    @observe('unpack')
    def _unpack_changed(self, change):
        # unpacker is not checked - it is assumed to be
        new = change['new']
        if not callable(new):
            raise TypeError("unpacker must be callable, not %s"%type(new))

    # thresholds:
    copy_threshold = Integer(2**16, config=True,
        help="Threshold (in bytes) beyond which a buffer should be sent without copying.")
    buffer_threshold = Integer(MAX_BYTES, config=True,
        help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.")
    item_threshold = Integer(MAX_ITEMS, config=True,
        help="""The maximum number of items for a container to be introspected for custom serialization.
        Containers larger than this are pickled outright.
        """
    )


    def __init__(self, **kwargs):
        """create a Session object

        Parameters
        ----------

        debug : bool
            whether to trigger extra debugging statements
        packer/unpacker : str : 'json', 'pickle' or import_string
            importstrings for methods to serialize message parts.  If just
            'json' or 'pickle', predefined JSON and pickle packers will be used.
            Otherwise, the entire importstring must be used.

            The functions must accept at least valid JSON input, and output
            *bytes*.

            For example, to use msgpack:
            packer = 'msgpack.packb', unpacker='msgpack.unpackb'
        pack/unpack : callables
            You can also set the pack/unpack callables for serialization
            directly.
        session : unicode (must be ascii)
            the ID of this Session object.  The default is to generate a new
            UUID.
        bsession : bytes
            The session as bytes
        username : unicode
            username added to message headers.  The default is to ask the OS.
        key : bytes
            The key used to initialize an HMAC signature.  If unset, messages
            will not be signed or checked.
        signature_scheme : str
            The message digest scheme. Currently must be of the form 'hmac-HASH',
            where 'HASH' is a hashing function available in Python's hashlib.
            The default is 'hmac-sha256'.
            This is ignored if 'key' is empty.
        keyfile : filepath
            The file containing a key.  If this is set, `key` will be
            initialized to the contents of the file.
        """
        super(Session, self).__init__(**kwargs)
        self._check_packers()
        self.none = self.pack({})
        # ensure self._session_default() if necessary, so bsession is defined:
        self.session
        self.pid = os.getpid()
        self._new_auth()
        if not self.key:
            get_logger().warning("Message signing is disabled.  This is insecure and not recommended!")

    def clone(self):
        """Create a copy of this Session

        Useful when connecting multiple times to a given kernel.
        This prevents a shared digest_history warning about duplicate digests
        due to multiple connections to IOPub in the same process.

        .. versionadded:: 5.1
        """
        # make a copy
        new_session = type(self)()
        for name in self.traits():
            setattr(new_session, name, getattr(self, name))
        # fork digest_history
        new_session.digest_history = set()
        new_session.digest_history.update(self.digest_history)
        return new_session

    @property
    def msg_id(self):
        """always return new uuid"""
        return new_id()

    def _check_packers(self):
        """check packers for datetime support."""
        pack = self.pack
        unpack = self.unpack

        # check simple serialization
        msg = dict(a=[1,'hi'])
        try:
            packed = pack(msg)
        except Exception as e:
            msg = "packer '{packer}' could not serialize a simple message: {e}{jsonmsg}"
            if self.packer == 'json':
                jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
            else:
                jsonmsg = ""
            raise ValueError(
                msg.format(packer=self.packer, e=e, jsonmsg=jsonmsg)
            )

        # ensure packed message is bytes
        if not isinstance(packed, bytes):
            raise ValueError("message packed to %r, but bytes are required"%type(packed))

        # check that unpack is pack's inverse
        try:
            unpacked = unpack(packed)
            assert unpacked == msg
        except Exception as e:
            msg = "unpacker '{unpacker}' could not handle output from packer '{packer}': {e}{jsonmsg}"
            if self.packer == 'json':
                jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
            else:
                jsonmsg = ""
            raise ValueError(
                msg.format(packer=self.packer, unpacker=self.unpacker, e=e, jsonmsg=jsonmsg)
            )

        # check datetime support
        msg = dict(t=utcnow())
        try:
            unpacked = unpack(pack(msg))
            if isinstance(unpacked['t'], datetime):
                raise ValueError("Shouldn't deserialize to datetime")
        except Exception:
            self.pack = lambda o: pack(squash_dates(o))
            self.unpack = lambda s: unpack(s)

    def msg_header(self, msg_type):
        return msg_header(self.msg_id, msg_type, self.username, self.session)

    def msg(self, msg_type, content=None, parent=None, header=None, metadata=None):
        """Return the nested message dict.

        This format is different from what is sent over the wire. The
        serialize/deserialize methods converts this nested message dict to the wire
        format, which is a list of message parts.
        """
        msg = {}
        header = self.msg_header(msg_type) if header is None else header
        msg['header'] = header
        msg['msg_id'] = header['msg_id']
        msg['msg_type'] = header['msg_type']
        msg['parent_header'] = {} if parent is None else extract_header(parent)
        msg['content'] = {} if content is None else content
        msg['metadata'] = self.metadata.copy()
        if metadata is not None:
            msg['metadata'].update(metadata)
        return msg

    def sign(self, msg_list):
        """Sign a message with HMAC digest. If no auth, return b''.

        Parameters
        ----------
        msg_list : list
            The [p_header,p_parent,p_content] part of the message list.
        """
        if self.auth is None:
            return b''
        h = self.auth.copy()
        for m in msg_list:
            h.update(m)
        return str_to_bytes(h.hexdigest())

    def serialize(self, msg, ident=None):
        """Serialize the message components to bytes.

        This is roughly the inverse of deserialize. The serialize/deserialize
        methods work with full message lists, whereas pack/unpack work with
        the individual message parts in the message list.

        Parameters
        ----------
        msg : dict or Message
            The next message dict as returned by the self.msg method.

        Returns
        -------
        msg_list : list
            The list of bytes objects to be sent with the format::

                [ident1, ident2, ..., DELIM, HMAC, p_header, p_parent,
                 p_metadata, p_content, buffer1, buffer2, ...]

            In this list, the ``p_*`` entities are the packed or serialized
            versions, so if JSON is used, these are utf8 encoded JSON strings.
        """
        content = msg.get('content', {})
        if content is None:
            content = self.none
        elif isinstance(content, dict):
            content = self.pack(content)
        elif isinstance(content, bytes):
            # content is already packed, as in a relayed message
            pass
        elif isinstance(content, unicode_type):
            # should be bytes, but JSON often spits out unicode
            content = content.encode('utf8')
        else:
            raise TypeError("Content incorrect type: %s"%type(content))

        real_message = [self.pack(msg['header']),
                        self.pack(msg['parent_header']),
                        self.pack(msg['metadata']),
                        content,
        ]

        to_send = []

        if isinstance(ident, list):
            # accept list of idents
            to_send.extend(ident)
        elif ident is not None:
            to_send.append(ident)
        to_send.append(DELIM)

        signature = self.sign(real_message)
        to_send.append(signature)

        to_send.extend(real_message)

        return to_send

    def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
             buffers=None, track=False, header=None, metadata=None):
        """Build and send a message via stream or socket.

        The message format used by this function internally is as follows:

        [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
         buffer1,buffer2,...]

        The serialize/deserialize methods convert the nested message dict into this
        format.

        Parameters
        ----------

        stream : zmq.Socket or ZMQStream
            The socket-like object used to send the data.
        msg_or_type : str or Message/dict
            Normally, msg_or_type will be a msg_type unless a message is being
            sent more than once. If a header is supplied, this can be set to
            None and the msg_type will be pulled from the header.

        content : dict or None
            The content of the message (ignored if msg_or_type is a message).
        header : dict or None
            The header dict for the message (ignored if msg_to_type is a message).
        parent : Message or dict or None
            The parent or parent header describing the parent of this message
            (ignored if msg_or_type is a message).
        ident : bytes or list of bytes
            The zmq.IDENTITY routing path.
        metadata : dict or None
            The metadata describing the message
        buffers : list or None
            The already-serialized buffers to be appended to the message.
        track : bool
            Whether to track.  Only for use with Sockets, because ZMQStream
            objects cannot track messages.


        Returns
        -------
        msg : dict
            The constructed message.
        """
        if not isinstance(stream, zmq.Socket):
            # ZMQStreams and dummy sockets do not support tracking.
            track = False

        if isinstance(msg_or_type, (Message, dict)):
            # We got a Message or message dict, not a msg_type so don't
            # build a new Message.
            msg = msg_or_type
            buffers = buffers or msg.get('buffers', [])
        else:
            msg = self.msg(msg_or_type, content=content, parent=parent,
                           header=header, metadata=metadata)
        if self.check_pid and not os.getpid() == self.pid:
            get_logger().warning("WARNING: attempted to send message from fork\n%s",
                msg
            )
            return
        buffers = [] if buffers is None else buffers
        for idx, buf in enumerate(buffers):
            if isinstance(buf, memoryview):
                view = buf
            else:
                try:
                    # check to see if buf supports the buffer protocol.
                    view = memoryview(buf)
                except TypeError:
                    raise TypeError("Buffer objects must support the buffer protocol.")
            # memoryview.contiguous is new in 3.3,
            # just skip the check on Python 2
            if hasattr(view, 'contiguous') and not view.contiguous:
                # zmq requires memoryviews to be contiguous
                raise ValueError("Buffer %i (%r) is not contiguous" % (idx, buf))

        if self.adapt_version:
            msg = adapt(msg, self.adapt_version)
        to_send = self.serialize(msg, ident)
        to_send.extend(buffers)
        longest = max([ len(s) for s in to_send ])
        copy = (longest < self.copy_threshold)

        if buffers and track and not copy:
            # only really track when we are doing zero-copy buffers
            tracker = stream.send_multipart(to_send, copy=False, track=True)
        else:
            # use dummy tracker, which will be done immediately
            tracker = DONE
            stream.send_multipart(to_send, copy=copy)

        if self.debug:
            pprint.pprint(msg)
            pprint.pprint(to_send)
            pprint.pprint(buffers)

        msg['tracker'] = tracker

        return msg

    def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
        """Send a raw message via ident path.

        This method is used to send a already serialized message.

        Parameters
        ----------
        stream : ZMQStream or Socket
            The ZMQ stream or socket to use for sending the message.
        msg_list : list
            The serialized list of messages to send. This only includes the
            [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of
            the message.
        ident : ident or list
            A single ident or a list of idents to use in sending.
        """
        to_send = []
        if isinstance(ident, bytes):
            ident = [ident]
        if ident is not None:
            to_send.extend(ident)

        to_send.append(DELIM)
        # Don't include buffers in signature (per spec).
        to_send.append(self.sign(msg_list[0:4]))
        to_send.extend(msg_list)
        stream.send_multipart(to_send, flags, copy=copy)

    def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
        """Receive and unpack a message.

        Parameters
        ----------
        socket : ZMQStream or Socket
            The socket or stream to use in receiving.

        Returns
        -------
        [idents], msg
            [idents] is a list of idents and msg is a nested message dict of
            same format as self.msg returns.
        """
        if isinstance(socket, ZMQStream):
            socket = socket.socket
        try:
            msg_list = socket.recv_multipart(mode, copy=copy)
        except zmq.ZMQError as e:
            if e.errno == zmq.EAGAIN:
                # We can convert EAGAIN to None as we know in this case
                # recv_multipart won't return None.
                return None,None
            else:
                raise
        # split multipart message into identity list and message dict
        # invalid large messages can cause very expensive string comparisons
        idents, msg_list = self.feed_identities(msg_list, copy)
        try:
            return idents, self.deserialize(msg_list, content=content, copy=copy)
        except Exception as e:
            # TODO: handle it
            raise e

    def feed_identities(self, msg_list, copy=True):
        """Split the identities from the rest of the message.

        Feed until DELIM is reached, then return the prefix as idents and
        remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
        but that would be silly.

        Parameters
        ----------
        msg_list : a list of Message or bytes objects
            The message to be split.
        copy : bool
            flag determining whether the arguments are bytes or Messages

        Returns
        -------
        (idents, msg_list) : two lists
            idents will always be a list of bytes, each of which is a ZMQ
            identity. msg_list will be a list of bytes or zmq.Messages of the
            form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
            should be unpackable/unserializable via self.deserialize at this
            point.
        """
        if copy:
            idx = msg_list.index(DELIM)
            return msg_list[:idx], msg_list[idx+1:]
        else:
            failed = True
            for idx,m in enumerate(msg_list):
                if m.bytes == DELIM:
                    failed = False
                    break
            if failed:
                raise ValueError("DELIM not in msg_list")
            idents, msg_list = msg_list[:idx], msg_list[idx+1:]
            return [m.bytes for m in idents], msg_list

    def _add_digest(self, signature):
        """add a digest to history to protect against replay attacks"""
        if self.digest_history_size == 0:
            # no history, never add digests
            return

        self.digest_history.add(signature)
        if len(self.digest_history) > self.digest_history_size:
            # threshold reached, cull 10%
            self._cull_digest_history()

    def _cull_digest_history(self):
        """cull the digest history

        Removes a randomly selected 10% of the digest history
        """
        current = len(self.digest_history)
        n_to_cull = max(int(current // 10), current - self.digest_history_size)
        if n_to_cull >= current:
            self.digest_history = set()
            return
        to_cull = random.sample(self.digest_history, n_to_cull)
        self.digest_history.difference_update(to_cull)

    def deserialize(self, msg_list, content=True, copy=True):
        """Unserialize a msg_list to a nested message dict.

        This is roughly the inverse of serialize. The serialize/deserialize
        methods work with full message lists, whereas pack/unpack work with
        the individual message parts in the message list.

        Parameters
        ----------
        msg_list : list of bytes or Message objects
            The list of message parts of the form [HMAC,p_header,p_parent,
            p_metadata,p_content,buffer1,buffer2,...].
        content : bool (True)
            Whether to unpack the content dict (True), or leave it packed
            (False).
        copy : bool (True)
            Whether msg_list contains bytes (True) or the non-copying Message
            objects in each place (False).

        Returns
        -------
        msg : dict
            The nested message dict with top-level keys [header, parent_header,
            content, buffers].  The buffers are returned as memoryviews.
        """
        minlen = 5
        message = {}
        if not copy:
            # pyzmq didn't copy the first parts of the message, so we'll do it
            for i in range(minlen):
                msg_list[i] = msg_list[i].bytes
        if self.auth is not None:
            signature = msg_list[0]
            if not signature:
                raise ValueError("Unsigned Message")
            if signature in self.digest_history:
                raise ValueError("Duplicate Signature: %r" % signature)
            if content:
                # Only store signature if we are unpacking content, don't store if just peeking.
                self._add_digest(signature)
            check = self.sign(msg_list[1:5])
            if not compare_digest(signature, check):
                raise ValueError("Invalid Signature: %r" % signature)
        if not len(msg_list) >= minlen:
            raise TypeError("malformed message, must have at least %i elements"%minlen)
        header = self.unpack(msg_list[1])
        message['header'] = extract_dates(header)
        message['msg_id'] = header['msg_id']
        message['msg_type'] = header['msg_type']
        message['parent_header'] = extract_dates(self.unpack(msg_list[2]))
        message['metadata'] = self.unpack(msg_list[3])
        if content:
            message['content'] = self.unpack(msg_list[4])
        else:
            message['content'] = msg_list[4]
        buffers = [memoryview(b) for b in msg_list[5:]]
        if buffers and buffers[0].shape is None:
            # force copy to workaround pyzmq #646
            buffers = [memoryview(b.bytes) for b in msg_list[5:]]
        message['buffers'] = buffers
        if self.debug:
            pprint.pprint(message)
        # adapt to the current version
        return adapt(message)

    def unserialize(self, *args, **kwargs):
        warnings.warn(
            "Session.unserialize is deprecated. Use Session.deserialize.",
            DeprecationWarning,
        )
        return self.deserialize(*args, **kwargs)
예제 #7
0
class ZMQDisplayPublisher(DisplayPublisher):
    """A display publisher that publishes data using a ZeroMQ PUB socket."""

    session = Instance(Session, allow_none=True)
    pub_socket = Any(allow_none=True)
    parent_header = Dict({})
    topic = CBytes(b'display_data')

    # thread_local:
    # An attribute used to ensure the correct output message
    # is processed. See ipykernel Issue 113 for a discussion.
    _thread_local = Any()

    def set_parent(self, parent):
        """Set the parent for outbound messages."""
        self.parent_header = extract_header(parent)

    def _flush_streams(self):
        """flush IO Streams prior to display"""
        sys.stdout.flush()
        sys.stderr.flush()

    @default('_thread_local')
    def _default_thread_local(self):
        """Initialize our thread local storage"""
        return local()

    @property
    def _hooks(self):
        if not hasattr(self._thread_local, 'hooks'):
            # create new list for a new thread
            self._thread_local.hooks = []
        return self._thread_local.hooks

    def publish(self, data, metadata=None, source=None):
        self._flush_streams()
        if metadata is None:
            metadata = {}
        self._validate_data(data, metadata)
        content = {}
        content['data'] = encode_images(data)
        content['metadata'] = metadata

        # Use 2-stage process to send a message,
        # in order to put it through the transform
        # hooks before potentially sending.
        msg = self.session.msg(u'display_data',
                               json_clean(content),
                               parent=self.parent_header)

        # Each transform either returns a new
        # message or None. If None is returned,
        # the message has been 'used' and we return.
        for hook in self._hooks:
            msg = hook(msg)
            if msg is None:
                return

        self.session.send(self.pub_socket, msg, ident=self.topic)

    def clear_output(self, wait=False):
        content = dict(wait=wait)
        self._flush_streams()
        self.session.send(
            self.pub_socket,
            u'clear_output',
            content,
            parent=self.parent_header,
            ident=self.topic,
        )

    def register_hook(self, hook):
        """
        Registers a hook with the thread-local storage.

        Parameters
        ----------
        hook : Any callable object

        Returns
        -------
        Either a publishable message, or `None`.

        The DisplayHook objects must return a message from
        the __call__ method if they still require the
        `session.send` method to be called after tranformation.
        Returning `None` will halt that execution path, and
        session.send will not be called.
        """
        self._hooks.append(hook)

    def unregister_hook(self, hook):
        """
        Un-registers a hook with the thread-local storage.

        Parameters
        ----------
        hook: Any callable object which has previously been
              registered as a hook.

        Returns
        -------
        bool - `True` if the hook was removed, `False` if it wasn't
               found.
        """
        try:
            self._hooks.remove(hook)
            return True
        except ValueError:
            return False
예제 #8
0
 def test_invalid_args(self):
     self.assertRaises(TypeError, Tuple, 5)
     self.assertRaises(TypeError, Tuple, default_value='hello')
     t = Tuple(Int(), CBytes(), default_value=(1,5))
예제 #9
0
class TaskScheduler(SessionFactory):
    """Python TaskScheduler object.

    This is the simplest object that supports msg_id based
    DAG dependencies. *Only* task msg_ids are checked, not
    msg_ids of jobs submitted via the MUX queue.

    """

    hwm = Integer(1,
                  config=True,
                  help="""specify the High Water Mark (HWM) for the downstream
        socket in the Task scheduler. This is the maximum number
        of allowed outstanding tasks on each engine.
        
        The default (1) means that only one task can be outstanding on each
        engine.  Setting TaskScheduler.hwm=0 means there is no limit, and the
        engines continue to be assigned tasks while they are working,
        effectively hiding network latency behind computation, but can result
        in an imbalance of work when submitting many heterogenous tasks all at
        once.  Any positive value greater than one is a compromise between the
        two.

        """)
    scheme_name = Enum(
        ('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'),
        'leastload',
        config=True,
        help="""select the task scheduler scheme  [default: Python LRU]
        Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'"""
    )

    def _scheme_name_changed(self, old, new):
        self.log.debug("Using scheme %r" % new)
        self.scheme = globals()[new]

    # input arguments:
    scheme = Instance(FunctionType)  # function for determining the destination

    def _scheme_default(self):
        return leastload

    client_stream = Instance(zmqstream.ZMQStream,
                             allow_none=True)  # client-facing stream
    engine_stream = Instance(zmqstream.ZMQStream,
                             allow_none=True)  # engine-facing stream
    notifier_stream = Instance(zmqstream.ZMQStream,
                               allow_none=True)  # hub-facing sub stream
    mon_stream = Instance(zmqstream.ZMQStream,
                          allow_none=True)  # hub-facing pub stream
    query_stream = Instance(zmqstream.ZMQStream,
                            allow_none=True)  # hub-facing DEALER stream

    # internals:
    queue = Instance(deque)  # sorted list of Jobs

    def _queue_default(self):
        return deque()

    queue_map = Dict()  # dict by msg_id of Jobs (for O(1) access to the Queue)
    graph = Dict()  # dict by msg_id of [ msg_ids that depend on key ]
    retries = Dict()  # dict by msg_id of retries remaining (non-neg ints)
    # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
    pending = Dict()  # dict by engine_uuid of submitted tasks
    completed = Dict()  # dict by engine_uuid of completed tasks
    failed = Dict()  # dict by engine_uuid of failed tasks
    destinations = Dict(
    )  # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
    clients = Dict()  # dict by msg_id for who submitted the task
    targets = List()  # list of target IDENTs
    loads = List()  # list of engine loads
    # full = Set() # set of IDENTs that have HWM outstanding tasks
    all_completed = Set()  # set of all completed tasks
    all_failed = Set()  # set of all failed tasks
    all_done = Set()  # set of all finished tasks=union(completed,failed)
    all_ids = Set()  # set of all submitted task IDs

    ident = CBytes()  # ZMQ identity. This should just be self.session.session

    # but ensure Bytes
    def _ident_default(self):
        return self.session.bsession

    def start(self):
        self.query_stream.on_recv(self.dispatch_query_reply)
        self.session.send(self.query_stream, "connection_request", {})

        self.engine_stream.on_recv(self.dispatch_result, copy=False)
        self.client_stream.on_recv(self.dispatch_submission, copy=False)

        self._notification_handlers = dict(
            registration_notification=self._register_engine,
            unregistration_notification=self._unregister_engine)
        self.notifier_stream.on_recv(self.dispatch_notification)
        self.log.info("Scheduler started [%s]" % self.scheme_name)

    def resume_receiving(self):
        """Resume accepting jobs."""
        self.client_stream.on_recv(self.dispatch_submission, copy=False)

    def stop_receiving(self):
        """Stop accepting jobs while there are no engines.
        Leave them in the ZMQ queue."""
        self.client_stream.on_recv(None)

    #-----------------------------------------------------------------------
    # [Un]Registration Handling
    #-----------------------------------------------------------------------

    def dispatch_query_reply(self, msg):
        """handle reply to our initial connection request"""
        try:
            idents, msg = self.session.feed_identities(msg)
        except ValueError:
            self.log.warn("task::Invalid Message: %r", msg)
            return
        try:
            msg = self.session.deserialize(msg)
        except ValueError:
            self.log.warn("task::Unauthorized message from: %r" % idents)
            return

        content = msg['content']
        for uuid in content.get('engines', {}).values():
            self._register_engine(cast_bytes(uuid))

    @util.log_errors
    def dispatch_notification(self, msg):
        """dispatch register/unregister events."""
        try:
            idents, msg = self.session.feed_identities(msg)
        except ValueError:
            self.log.warn("task::Invalid Message: %r", msg)
            return
        try:
            msg = self.session.deserialize(msg)
        except ValueError:
            self.log.warn("task::Unauthorized message from: %r" % idents)
            return

        msg_type = msg['header']['msg_type']

        handler = self._notification_handlers.get(msg_type, None)
        if handler is None:
            self.log.error("Unhandled message type: %r" % msg_type)
        else:
            try:
                handler(cast_bytes(msg['content']['uuid']))
            except Exception:
                self.log.error("task::Invalid notification msg: %r",
                               msg,
                               exc_info=True)

    def _register_engine(self, uid):
        """New engine with ident `uid` became available."""
        # head of the line:
        self.targets.insert(0, uid)
        self.loads.insert(0, 0)

        # initialize sets
        self.completed[uid] = set()
        self.failed[uid] = set()
        self.pending[uid] = {}

        # rescan the graph:
        self.update_graph(None)

    def _unregister_engine(self, uid):
        """Existing engine with ident `uid` became unavailable."""
        if len(self.targets) == 1:
            # this was our only engine
            pass

        # handle any potentially finished tasks:
        self.engine_stream.flush()

        # don't pop destinations, because they might be used later
        # map(self.destinations.pop, self.completed.pop(uid))
        # map(self.destinations.pop, self.failed.pop(uid))

        # prevent this engine from receiving work
        idx = self.targets.index(uid)
        self.targets.pop(idx)
        self.loads.pop(idx)

        # wait 5 seconds before cleaning up pending jobs, since the results might
        # still be incoming
        if self.pending[uid]:
            self.loop.add_timeout(
                self.loop.time() + 5,
                lambda: self.handle_stranded_tasks(uid),
            )
        else:
            self.completed.pop(uid)
            self.failed.pop(uid)

    def handle_stranded_tasks(self, engine):
        """Deal with jobs resident in an engine that died."""
        lost = self.pending[engine]
        for msg_id in lost.keys():
            if msg_id not in self.pending[engine]:
                # prevent double-handling of messages
                continue

            raw_msg = lost[msg_id].raw_msg
            idents, msg = self.session.feed_identities(raw_msg, copy=False)
            parent = self.session.unpack(msg[1].bytes)
            idents = [engine, idents[0]]

            # build fake error reply
            try:
                raise error.EngineError(
                    "Engine %r died while running task %r" % (engine, msg_id))
            except:
                content = error.wrap_exception()
            # build fake metadata
            md = dict(
                status=u'error',
                engine=engine.decode('ascii'),
                date=datetime.now(),
            )
            msg = self.session.msg('apply_reply',
                                   content,
                                   parent=parent,
                                   metadata=md)
            raw_reply = list(
                map(zmq.Message, self.session.serialize(msg, ident=idents)))
            # and dispatch it
            self.dispatch_result(raw_reply)

        # finally scrub completed/failed lists
        self.completed.pop(engine)
        self.failed.pop(engine)

    #-----------------------------------------------------------------------
    # Job Submission
    #-----------------------------------------------------------------------

    @util.log_errors
    def dispatch_submission(self, raw_msg):
        """Dispatch job submission to appropriate handlers."""
        # ensure targets up to date:
        self.notifier_stream.flush()
        try:
            idents, msg = self.session.feed_identities(raw_msg, copy=False)
            msg = self.session.deserialize(msg, content=False, copy=False)
        except Exception:
            self.log.error("task::Invaid task msg: %r" % raw_msg,
                           exc_info=True)
            return

        # send to monitor
        self.mon_stream.send_multipart([b'intask'] + raw_msg, copy=False)

        header = msg['header']
        md = msg['metadata']
        msg_id = header['msg_id']
        self.all_ids.add(msg_id)

        # get targets as a set of bytes objects
        # from a list of unicode objects
        targets = md.get('targets', [])
        targets = set(map(cast_bytes, targets))

        retries = md.get('retries', 0)
        self.retries[msg_id] = retries

        # time dependencies
        after = md.get('after', None)
        if after:
            after = Dependency(after)
            if after.all:
                if after.success:
                    after = Dependency(
                        after.difference(self.all_completed),
                        success=after.success,
                        failure=after.failure,
                        all=after.all,
                    )
                if after.failure:
                    after = Dependency(
                        after.difference(self.all_failed),
                        success=after.success,
                        failure=after.failure,
                        all=after.all,
                    )
            if after.check(self.all_completed, self.all_failed):
                # recast as empty set, if `after` already met,
                # to prevent unnecessary set comparisons
                after = MET
        else:
            after = MET

        # location dependencies
        follow = Dependency(md.get('follow', []))

        timeout = md.get('timeout', None)
        if timeout:
            timeout = float(timeout)

        job = Job(
            msg_id=msg_id,
            raw_msg=raw_msg,
            idents=idents,
            msg=msg,
            header=header,
            targets=targets,
            after=after,
            follow=follow,
            timeout=timeout,
            metadata=md,
        )
        # validate and reduce dependencies:
        for dep in after, follow:
            if not dep:  # empty dependency
                continue
            # check valid:
            if msg_id in dep or dep.difference(self.all_ids):
                self.queue_map[msg_id] = job
                return self.fail_unreachable(msg_id, error.InvalidDependency)
            # check if unreachable:
            if dep.unreachable(self.all_completed, self.all_failed):
                self.queue_map[msg_id] = job
                return self.fail_unreachable(msg_id)

        if after.check(self.all_completed, self.all_failed):
            # time deps already met, try to run
            if not self.maybe_run(job):
                # can't run yet
                if msg_id not in self.all_failed:
                    # could have failed as unreachable
                    self.save_unmet(job)
        else:
            self.save_unmet(job)

    def job_timeout(self, job, timeout_id):
        """callback for a job's timeout.
        
        The job may or may not have been run at this point.
        """
        if job.timeout_id != timeout_id:
            # not the most recent call
            return
        now = time.time()
        if job.timeout >= (now + 1):
            self.log.warn("task %s timeout fired prematurely: %s > %s",
                          job.msg_id, job.timeout, now)
        if job.msg_id in self.queue_map:
            # still waiting, but ran out of time
            self.log.info("task %r timed out", job.msg_id)
            self.fail_unreachable(job.msg_id, error.TaskTimeout)

    def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
        """a task has become unreachable, send a reply with an ImpossibleDependency
        error."""
        if msg_id not in self.queue_map:
            self.log.error("task %r already failed!", msg_id)
            return
        job = self.queue_map.pop(msg_id)
        # lazy-delete from the queue
        job.removed = True
        for mid in job.dependents:
            if mid in self.graph:
                self.graph[mid].remove(msg_id)

        try:
            raise why()
        except:
            content = error.wrap_exception()
        self.log.debug("task %r failing as unreachable with: %s", msg_id,
                       content['ename'])

        self.all_done.add(msg_id)
        self.all_failed.add(msg_id)

        msg = self.session.send(self.client_stream,
                                'apply_reply',
                                content,
                                parent=job.header,
                                ident=job.idents)
        self.session.send(self.mon_stream,
                          msg,
                          ident=[b'outtask'] + job.idents)

        self.update_graph(msg_id, success=False)

    def available_engines(self):
        """return a list of available engine indices based on HWM"""
        if not self.hwm:
            return list(range(len(self.targets)))
        available = []
        for idx in range(len(self.targets)):
            if self.loads[idx] < self.hwm:
                available.append(idx)
        return available

    def maybe_run(self, job):
        """check location dependencies, and run if they are met."""
        msg_id = job.msg_id
        self.log.debug("Attempting to assign task %s", msg_id)
        available = self.available_engines()
        if not available:
            # no engines, definitely can't run
            return False

        if job.follow or job.targets or job.blacklist or self.hwm:
            # we need a can_run filter
            def can_run(idx):
                # check hwm
                if self.hwm and self.loads[idx] == self.hwm:
                    return False
                target = self.targets[idx]
                # check blacklist
                if target in job.blacklist:
                    return False
                # check targets
                if job.targets and target not in job.targets:
                    return False
                # check follow
                return job.follow.check(self.completed[target],
                                        self.failed[target])

            indices = list(filter(can_run, available))

            if not indices:
                # couldn't run
                if job.follow.all:
                    # check follow for impossibility
                    dests = set()
                    relevant = set()
                    if job.follow.success:
                        relevant = self.all_completed
                    if job.follow.failure:
                        relevant = relevant.union(self.all_failed)
                    for m in job.follow.intersection(relevant):
                        dests.add(self.destinations[m])
                    if len(dests) > 1:
                        self.queue_map[msg_id] = job
                        self.fail_unreachable(msg_id)
                        return False
                if job.targets:
                    # check blacklist+targets for impossibility
                    job.targets.difference_update(job.blacklist)
                    if not job.targets or not job.targets.intersection(
                            self.targets):
                        self.queue_map[msg_id] = job
                        self.fail_unreachable(msg_id)
                        return False
                return False
        else:
            indices = None

        self.submit_task(job, indices)
        return True

    def save_unmet(self, job):
        """Save a message for later submission when its dependencies are met."""
        msg_id = job.msg_id
        self.log.debug("Adding task %s to the queue", msg_id)
        self.queue_map[msg_id] = job
        self.queue.append(job)
        # track the ids in follow or after, but not those already finished
        for dep_id in job.after.union(job.follow).difference(self.all_done):
            if dep_id not in self.graph:
                self.graph[dep_id] = set()
            self.graph[dep_id].add(msg_id)

        # schedule timeout callback
        if job.timeout:
            timeout_id = job.timeout_id = job.timeout_id + 1
            self.loop.add_timeout(time.time() + job.timeout,
                                  lambda: self.job_timeout(job, timeout_id))

    def submit_task(self, job, indices=None):
        """Submit a task to any of a subset of our targets."""
        if indices:
            loads = [self.loads[i] for i in indices]
        else:
            loads = self.loads
        idx = self.scheme(loads)
        if indices:
            idx = indices[idx]
        target = self.targets[idx]
        # print (target, map(str, msg[:3]))
        # send job to the engine
        self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
        self.engine_stream.send_multipart(job.raw_msg, copy=False)
        # update load
        self.add_job(idx)
        self.pending[target][job.msg_id] = job
        # notify Hub
        content = dict(msg_id=job.msg_id, engine_id=target.decode('ascii'))
        self.session.send(self.mon_stream,
                          'task_destination',
                          content=content,
                          ident=[b'tracktask', self.ident])

    #-----------------------------------------------------------------------
    # Result Handling
    #-----------------------------------------------------------------------

    @util.log_errors
    def dispatch_result(self, raw_msg):
        """dispatch method for result replies"""
        try:
            idents, msg = self.session.feed_identities(raw_msg, copy=False)
            msg = self.session.deserialize(msg, content=False, copy=False)
            engine = idents[0]
            try:
                idx = self.targets.index(engine)
            except ValueError:
                pass  # skip load-update for dead engines
            else:
                self.finish_job(idx)
        except Exception:
            self.log.error("task::Invalid result: %r", raw_msg, exc_info=True)
            return

        md = msg['metadata']
        parent = msg['parent_header']
        if md.get('dependencies_met', True):
            success = (md['status'] == 'ok')
            msg_id = parent['msg_id']
            retries = self.retries[msg_id]
            if not success and retries > 0:
                # failed
                self.retries[msg_id] = retries - 1
                self.handle_unmet_dependency(idents, parent)
            else:
                del self.retries[msg_id]
                # relay to client and update graph
                self.handle_result(idents, parent, raw_msg, success)
                # send to Hub monitor
                self.mon_stream.send_multipart([b'outtask'] + raw_msg,
                                               copy=False)
        else:
            self.handle_unmet_dependency(idents, parent)

    def handle_result(self, idents, parent, raw_msg, success=True):
        """handle a real task result, either success or failure"""
        # first, relay result to client
        engine = idents[0]
        client = idents[1]
        # swap_ids for ROUTER-ROUTER mirror
        raw_msg[:2] = [client, engine]
        # print (map(str, raw_msg[:4]))
        self.client_stream.send_multipart(raw_msg, copy=False)
        # now, update our data structures
        msg_id = parent['msg_id']
        self.pending[engine].pop(msg_id)
        if success:
            self.completed[engine].add(msg_id)
            self.all_completed.add(msg_id)
        else:
            self.failed[engine].add(msg_id)
            self.all_failed.add(msg_id)
        self.all_done.add(msg_id)
        self.destinations[msg_id] = engine

        self.update_graph(msg_id, success)

    def handle_unmet_dependency(self, idents, parent):
        """handle an unmet dependency"""
        engine = idents[0]
        msg_id = parent['msg_id']

        job = self.pending[engine].pop(msg_id)
        job.blacklist.add(engine)

        if job.blacklist == job.targets:
            self.queue_map[msg_id] = job
            self.fail_unreachable(msg_id)
        elif not self.maybe_run(job):
            # resubmit failed
            if msg_id not in self.all_failed:
                # put it back in our dependency tree
                self.save_unmet(job)

        if self.hwm:
            try:
                idx = self.targets.index(engine)
            except ValueError:
                pass  # skip load-update for dead engines
            else:
                if self.loads[idx] == self.hwm - 1:
                    self.update_graph(None)

    def update_graph(self, dep_id=None, success=True):
        """dep_id just finished. Update our dependency
        graph and submit any jobs that just became runnable.

        Called with dep_id=None to update entire graph for hwm, but without finishing a task.
        """
        # print ("\n\n***********")
        # pprint (dep_id)
        # pprint (self.graph)
        # pprint (self.queue_map)
        # pprint (self.all_completed)
        # pprint (self.all_failed)
        # print ("\n\n***********\n\n")
        # update any jobs that depended on the dependency
        msg_ids = self.graph.pop(dep_id, [])

        # recheck *all* jobs if
        # a) we have HWM and an engine just become no longer full
        # or b) dep_id was given as None

        if dep_id is None or self.hwm and any(
            [load == self.hwm - 1 for load in self.loads]):
            jobs = self.queue
            using_queue = True
        else:
            using_queue = False
            jobs = deque(sorted(self.queue_map[msg_id] for msg_id in msg_ids))

        to_restore = []
        while jobs:
            job = jobs.popleft()
            if job.removed:
                continue
            msg_id = job.msg_id

            put_it_back = True

            if job.after.unreachable(self.all_completed, self.all_failed)\
                    or job.follow.unreachable(self.all_completed, self.all_failed):
                self.fail_unreachable(msg_id)
                put_it_back = False

            elif job.after.check(self.all_completed,
                                 self.all_failed):  # time deps met, maybe run
                if self.maybe_run(job):
                    put_it_back = False
                    self.queue_map.pop(msg_id)
                    for mid in job.dependents:
                        if mid in self.graph:
                            self.graph[mid].remove(msg_id)

                    # abort the loop if we just filled up all of our engines.
                    # avoids an O(N) operation in situation of full queue,
                    # where graph update is triggered as soon as an engine becomes
                    # non-full, and all tasks after the first are checked,
                    # even though they can't run.
                    if not self.available_engines():
                        break

            if using_queue and put_it_back:
                # popped a job from the queue but it neither ran nor failed,
                # so we need to put it back when we are done
                # make sure to_restore preserves the same ordering
                to_restore.append(job)

        # put back any tasks we popped but didn't run
        if using_queue:
            self.queue.extendleft(to_restore)

    #----------------------------------------------------------------------
    # methods to be overridden by subclasses
    #----------------------------------------------------------------------

    def add_job(self, idx):
        """Called after self.targets[idx] just got the job with header.
        Override with subclasses.  The default ordering is simple LRU.
        The default loads are the number of outstanding jobs."""
        self.loads[idx] += 1
        for lis in (self.targets, self.loads):
            lis.append(lis.pop(idx))

    def finish_job(self, idx):
        """Called after self.targets[idx] just finished a job.
        Override with subclasses."""
        self.loads[idx] -= 1