Exemplo n.º 1
0
class EngineConnector(HasTraits):
    """A simple object for accessing the various zmq connections of an object.
    Attributes are:
    id (int): engine ID
    uuid (str): uuid (unused?)
    queue (str): identity of queue's XREQ socket
    registration (str): identity of registration XREQ socket
    heartbeat (str): identity of heartbeat XREQ socket
    """
    id=Integer(0)
    queue=CBytes()
    control=CBytes()
    registration=CBytes()
    heartbeat=CBytes()
    pending=Set()
Exemplo n.º 2
0
class ZMQDataPublisher(Configurable):

    topic = topic = CBytes(b'datapub')
    session = Instance(Session)
    pub_socket = Instance('zmq.Socket')
    parent_header = Dict({})

    def set_parent(self, parent):
        """Set the parent for outbound messages."""
        self.parent_header = extract_header(parent)
    
    def publish_data(self, data):
        """publish a data_message on the IOPub channel
    
        Parameters
        ----------
    
        data : dict
            The data to be published. Think of it as a namespace.
        """
        session = self.session
        buffers = serialize_object(data,
            buffer_threshold=session.buffer_threshold,
            item_threshold=session.item_threshold,
        )
        content = json_clean(dict(keys=data.keys()))
        session.send(self.pub_socket, 'data_message', content=content,
            parent=self.parent_header,
            buffers=buffers,
            ident=self.topic,
        )
Exemplo n.º 3
0
class ZMQDisplayPublisher(DisplayPublisher):
    """A display publisher that publishes data using a ZeroMQ PUB socket."""

    session = Instance(Session)
    pub_socket = Instance(SocketABC)
    parent_header = Dict({})
    topic = CBytes(b'displaypub')

    def set_parent(self, parent):
        """Set the parent for outbound messages."""
        self.parent_header = extract_header(parent)

    def _flush_streams(self):
        """flush IO Streams prior to display"""
        sys.stdout.flush()
        sys.stderr.flush()

    def publish(self, source, data, metadata=None):
        self._flush_streams()
        if metadata is None:
            metadata = {}
        self._validate_data(source, data, metadata)
        content = {}
        content['source'] = source
        content['data'] = encode_images(data)
        content['metadata'] = metadata
        self.session.send(
            self.pub_socket,
            u'display_data',
            json_clean(content),
            parent=self.parent_header,
            ident=self.topic,
        )

    def clear_output(self, stdout=True, stderr=True, other=True):
        content = dict(stdout=stdout, stderr=stderr, other=other)

        if stdout:
            print('\r', file=sys.stdout, end='')
        if stderr:
            print('\r', file=sys.stderr, end='')

        self._flush_streams()

        self.session.send(
            self.pub_socket,
            u'clear_output',
            content,
            parent=self.parent_header,
            ident=self.topic,
        )
Exemplo n.º 4
0
class ZMQDisplayPublisher(DisplayPublisher):
    """A display publisher that publishes data using a ZeroMQ PUB socket."""

    session = Instance(Session, allow_none=True)
    pub_socket = Instance(SocketABC, allow_none=True)
    parent_header = Dict({})
    topic = CBytes(b'display_data')

    def set_parent(self, parent):
        """Set the parent for outbound messages."""
        self.parent_header = extract_header(parent)

    def _flush_streams(self):
        """flush IO Streams prior to display"""
        sys.stdout.flush()
        sys.stderr.flush()

    def publish(self, data, metadata=None, source=None):
        self._flush_streams()
        if metadata is None:
            metadata = {}
        self._validate_data(data, metadata)
        content = {}
        content['data'] = encode_images(data)
        content['metadata'] = metadata
        self.session.send(
            self.pub_socket,
            u'display_data',
            json_clean(content),
            parent=self.parent_header,
            ident=self.topic,
        )

    def clear_output(self, wait=False):
        content = dict(wait=wait)
        self._flush_streams()
        self.session.send(
            self.pub_socket,
            u'clear_output',
            content,
            parent=self.parent_header,
            ident=self.topic,
        )
Exemplo n.º 5
0
class Session(Configurable):
    """Object for handling serialization and sending of messages.
    
    The Session object handles building messages and sending them
    with ZMQ sockets or ZMQStream objects.  Objects can communicate with each
    other over the network via Session objects, and only need to work with the
    dict-based IPython message spec. The Session will handle 
    serialization/deserialization, security, and metadata.
    
    Sessions support configurable serialiization via packer/unpacker traits,
    and signing with HMAC digests via the key/keyfile traits.
    
    Parameters
    ----------
    
    debug : bool
        whether to trigger extra debugging statements
    packer/unpacker : str : 'json', 'pickle' or import_string
        importstrings for methods to serialize message parts.  If just
        'json' or 'pickle', predefined JSON and pickle packers will be used.
        Otherwise, the entire importstring must be used.
        
        The functions must accept at least valid JSON input, and output *bytes*.
        
        For example, to use msgpack:
        packer = 'msgpack.packb', unpacker='msgpack.unpackb'
    pack/unpack : callables
        You can also set the pack/unpack callables for serialization directly.
    session : bytes
        the ID of this Session object.  The default is to generate a new UUID.
    username : unicode
        username added to message headers.  The default is to ask the OS.
    key : bytes
        The key used to initialize an HMAC signature.  If unset, messages
        will not be signed or checked.
    keyfile : filepath
        The file containing a key.  If this is set, `key` will be initialized
        to the contents of the file.
    
    """

    debug = Bool(False, config=True, help="""Debug output in the Session""")

    packer = DottedObjectName(
        'json',
        config=True,
        help="""The name of the packer for serializing messages.
            Should be one of 'json', 'pickle', or an import name
            for a custom callable serializer.""")

    def _packer_changed(self, name, old, new):
        if new.lower() == 'json':
            self.pack = json_packer
            self.unpack = json_unpacker
        elif new.lower() == 'pickle':
            self.pack = pickle_packer
            self.unpack = pickle_unpacker
        else:
            self.pack = import_item(str(new))

    unpacker = DottedObjectName(
        'json',
        config=True,
        help="""The name of the unpacker for unserializing messages.
        Only used with custom functions for `packer`.""")

    def _unpacker_changed(self, name, old, new):
        if new.lower() == 'json':
            self.pack = json_packer
            self.unpack = json_unpacker
        elif new.lower() == 'pickle':
            self.pack = pickle_packer
            self.unpack = pickle_unpacker
        else:
            self.unpack = import_item(str(new))

    session = CBytes(b'',
                     config=True,
                     help="""The UUID identifying this session.""")

    def _session_default(self):
        return bytes(uuid.uuid4())

    username = Unicode(
        os.environ.get('USER', 'username'),
        config=True,
        help="""Username for the Session. Default is your system username.""")

    # message signature related traits:
    key = CBytes(b'',
                 config=True,
                 help="""execution key, for extra authentication.""")

    def _key_changed(self, name, old, new):
        if new:
            self.auth = hmac.HMAC(new)
        else:
            self.auth = None

    auth = Instance(hmac.HMAC)
    digest_history = Set()

    keyfile = Unicode('',
                      config=True,
                      help="""path to file containing execution key.""")

    def _keyfile_changed(self, name, old, new):
        with open(new, 'rb') as f:
            self.key = f.read().strip()

    pack = Any(default_packer)  # the actual packer function

    def _pack_changed(self, name, old, new):
        if not callable(new):
            raise TypeError("packer must be callable, not %s" % type(new))

    unpack = Any(default_unpacker)  # the actual packer function

    def _unpack_changed(self, name, old, new):
        # unpacker is not checked - it is assumed to be
        if not callable(new):
            raise TypeError("unpacker must be callable, not %s" % type(new))

    def __init__(self, **kwargs):
        """create a Session object
        
        Parameters
        ----------

        debug : bool
            whether to trigger extra debugging statements
        packer/unpacker : str : 'json', 'pickle' or import_string
            importstrings for methods to serialize message parts.  If just
            'json' or 'pickle', predefined JSON and pickle packers will be used.
            Otherwise, the entire importstring must be used.

            The functions must accept at least valid JSON input, and output
            *bytes*.

            For example, to use msgpack:
            packer = 'msgpack.packb', unpacker='msgpack.unpackb'
        pack/unpack : callables
            You can also set the pack/unpack callables for serialization
            directly.
        session : bytes
            the ID of this Session object.  The default is to generate a new 
            UUID.
        username : unicode
            username added to message headers.  The default is to ask the OS.
        key : bytes
            The key used to initialize an HMAC signature.  If unset, messages
            will not be signed or checked.
        keyfile : filepath
            The file containing a key.  If this is set, `key` will be 
            initialized to the contents of the file.
        """
        super(Session, self).__init__(**kwargs)
        self._check_packers()
        self.none = self.pack({})

    @property
    def msg_id(self):
        """always return new uuid"""
        return str(uuid.uuid4())

    def _check_packers(self):
        """check packers for binary data and datetime support."""
        pack = self.pack
        unpack = self.unpack

        # check simple serialization
        msg = dict(a=[1, 'hi'])
        try:
            packed = pack(msg)
        except Exception:
            raise ValueError("packer could not serialize a simple message")

        # ensure packed message is bytes
        if not isinstance(packed, bytes):
            raise ValueError("message packed to %r, but bytes are required" %
                             type(packed))

        # check that unpack is pack's inverse
        try:
            unpacked = unpack(packed)
        except Exception:
            raise ValueError("unpacker could not handle the packer's output")

        # check datetime support
        msg = dict(t=datetime.now())
        try:
            unpacked = unpack(pack(msg))
        except Exception:
            self.pack = lambda o: pack(squash_dates(o))
            self.unpack = lambda s: extract_dates(unpack(s))

    def msg_header(self, msg_type):
        return msg_header(self.msg_id, msg_type, self.username, self.session)

    def msg(self, msg_type, content=None, parent=None, subheader=None):
        """Return the nested message dict.

        This format is different from what is sent over the wire. The
        self.serialize method converts this nested message dict to the wire
        format, which uses a message list.
        """
        msg = {}
        msg['header'] = self.msg_header(msg_type)
        msg['msg_id'] = msg['header']['msg_id']
        msg['parent_header'] = {} if parent is None else extract_header(parent)
        msg['msg_type'] = msg_type
        msg['content'] = {} if content is None else content
        sub = {} if subheader is None else subheader
        msg['header'].update(sub)
        return msg

    def sign(self, msg_list):
        """Sign a message with HMAC digest. If no auth, return b''.

        Parameters
        ----------
        msg_list : list
            The [p_header,p_parent,p_content] part of the message list. 
        """
        if self.auth is None:
            return b''
        h = self.auth.copy()
        for m in msg_list:
            h.update(m)
        return h.hexdigest()

    def serialize(self, msg, ident=None):
        """Serialize the message components to bytes.

        Parameters
        ----------
        msg : dict or Message
            The nexted message dict as returned by the self.msg method.

        Returns
        -------
        msg_list : list
            The list of bytes objects to be sent with the format:
            [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
             buffer1,buffer2,...]. In this list, the p_* entities are
            the packed or serialized versions, so if JSON is used, these
            are uft8 encoded JSON strings.
        """
        content = msg.get('content', {})
        if content is None:
            content = self.none
        elif isinstance(content, dict):
            content = self.pack(content)
        elif isinstance(content, bytes):
            # content is already packed, as in a relayed message
            pass
        elif isinstance(content, unicode):
            # should be bytes, but JSON often spits out unicode
            content = content.encode('utf8')
        else:
            raise TypeError("Content incorrect type: %s" % type(content))

        real_message = [
            self.pack(msg['header']),
            self.pack(msg['parent_header']), content
        ]

        to_send = []

        if isinstance(ident, list):
            # accept list of idents
            to_send.extend(ident)
        elif ident is not None:
            to_send.append(ident)
        to_send.append(DELIM)

        signature = self.sign(real_message)
        to_send.append(signature)

        to_send.extend(real_message)

        return to_send

    def send(self,
             stream,
             msg_or_type,
             content=None,
             parent=None,
             ident=None,
             buffers=None,
             subheader=None,
             track=False):
        """Build and send a message via stream or socket.

        The message format used by this function internally is as follows:

        [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
         buffer1,buffer2,...] 

        The self.serialize method converts the nested message dict into this
        format.

        Parameters
        ----------
        
        stream : zmq.Socket or ZMQStream
            the socket-like object used to send the data
        msg_or_type : str or Message/dict
            Normally, msg_or_type will be a msg_type unless a message is being 
            sent more than once.
        
        content : dict or None
            the content of the message (ignored if msg_or_type is a message)
        parent : Message or dict or None
            the parent or parent header describing the parent of this message
        ident : bytes or list of bytes
            the zmq.IDENTITY routing path
        subheader : dict or None
            extra header keys for this message's header
        buffers : list or None
            the already-serialized buffers to be appended to the message
        track : bool
            whether to track.  Only for use with Sockets, 
            because ZMQStream objects cannot track messages.
        
        Returns
        -------
        msg : message dict
            the constructed message
        (msg,tracker) : (message dict, MessageTracker)
            if track=True, then a 2-tuple will be returned, 
            the first element being the constructed
            message, and the second being the MessageTracker
            
        """

        if not isinstance(stream, (zmq.Socket, ZMQStream)):
            raise TypeError("stream must be Socket or ZMQStream, not %r" %
                            type(stream))
        elif track and isinstance(stream, ZMQStream):
            raise TypeError("ZMQStream cannot track messages")

        if isinstance(msg_or_type, (Message, dict)):
            # we got a Message, not a msg_type
            # don't build a new Message
            msg = msg_or_type
        else:
            msg = self.msg(msg_or_type, content, parent, subheader)

        buffers = [] if buffers is None else buffers
        to_send = self.serialize(msg, ident)
        flag = 0
        if buffers:
            flag = zmq.SNDMORE
            _track = False
        else:
            _track = track
        if track:
            tracker = stream.send_multipart(to_send,
                                            flag,
                                            copy=False,
                                            track=_track)
        else:
            tracker = stream.send_multipart(to_send, flag, copy=False)
        for b in buffers[:-1]:
            stream.send(b, flag, copy=False)
        if buffers:
            if track:
                tracker = stream.send(buffers[-1], copy=False, track=track)
            else:
                tracker = stream.send(buffers[-1], copy=False)

        # omsg = Message(msg)
        if self.debug:
            pprint.pprint(msg)
            pprint.pprint(to_send)
            pprint.pprint(buffers)

        msg['tracker'] = tracker

        return msg

    def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
        """Send a raw message via ident path.

        This method is used to send a already serialized message.

        Parameters
        ----------
        stream : ZMQStream or Socket
            The ZMQ stream or socket to use for sending the message.
        msg_list : list
            The serialized list of messages to send. This only includes the
            [p_header,p_parent,p_content,buffer1,buffer2,...] portion of
            the message.
        ident : ident or list
            A single ident or a list of idents to use in sending.
        """
        to_send = []
        if isinstance(ident, bytes):
            ident = [ident]
        if ident is not None:
            to_send.extend(ident)

        to_send.append(DELIM)
        to_send.append(self.sign(msg_list))
        to_send.extend(msg_list)
        stream.send_multipart(msg_list, flags, copy=copy)

    def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
        """Receive and unpack a message.

        Parameters
        ----------
        socket : ZMQStream or Socket
            The socket or stream to use in receiving.

        Returns
        -------
        [idents], msg
            [idents] is a list of idents and msg is a nested message dict of
            same format as self.msg returns.
        """
        if isinstance(socket, ZMQStream):
            socket = socket.socket
        try:
            msg_list = socket.recv_multipart(mode)
        except zmq.ZMQError as e:
            if e.errno == zmq.EAGAIN:
                # We can convert EAGAIN to None as we know in this case
                # recv_multipart won't return None.
                return None, None
            else:
                raise
        # split multipart message into identity list and message dict
        # invalid large messages can cause very expensive string comparisons
        idents, msg_list = self.feed_identities(msg_list, copy)
        try:
            return idents, self.unpack_message(msg_list,
                                               content=content,
                                               copy=copy)
        except Exception as e:
            print(idents, msg_list)
            # TODO: handle it
            raise e

    def feed_identities(self, msg_list, copy=True):
        """Split the identities from the rest of the message.

        Feed until DELIM is reached, then return the prefix as idents and
        remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
        but that would be silly.
        
        Parameters
        ----------
        msg_list : a list of Message or bytes objects
            The message to be split.
        copy : bool
            flag determining whether the arguments are bytes or Messages
        
        Returns
        -------
        (idents,msg_list) : two lists
            idents will always be a list of bytes - the indentity prefix
            msg_list will be a list of bytes or Messages, unchanged from input
            msg_list should be unpackable via self.unpack_message at this point.
        """
        if copy:
            idx = msg_list.index(DELIM)
            return msg_list[:idx], msg_list[idx + 1:]
        else:
            failed = True
            for idx, m in enumerate(msg_list):
                if m.bytes == DELIM:
                    failed = False
                    break
            if failed:
                raise ValueError("DELIM not in msg_list")
            idents, msg_list = msg_list[:idx], msg_list[idx + 1:]
            return [m.bytes for m in idents], msg_list

    def unpack_message(self, msg_list, content=True, copy=True):
        """Return a message object from the format
        sent by self.send.
        
        Parameters:
        -----------
        
        content : bool (True)
            whether to unpack the content dict (True), 
            or leave it serialized (False)
        
        copy : bool (True)
            whether to return the bytes (True), 
            or the non-copying Message object in each place (False)
        
        """
        minlen = 4
        message = {}
        if not copy:
            for i in range(minlen):
                msg_list[i] = msg_list[i].bytes
        if self.auth is not None:
            signature = msg_list[0]
            if signature in self.digest_history:
                raise ValueError("Duplicate Signature: %r" % signature)
            self.digest_history.add(signature)
            check = self.sign(msg_list[1:4])
            if not signature == check:
                raise ValueError("Invalid Signature: %r" % signature)
        if not len(msg_list) >= minlen:
            raise TypeError(
                "malformed message, must have at least %i elements" % minlen)
        message['header'] = self.unpack(msg_list[1])
        message['msg_type'] = message['header']['msg_type']
        message['parent_header'] = self.unpack(msg_list[2])
        if content:
            message['content'] = self.unpack(msg_list[3])
        else:
            message['content'] = msg_list[3]

        message['buffers'] = msg_list[4:]
        return message
Exemplo n.º 6
0
class Session(Configurable):
    """Object for handling serialization and sending of messages.

    The Session object handles building messages and sending them
    with ZMQ sockets or ZMQStream objects.  Objects can communicate with each
    other over the network via Session objects, and only need to work with the
    dict-based IPython message spec. The Session will handle
    serialization/deserialization, security, and metadata.

    Sessions support configurable serialization via packer/unpacker traits,
    and signing with HMAC digests via the key/keyfile traits.

    Parameters
    ----------

    debug : bool
        whether to trigger extra debugging statements
    packer/unpacker : str : 'json', 'pickle' or import_string
        importstrings for methods to serialize message parts.  If just
        'json' or 'pickle', predefined JSON and pickle packers will be used.
        Otherwise, the entire importstring must be used.

        The functions must accept at least valid JSON input, and output *bytes*.

        For example, to use msgpack:
        packer = 'msgpack.packb', unpacker='msgpack.unpackb'
    pack/unpack : callables
        You can also set the pack/unpack callables for serialization directly.
    session : bytes
        the ID of this Session object.  The default is to generate a new UUID.
    username : unicode
        username added to message headers.  The default is to ask the OS.
    key : bytes
        The key used to initialize an HMAC signature.  If unset, messages
        will not be signed or checked.
    keyfile : filepath
        The file containing a key.  If this is set, `key` will be initialized
        to the contents of the file.

    """

    debug = Bool(False, config=True, help="""Debug output in the Session""")

    packer = DottedObjectName(
        'json',
        config=True,
        help="""The name of the packer for serializing messages.
            Should be one of 'json', 'pickle', or an import name
            for a custom callable serializer.""")

    def _packer_changed(self, name, old, new):
        if new.lower() == 'json':
            self.pack = json_packer
            self.unpack = json_unpacker
            self.unpacker = new
        elif new.lower() == 'pickle':
            self.pack = pickle_packer
            self.unpack = pickle_unpacker
            self.unpacker = new
        else:
            self.pack = import_item(str(new))

    unpacker = DottedObjectName(
        'json',
        config=True,
        help="""The name of the unpacker for unserializing messages.
        Only used with custom functions for `packer`.""")

    def _unpacker_changed(self, name, old, new):
        if new.lower() == 'json':
            self.pack = json_packer
            self.unpack = json_unpacker
            self.packer = new
        elif new.lower() == 'pickle':
            self.pack = pickle_packer
            self.unpack = pickle_unpacker
            self.packer = new
        else:
            self.unpack = import_item(str(new))

    session = CUnicode(u'',
                       config=True,
                       help="""The UUID identifying this session.""")

    def _session_default(self):
        u = unicode_type(uuid.uuid4())
        self.bsession = u.encode('ascii')
        return u

    def _session_changed(self, name, old, new):
        self.bsession = self.session.encode('ascii')

    # bsession is the session as bytes
    bsession = CBytes(b'')

    username = Unicode(
        str_to_unicode(os.environ.get('USER', 'username')),
        help="""Username for the Session. Default is your system username.""",
        config=True)

    metadata = Dict(
        {},
        config=True,
        help=
        """Metadata dictionary, which serves as the default top-level metadata dict for each message."""
    )

    # if 0, no adapting to do.
    adapt_version = Integer(0)

    # message signature related traits:

    key = CBytes(b'',
                 config=True,
                 help="""execution key, for extra authentication.""")

    def _key_changed(self):
        self._new_auth()

    signature_scheme = Unicode(
        'hmac-sha256',
        config=True,
        help="""The digest scheme used to construct the message signatures.
        Must have the form 'hmac-HASH'.""")

    def _signature_scheme_changed(self, name, old, new):
        if not new.startswith('hmac-'):
            raise TraitError(
                "signature_scheme must start with 'hmac-', got %r" % new)
        hash_name = new.split('-', 1)[1]
        try:
            self.digest_mod = getattr(hashlib, hash_name)
        except AttributeError:
            raise TraitError("hashlib has no such attribute: %s" % hash_name)
        self._new_auth()

    digest_mod = Any()

    def _digest_mod_default(self):
        return hashlib.sha256

    auth = Instance(hmac.HMAC)

    def _new_auth(self):
        if self.key:
            self.auth = hmac.HMAC(self.key, digestmod=self.digest_mod)
        else:
            self.auth = None

    digest_history = Set()
    digest_history_size = Integer(
        2**16,
        config=True,
        help="""The maximum number of digests to remember.
        
        The digest history will be culled when it exceeds this value.
        """)

    keyfile = Unicode('',
                      config=True,
                      help="""path to file containing execution key.""")

    def _keyfile_changed(self, name, old, new):
        with open(new, 'rb') as f:
            self.key = f.read().strip()

    # for protecting against sends from forks
    pid = Integer()

    # serialization traits:

    pack = Any(default_packer)  # the actual packer function

    def _pack_changed(self, name, old, new):
        if not callable(new):
            raise TypeError("packer must be callable, not %s" % type(new))

    unpack = Any(default_unpacker)  # the actual packer function

    def _unpack_changed(self, name, old, new):
        # unpacker is not checked - it is assumed to be
        if not callable(new):
            raise TypeError("unpacker must be callable, not %s" % type(new))

    # thresholds:
    copy_threshold = Integer(
        2**16,
        config=True,
        help=
        "Threshold (in bytes) beyond which a buffer should be sent without copying."
    )
    buffer_threshold = Integer(
        MAX_BYTES,
        config=True,
        help=
        "Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling."
    )
    item_threshold = Integer(
        MAX_ITEMS,
        config=True,
        help=
        """The maximum number of items for a container to be introspected for custom serialization.
        Containers larger than this are pickled outright.
        """)

    def __init__(self, **kwargs):
        """create a Session object

        Parameters
        ----------

        debug : bool
            whether to trigger extra debugging statements
        packer/unpacker : str : 'json', 'pickle' or import_string
            importstrings for methods to serialize message parts.  If just
            'json' or 'pickle', predefined JSON and pickle packers will be used.
            Otherwise, the entire importstring must be used.

            The functions must accept at least valid JSON input, and output
            *bytes*.

            For example, to use msgpack:
            packer = 'msgpack.packb', unpacker='msgpack.unpackb'
        pack/unpack : callables
            You can also set the pack/unpack callables for serialization
            directly.
        session : unicode (must be ascii)
            the ID of this Session object.  The default is to generate a new
            UUID.
        bsession : bytes
            The session as bytes
        username : unicode
            username added to message headers.  The default is to ask the OS.
        key : bytes
            The key used to initialize an HMAC signature.  If unset, messages
            will not be signed or checked.
        signature_scheme : str
            The message digest scheme. Currently must be of the form 'hmac-HASH',
            where 'HASH' is a hashing function available in Python's hashlib.
            The default is 'hmac-sha256'.
            This is ignored if 'key' is empty.
        keyfile : filepath
            The file containing a key.  If this is set, `key` will be
            initialized to the contents of the file.
        """
        super(Session, self).__init__(**kwargs)
        self._check_packers()
        self.none = self.pack({})
        # ensure self._session_default() if necessary, so bsession is defined:
        self.session
        self.pid = os.getpid()

    @property
    def msg_id(self):
        """always return new uuid"""
        return str(uuid.uuid4())

    def _check_packers(self):
        """check packers for datetime support."""
        pack = self.pack
        unpack = self.unpack

        # check simple serialization
        msg = dict(a=[1, 'hi'])
        try:
            packed = pack(msg)
        except Exception as e:
            msg = "packer '{packer}' could not serialize a simple message: {e}{jsonmsg}"
            if self.packer == 'json':
                jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
            else:
                jsonmsg = ""
            raise ValueError(
                msg.format(packer=self.packer, e=e, jsonmsg=jsonmsg))

        # ensure packed message is bytes
        if not isinstance(packed, bytes):
            raise ValueError("message packed to %r, but bytes are required" %
                             type(packed))

        # check that unpack is pack's inverse
        try:
            unpacked = unpack(packed)
            assert unpacked == msg
        except Exception as e:
            msg = "unpacker '{unpacker}' could not handle output from packer '{packer}': {e}{jsonmsg}"
            if self.packer == 'json':
                jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
            else:
                jsonmsg = ""
            raise ValueError(
                msg.format(packer=self.packer,
                           unpacker=self.unpacker,
                           e=e,
                           jsonmsg=jsonmsg))

        # check datetime support
        msg = dict(t=datetime.now())
        try:
            unpacked = unpack(pack(msg))
            if isinstance(unpacked['t'], datetime):
                raise ValueError("Shouldn't deserialize to datetime")
        except Exception:
            self.pack = lambda o: pack(squash_dates(o))
            self.unpack = lambda s: unpack(s)

    def msg_header(self, msg_type):
        return msg_header(self.msg_id, msg_type, self.username, self.session)

    def msg(self,
            msg_type,
            content=None,
            parent=None,
            header=None,
            metadata=None):
        """Return the nested message dict.

        This format is different from what is sent over the wire. The
        serialize/deserialize methods converts this nested message dict to the wire
        format, which is a list of message parts.
        """
        msg = {}
        header = self.msg_header(msg_type) if header is None else header
        msg['header'] = header
        msg['msg_id'] = header['msg_id']
        msg['msg_type'] = header['msg_type']
        msg['parent_header'] = {} if parent is None else extract_header(parent)
        msg['content'] = {} if content is None else content
        msg['metadata'] = self.metadata.copy()
        if metadata is not None:
            msg['metadata'].update(metadata)
        return msg

    def sign(self, msg_list):
        """Sign a message with HMAC digest. If no auth, return b''.

        Parameters
        ----------
        msg_list : list
            The [p_header,p_parent,p_content] part of the message list.
        """
        if self.auth is None:
            return b''
        h = self.auth.copy()
        for m in msg_list:
            h.update(m)
        return str_to_bytes(h.hexdigest())

    def serialize(self, msg, ident=None):
        """Serialize the message components to bytes.

        This is roughly the inverse of deserialize. The serialize/deserialize
        methods work with full message lists, whereas pack/unpack work with
        the individual message parts in the message list.

        Parameters
        ----------
        msg : dict or Message
            The next message dict as returned by the self.msg method.

        Returns
        -------
        msg_list : list
            The list of bytes objects to be sent with the format::

                [ident1, ident2, ..., DELIM, HMAC, p_header, p_parent,
                 p_metadata, p_content, buffer1, buffer2, ...]

            In this list, the ``p_*`` entities are the packed or serialized
            versions, so if JSON is used, these are utf8 encoded JSON strings.
        """
        content = msg.get('content', {})
        if content is None:
            content = self.none
        elif isinstance(content, dict):
            content = self.pack(content)
        elif isinstance(content, bytes):
            # content is already packed, as in a relayed message
            pass
        elif isinstance(content, unicode_type):
            # should be bytes, but JSON often spits out unicode
            content = content.encode('utf8')
        else:
            raise TypeError("Content incorrect type: %s" % type(content))

        real_message = [
            self.pack(msg['header']),
            self.pack(msg['parent_header']),
            self.pack(msg['metadata']),
            content,
        ]

        to_send = []

        if isinstance(ident, list):
            # accept list of idents
            to_send.extend(ident)
        elif ident is not None:
            to_send.append(ident)
        to_send.append(DELIM)

        signature = self.sign(real_message)
        to_send.append(signature)

        to_send.extend(real_message)

        return to_send

    def send(self,
             stream,
             msg_or_type,
             content=None,
             parent=None,
             ident=None,
             buffers=None,
             track=False,
             header=None,
             metadata=None):
        """Build and send a message via stream or socket.

        The message format used by this function internally is as follows:

        [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
         buffer1,buffer2,...]

        The serialize/deserialize methods convert the nested message dict into this
        format.

        Parameters
        ----------

        stream : zmq.Socket or ZMQStream
            The socket-like object used to send the data.
        msg_or_type : str or Message/dict
            Normally, msg_or_type will be a msg_type unless a message is being
            sent more than once. If a header is supplied, this can be set to
            None and the msg_type will be pulled from the header.

        content : dict or None
            The content of the message (ignored if msg_or_type is a message).
        header : dict or None
            The header dict for the message (ignored if msg_to_type is a message).
        parent : Message or dict or None
            The parent or parent header describing the parent of this message
            (ignored if msg_or_type is a message).
        ident : bytes or list of bytes
            The zmq.IDENTITY routing path.
        metadata : dict or None
            The metadata describing the message
        buffers : list or None
            The already-serialized buffers to be appended to the message.
        track : bool
            Whether to track.  Only for use with Sockets, because ZMQStream
            objects cannot track messages.
            

        Returns
        -------
        msg : dict
            The constructed message.
        """
        if not isinstance(stream, zmq.Socket):
            # ZMQStreams and dummy sockets do not support tracking.
            track = False

        if isinstance(msg_or_type, (Message, dict)):
            # We got a Message or message dict, not a msg_type so don't
            # build a new Message.
            msg = msg_or_type
            buffers = buffers or msg.get('buffers', [])
        else:
            msg = self.msg(msg_or_type,
                           content=content,
                           parent=parent,
                           header=header,
                           metadata=metadata)
        if not os.getpid() == self.pid:
            io.rprint("WARNING: attempted to send message from fork")
            io.rprint(msg)
            return
        buffers = [] if buffers is None else buffers
        if self.adapt_version:
            msg = adapt(msg, self.adapt_version)
        to_send = self.serialize(msg, ident)
        to_send.extend(buffers)
        longest = max([len(s) for s in to_send])
        copy = (longest < self.copy_threshold)

        if buffers and track and not copy:
            # only really track when we are doing zero-copy buffers
            tracker = stream.send_multipart(to_send, copy=False, track=True)
        else:
            # use dummy tracker, which will be done immediately
            tracker = DONE
            stream.send_multipart(to_send, copy=copy)

        if self.debug:
            pprint.pprint(msg)
            pprint.pprint(to_send)
            pprint.pprint(buffers)

        msg['tracker'] = tracker

        return msg

    def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
        """Send a raw message via ident path.

        This method is used to send a already serialized message.

        Parameters
        ----------
        stream : ZMQStream or Socket
            The ZMQ stream or socket to use for sending the message.
        msg_list : list
            The serialized list of messages to send. This only includes the
            [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of
            the message.
        ident : ident or list
            A single ident or a list of idents to use in sending.
        """
        to_send = []
        if isinstance(ident, bytes):
            ident = [ident]
        if ident is not None:
            to_send.extend(ident)

        to_send.append(DELIM)
        to_send.append(self.sign(msg_list))
        to_send.extend(msg_list)
        stream.send_multipart(to_send, flags, copy=copy)

    def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
        """Receive and unpack a message.

        Parameters
        ----------
        socket : ZMQStream or Socket
            The socket or stream to use in receiving.

        Returns
        -------
        [idents], msg
            [idents] is a list of idents and msg is a nested message dict of
            same format as self.msg returns.
        """
        if isinstance(socket, ZMQStream):
            socket = socket.socket
        try:
            msg_list = socket.recv_multipart(mode, copy=copy)
        except zmq.ZMQError as e:
            if e.errno == zmq.EAGAIN:
                # We can convert EAGAIN to None as we know in this case
                # recv_multipart won't return None.
                return None, None
            else:
                raise
        # split multipart message into identity list and message dict
        # invalid large messages can cause very expensive string comparisons
        idents, msg_list = self.feed_identities(msg_list, copy)
        try:
            return idents, self.deserialize(msg_list,
                                            content=content,
                                            copy=copy)
        except Exception as e:
            # TODO: handle it
            raise e

    def feed_identities(self, msg_list, copy=True):
        """Split the identities from the rest of the message.

        Feed until DELIM is reached, then return the prefix as idents and
        remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
        but that would be silly.

        Parameters
        ----------
        msg_list : a list of Message or bytes objects
            The message to be split.
        copy : bool
            flag determining whether the arguments are bytes or Messages

        Returns
        -------
        (idents, msg_list) : two lists
            idents will always be a list of bytes, each of which is a ZMQ
            identity. msg_list will be a list of bytes or zmq.Messages of the
            form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
            should be unpackable/unserializable via self.deserialize at this
            point.
        """
        if copy:
            idx = msg_list.index(DELIM)
            return msg_list[:idx], msg_list[idx + 1:]
        else:
            failed = True
            for idx, m in enumerate(msg_list):
                if m.bytes == DELIM:
                    failed = False
                    break
            if failed:
                raise ValueError("DELIM not in msg_list")
            idents, msg_list = msg_list[:idx], msg_list[idx + 1:]
            return [m.bytes for m in idents], msg_list

    def _add_digest(self, signature):
        """add a digest to history to protect against replay attacks"""
        if self.digest_history_size == 0:
            # no history, never add digests
            return

        self.digest_history.add(signature)
        if len(self.digest_history) > self.digest_history_size:
            # threshold reached, cull 10%
            self._cull_digest_history()

    def _cull_digest_history(self):
        """cull the digest history
        
        Removes a randomly selected 10% of the digest history
        """
        current = len(self.digest_history)
        n_to_cull = max(int(current // 10), current - self.digest_history_size)
        if n_to_cull >= current:
            self.digest_history = set()
            return
        to_cull = random.sample(self.digest_history, n_to_cull)
        self.digest_history.difference_update(to_cull)

    def deserialize(self, msg_list, content=True, copy=True):
        """Unserialize a msg_list to a nested message dict.

        This is roughly the inverse of serialize. The serialize/deserialize
        methods work with full message lists, whereas pack/unpack work with
        the individual message parts in the message list.

        Parameters
        ----------
        msg_list : list of bytes or Message objects
            The list of message parts of the form [HMAC,p_header,p_parent,
            p_metadata,p_content,buffer1,buffer2,...].
        content : bool (True)
            Whether to unpack the content dict (True), or leave it packed
            (False).
        copy : bool (True)
            Whether to return the bytes (True), or the non-copying Message
            object in each place (False).

        Returns
        -------
        msg : dict
            The nested message dict with top-level keys [header, parent_header,
            content, buffers].
        """
        minlen = 5
        message = {}
        if not copy:
            for i in range(minlen):
                msg_list[i] = msg_list[i].bytes
        if self.auth is not None:
            signature = msg_list[0]
            if not signature:
                raise ValueError("Unsigned Message")
            if signature in self.digest_history:
                raise ValueError("Duplicate Signature: %r" % signature)
            self._add_digest(signature)
            check = self.sign(msg_list[1:5])
            if not compare_digest(signature, check):
                raise ValueError("Invalid Signature: %r" % signature)
        if not len(msg_list) >= minlen:
            raise TypeError(
                "malformed message, must have at least %i elements" % minlen)
        header = self.unpack(msg_list[1])
        message['header'] = extract_dates(header)
        message['msg_id'] = header['msg_id']
        message['msg_type'] = header['msg_type']
        message['parent_header'] = extract_dates(self.unpack(msg_list[2]))
        message['metadata'] = self.unpack(msg_list[3])
        if content:
            message['content'] = self.unpack(msg_list[4])
        else:
            message['content'] = msg_list[4]

        message['buffers'] = msg_list[5:]
        # adapt to the current version
        return adapt(message)

    def unserialize(self, *args, **kwargs):
        warnings.warn(
            "Session.unserialize is deprecated. Use Session.deserialize.",
            DeprecationWarning,
        )
        return self.deserialize(*args, **kwargs)
Exemplo n.º 7
0
class EngineFactory(RegistrationFactory):
    """IPython engine"""

    # configurables:
    out_stream_factory = Type('IPython.zmq.iostream.OutStream',
                              config=True,
                              help="""The OutStream for handling stdout/err.
        Typically 'IPython.zmq.iostream.OutStream'""")
    display_hook_factory = Type('IPython.zmq.displayhook.ZMQDisplayHook',
                                config=True,
                                help="""The class for handling displayhook.
        Typically 'IPython.zmq.displayhook.ZMQDisplayHook'""")
    location = Unicode(
        config=True,
        help="""The location (an IP address) of the controller.  This is
        used for disambiguating URLs, to determine whether
        loopback should be used to connect or the public address.""")
    timeout = CFloat(
        2,
        config=True,
        help="""The time (in seconds) to wait for the Controller to respond
        to registration requests before giving up.""")
    sshserver = Unicode(
        config=True,
        help=
        """The SSH server to use for tunneling connections to the Controller."""
    )
    sshkey = Unicode(
        config=True,
        help=
        """The SSH private key file to use when tunneling connections to the Controller."""
    )
    paramiko = Bool(
        sys.platform == 'win32',
        config=True,
        help="""Whether to use paramiko instead of openssh for tunnels.""")

    # not configurable:
    user_ns = Dict()
    id = Integer(allow_none=True)
    registrar = Instance('zmq.eventloop.zmqstream.ZMQStream')
    kernel = Instance(Kernel)

    bident = CBytes()
    ident = Unicode()

    def _ident_changed(self, name, old, new):
        self.bident = cast_bytes(new)

    using_ssh = Bool(False)

    def __init__(self, **kwargs):
        super(EngineFactory, self).__init__(**kwargs)
        self.ident = self.session.session

    def init_connector(self):
        """construct connection function, which handles tunnels."""
        self.using_ssh = bool(self.sshkey or self.sshserver)

        if self.sshkey and not self.sshserver:
            # We are using ssh directly to the controller, tunneling localhost to localhost
            self.sshserver = self.url.split('://')[1].split(':')[0]

        if self.using_ssh:
            if tunnel.try_passwordless_ssh(self.sshserver, self.sshkey,
                                           self.paramiko):
                password = False
            else:
                password = getpass("SSH Password for %s: " % self.sshserver)
        else:
            password = False

        def connect(s, url):
            url = disambiguate_url(url, self.location)
            if self.using_ssh:
                self.log.debug("Tunneling connection to %s via %s" %
                               (url, self.sshserver))
                return tunnel.tunnel_connection(
                    s,
                    url,
                    self.sshserver,
                    keyfile=self.sshkey,
                    paramiko=self.paramiko,
                    password=password,
                )
            else:
                return s.connect(url)

        def maybe_tunnel(url):
            """like connect, but don't complete the connection (for use by heartbeat)"""
            url = disambiguate_url(url, self.location)
            if self.using_ssh:
                self.log.debug("Tunneling connection to %s via %s" %
                               (url, self.sshserver))
                url, tunnelobj = tunnel.open_tunnel(
                    url,
                    self.sshserver,
                    keyfile=self.sshkey,
                    paramiko=self.paramiko,
                    password=password,
                )
            return url

        return connect, maybe_tunnel

    def register(self):
        """send the registration_request"""

        self.log.info("Registering with controller at %s" % self.url)
        ctx = self.context
        connect, maybe_tunnel = self.init_connector()
        reg = ctx.socket(zmq.DEALER)
        reg.setsockopt(zmq.IDENTITY, self.bident)
        connect(reg, self.url)
        self.registrar = zmqstream.ZMQStream(reg, self.loop)

        content = dict(queue=self.ident,
                       heartbeat=self.ident,
                       control=self.ident)
        self.registrar.on_recv(
            lambda msg: self.complete_registration(msg, connect, maybe_tunnel))
        # print (self.session.key)
        self.session.send(self.registrar,
                          "registration_request",
                          content=content)

    def complete_registration(self, msg, connect, maybe_tunnel):
        # print msg
        self._abort_dc.stop()
        ctx = self.context
        loop = self.loop
        identity = self.bident
        idents, msg = self.session.feed_identities(msg)
        msg = Message(self.session.unserialize(msg))

        if msg.content.status == 'ok':
            self.id = int(msg.content.id)

            # launch heartbeat
            hb_addrs = msg.content.heartbeat

            # possibly forward hb ports with tunnels
            hb_addrs = [maybe_tunnel(addr) for addr in hb_addrs]
            heart = Heart(*map(str, hb_addrs), heart_id=identity)
            heart.start()

            # create Shell Streams (MUX, Task, etc.):
            queue_addr = msg.content.mux
            shell_addrs = [str(queue_addr)]
            task_addr = msg.content.task
            if task_addr:
                shell_addrs.append(str(task_addr))

            # Uncomment this to go back to two-socket model
            # shell_streams = []
            # for addr in shell_addrs:
            #     stream = zmqstream.ZMQStream(ctx.socket(zmq.ROUTER), loop)
            #     stream.setsockopt(zmq.IDENTITY, identity)
            #     stream.connect(disambiguate_url(addr, self.location))
            #     shell_streams.append(stream)

            # Now use only one shell stream for mux and tasks
            stream = zmqstream.ZMQStream(ctx.socket(zmq.ROUTER), loop)
            stream.setsockopt(zmq.IDENTITY, identity)
            shell_streams = [stream]
            for addr in shell_addrs:
                connect(stream, addr)
            # end single stream-socket

            # control stream:
            control_addr = str(msg.content.control)
            control_stream = zmqstream.ZMQStream(ctx.socket(zmq.ROUTER), loop)
            control_stream.setsockopt(zmq.IDENTITY, identity)
            connect(control_stream, control_addr)

            # create iopub stream:
            iopub_addr = msg.content.iopub
            iopub_socket = ctx.socket(zmq.PUB)
            iopub_socket.setsockopt(zmq.IDENTITY, identity)
            connect(iopub_socket, iopub_addr)

            # disable history:
            self.config.HistoryManager.hist_file = ':memory:'

            # Redirect input streams and set a display hook.
            if self.out_stream_factory:
                sys.stdout = self.out_stream_factory(self.session,
                                                     iopub_socket, u'stdout')
                sys.stdout.topic = cast_bytes('engine.%i.stdout' % self.id)
                sys.stderr = self.out_stream_factory(self.session,
                                                     iopub_socket, u'stderr')
                sys.stderr.topic = cast_bytes('engine.%i.stderr' % self.id)
            if self.display_hook_factory:
                sys.displayhook = self.display_hook_factory(
                    self.session, iopub_socket)
                sys.displayhook.topic = cast_bytes('engine.%i.pyout' % self.id)

            self.kernel = Kernel(config=self.config,
                                 int_id=self.id,
                                 ident=self.ident,
                                 session=self.session,
                                 control_stream=control_stream,
                                 shell_streams=shell_streams,
                                 iopub_socket=iopub_socket,
                                 loop=loop,
                                 user_ns=self.user_ns,
                                 log=self.log)

            self.kernel.shell.display_pub.topic = cast_bytes(
                'engine.%i.displaypub' % self.id)

            # FIXME: This is a hack until IPKernelApp and IPEngineApp can be fully merged
            app = IPKernelApp(config=self.config,
                              shell=self.kernel.shell,
                              kernel=self.kernel,
                              log=self.log)
            app.init_profile_dir()
            app.init_code()

            self.kernel.start()
        else:
            self.log.fatal("Registration Failed: %s" % msg)
            raise Exception("Registration Failed: %s" % msg)

        self.log.info("Completed registration with id %i" % self.id)

    def abort(self):
        self.log.fatal("Registration timed out after %.1f seconds" %
                       self.timeout)
        if self.url.startswith('127.'):
            self.log.fatal("""
            If the controller and engines are not on the same machine,
            you will have to instruct the controller to listen on an external IP (in ipcontroller_config.py):
                c.HubFactory.ip='*' # for all interfaces, internal and external
                c.HubFactory.ip='192.168.1.101' # or any interface that the engines can see
            or tunnel connections via ssh.
            """)
        self.session.send(self.registrar,
                          "unregistration_request",
                          content=dict(id=self.id))
        time.sleep(1)
        sys.exit(255)

    def start(self):
        dc = ioloop.DelayedCallback(self.register, 0, self.loop)
        dc.start()
        self._abort_dc = ioloop.DelayedCallback(self.abort,
                                                self.timeout * 1000, self.loop)
        self._abort_dc.start()
Exemplo n.º 8
0
class TaskScheduler(SessionFactory):
    """Python TaskScheduler object.

    This is the simplest object that supports msg_id based
    DAG dependencies. *Only* task msg_ids are checked, not
    msg_ids of jobs submitted via the MUX queue.

    """

    hwm = Integer(1,
                  config=True,
                  help="""specify the High Water Mark (HWM) for the downstream
        socket in the Task scheduler. This is the maximum number
        of allowed outstanding tasks on each engine.
        
        The default (1) means that only one task can be outstanding on each
        engine.  Setting TaskScheduler.hwm=0 means there is no limit, and the
        engines continue to be assigned tasks while they are working,
        effectively hiding network latency behind computation, but can result
        in an imbalance of work when submitting many heterogenous tasks all at
        once.  Any positive value greater than one is a compromise between the
        two.

        """)
    scheme_name = Enum(
        ('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'),
        'leastload',
        config=True,
        allow_none=False,
        help="""select the task scheduler scheme  [default: Python LRU]
        Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'"""
    )

    def _scheme_name_changed(self, old, new):
        self.log.debug("Using scheme %r" % new)
        self.scheme = globals()[new]

    # input arguments:
    scheme = Instance(FunctionType)  # function for determining the destination

    def _scheme_default(self):
        return leastload

    client_stream = Instance(zmqstream.ZMQStream)  # client-facing stream
    engine_stream = Instance(zmqstream.ZMQStream)  # engine-facing stream
    notifier_stream = Instance(zmqstream.ZMQStream)  # hub-facing sub stream
    mon_stream = Instance(zmqstream.ZMQStream)  # hub-facing pub stream
    query_stream = Instance(zmqstream.ZMQStream)  # hub-facing DEALER stream

    # internals:
    graph = Dict()  # dict by msg_id of [ msg_ids that depend on key ]
    retries = Dict()  # dict by msg_id of retries remaining (non-neg ints)
    # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
    depending = Dict()  # dict by msg_id of Jobs
    pending = Dict()  # dict by engine_uuid of submitted tasks
    completed = Dict()  # dict by engine_uuid of completed tasks
    failed = Dict()  # dict by engine_uuid of failed tasks
    destinations = Dict(
    )  # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
    clients = Dict()  # dict by msg_id for who submitted the task
    targets = List()  # list of target IDENTs
    loads = List()  # list of engine loads
    # full = Set() # set of IDENTs that have HWM outstanding tasks
    all_completed = Set()  # set of all completed tasks
    all_failed = Set()  # set of all failed tasks
    all_done = Set()  # set of all finished tasks=union(completed,failed)
    all_ids = Set()  # set of all submitted task IDs

    auditor = Instance('zmq.eventloop.ioloop.PeriodicCallback')

    ident = CBytes()  # ZMQ identity. This should just be self.session.session

    # but ensure Bytes
    def _ident_default(self):
        return self.session.bsession

    def start(self):
        self.query_stream.on_recv(self.dispatch_query_reply)
        self.session.send(self.query_stream, "connection_request", {})

        self.engine_stream.on_recv(self.dispatch_result, copy=False)
        self.client_stream.on_recv(self.dispatch_submission, copy=False)

        self._notification_handlers = dict(
            registration_notification=self._register_engine,
            unregistration_notification=self._unregister_engine)
        self.notifier_stream.on_recv(self.dispatch_notification)
        self.auditor = ioloop.PeriodicCallback(self.audit_timeouts, 2e3,
                                               self.loop)  # 1 Hz
        self.auditor.start()
        self.log.info("Scheduler started [%s]" % self.scheme_name)

    def resume_receiving(self):
        """Resume accepting jobs."""
        self.client_stream.on_recv(self.dispatch_submission, copy=False)

    def stop_receiving(self):
        """Stop accepting jobs while there are no engines.
        Leave them in the ZMQ queue."""
        self.client_stream.on_recv(None)

    #-----------------------------------------------------------------------
    # [Un]Registration Handling
    #-----------------------------------------------------------------------

    def dispatch_query_reply(self, msg):
        """handle reply to our initial connection request"""
        try:
            idents, msg = self.session.feed_identities(msg)
        except ValueError:
            self.log.warn("task::Invalid Message: %r", msg)
            return
        try:
            msg = self.session.unserialize(msg)
        except ValueError:
            self.log.warn("task::Unauthorized message from: %r" % idents)
            return

        content = msg['content']
        for uuid in content.get('engines', {}).values():
            self._register_engine(cast_bytes(uuid))

    @util.log_errors
    def dispatch_notification(self, msg):
        """dispatch register/unregister events."""
        try:
            idents, msg = self.session.feed_identities(msg)
        except ValueError:
            self.log.warn("task::Invalid Message: %r", msg)
            return
        try:
            msg = self.session.unserialize(msg)
        except ValueError:
            self.log.warn("task::Unauthorized message from: %r" % idents)
            return

        msg_type = msg['header']['msg_type']

        handler = self._notification_handlers.get(msg_type, None)
        if handler is None:
            self.log.error("Unhandled message type: %r" % msg_type)
        else:
            try:
                handler(cast_bytes(msg['content']['uuid']))
            except Exception:
                self.log.error("task::Invalid notification msg: %r",
                               msg,
                               exc_info=True)

    def _register_engine(self, uid):
        """New engine with ident `uid` became available."""
        # head of the line:
        self.targets.insert(0, uid)
        self.loads.insert(0, 0)

        # initialize sets
        self.completed[uid] = set()
        self.failed[uid] = set()
        self.pending[uid] = {}

        # rescan the graph:
        self.update_graph(None)

    def _unregister_engine(self, uid):
        """Existing engine with ident `uid` became unavailable."""
        if len(self.targets) == 1:
            # this was our only engine
            pass

        # handle any potentially finished tasks:
        self.engine_stream.flush()

        # don't pop destinations, because they might be used later
        # map(self.destinations.pop, self.completed.pop(uid))
        # map(self.destinations.pop, self.failed.pop(uid))

        # prevent this engine from receiving work
        idx = self.targets.index(uid)
        self.targets.pop(idx)
        self.loads.pop(idx)

        # wait 5 seconds before cleaning up pending jobs, since the results might
        # still be incoming
        if self.pending[uid]:
            dc = ioloop.DelayedCallback(
                lambda: self.handle_stranded_tasks(uid), 5000, self.loop)
            dc.start()
        else:
            self.completed.pop(uid)
            self.failed.pop(uid)

    def handle_stranded_tasks(self, engine):
        """Deal with jobs resident in an engine that died."""
        lost = self.pending[engine]
        for msg_id in lost.keys():
            if msg_id not in self.pending[engine]:
                # prevent double-handling of messages
                continue

            raw_msg = lost[msg_id].raw_msg
            idents, msg = self.session.feed_identities(raw_msg, copy=False)
            parent = self.session.unpack(msg[1].bytes)
            idents = [engine, idents[0]]

            # build fake error reply
            try:
                raise error.EngineError(
                    "Engine %r died while running task %r" % (engine, msg_id))
            except:
                content = error.wrap_exception()
            # build fake metadata
            md = dict(
                status=u'error',
                engine=engine,
                date=datetime.now(),
            )
            msg = self.session.msg('apply_reply',
                                   content,
                                   parent=parent,
                                   metadata=md)
            raw_reply = map(zmq.Message,
                            self.session.serialize(msg, ident=idents))
            # and dispatch it
            self.dispatch_result(raw_reply)

        # finally scrub completed/failed lists
        self.completed.pop(engine)
        self.failed.pop(engine)

    #-----------------------------------------------------------------------
    # Job Submission
    #-----------------------------------------------------------------------

    @util.log_errors
    def dispatch_submission(self, raw_msg):
        """Dispatch job submission to appropriate handlers."""
        # ensure targets up to date:
        self.notifier_stream.flush()
        try:
            idents, msg = self.session.feed_identities(raw_msg, copy=False)
            msg = self.session.unserialize(msg, content=False, copy=False)
        except Exception:
            self.log.error("task::Invaid task msg: %r" % raw_msg,
                           exc_info=True)
            return

        # send to monitor
        self.mon_stream.send_multipart([b'intask'] + raw_msg, copy=False)

        header = msg['header']
        md = msg['metadata']
        msg_id = header['msg_id']
        self.all_ids.add(msg_id)

        # get targets as a set of bytes objects
        # from a list of unicode objects
        targets = md.get('targets', [])
        targets = map(cast_bytes, targets)
        targets = set(targets)

        retries = md.get('retries', 0)
        self.retries[msg_id] = retries

        # time dependencies
        after = md.get('after', None)
        if after:
            after = Dependency(after)
            if after.all:
                if after.success:
                    after = Dependency(
                        after.difference(self.all_completed),
                        success=after.success,
                        failure=after.failure,
                        all=after.all,
                    )
                if after.failure:
                    after = Dependency(
                        after.difference(self.all_failed),
                        success=after.success,
                        failure=after.failure,
                        all=after.all,
                    )
            if after.check(self.all_completed, self.all_failed):
                # recast as empty set, if `after` already met,
                # to prevent unnecessary set comparisons
                after = MET
        else:
            after = MET

        # location dependencies
        follow = Dependency(md.get('follow', []))

        # turn timeouts into datetime objects:
        timeout = md.get('timeout', None)
        if timeout:
            # cast to float, because jsonlib returns floats as decimal.Decimal,
            # which timedelta does not accept
            timeout = datetime.now() + timedelta(0, float(timeout), 0)

        job = Job(
            msg_id=msg_id,
            raw_msg=raw_msg,
            idents=idents,
            msg=msg,
            header=header,
            targets=targets,
            after=after,
            follow=follow,
            timeout=timeout,
            metadata=md,
        )

        # validate and reduce dependencies:
        for dep in after, follow:
            if not dep:  # empty dependency
                continue
            # check valid:
            if msg_id in dep or dep.difference(self.all_ids):
                self.depending[msg_id] = job
                return self.fail_unreachable(msg_id, error.InvalidDependency)
            # check if unreachable:
            if dep.unreachable(self.all_completed, self.all_failed):
                self.depending[msg_id] = job
                return self.fail_unreachable(msg_id)

        if after.check(self.all_completed, self.all_failed):
            # time deps already met, try to run
            if not self.maybe_run(job):
                # can't run yet
                if msg_id not in self.all_failed:
                    # could have failed as unreachable
                    self.save_unmet(job)
        else:
            self.save_unmet(job)

    def audit_timeouts(self):
        """Audit all waiting tasks for expired timeouts."""
        now = datetime.now()
        for msg_id in self.depending.keys():
            # must recheck, in case one failure cascaded to another:
            if msg_id in self.depending:
                job = self.depending[msg_id]
                if job.timeout and job.timeout < now:
                    self.fail_unreachable(msg_id, error.TaskTimeout)

    def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
        """a task has become unreachable, send a reply with an ImpossibleDependency
        error."""
        if msg_id not in self.depending:
            self.log.error("msg %r already failed!", msg_id)
            return
        job = self.depending.pop(msg_id)
        for mid in job.dependents:
            if mid in self.graph:
                self.graph[mid].remove(msg_id)

        try:
            raise why()
        except:
            content = error.wrap_exception()

        self.all_done.add(msg_id)
        self.all_failed.add(msg_id)

        msg = self.session.send(self.client_stream,
                                'apply_reply',
                                content,
                                parent=job.header,
                                ident=job.idents)
        self.session.send(self.mon_stream,
                          msg,
                          ident=[b'outtask'] + job.idents)

        self.update_graph(msg_id, success=False)

    def maybe_run(self, job):
        """check location dependencies, and run if they are met."""
        msg_id = job.msg_id
        self.log.debug("Attempting to assign task %s", msg_id)
        if not self.targets:
            # no engines, definitely can't run
            return False

        if job.follow or job.targets or job.blacklist or self.hwm:
            # we need a can_run filter
            def can_run(idx):
                # check hwm
                if self.hwm and self.loads[idx] == self.hwm:
                    return False
                target = self.targets[idx]
                # check blacklist
                if target in job.blacklist:
                    return False
                # check targets
                if job.targets and target not in job.targets:
                    return False
                # check follow
                return job.follow.check(self.completed[target],
                                        self.failed[target])

            indices = filter(can_run, range(len(self.targets)))

            if not indices:
                # couldn't run
                if job.follow.all:
                    # check follow for impossibility
                    dests = set()
                    relevant = set()
                    if job.follow.success:
                        relevant = self.all_completed
                    if job.follow.failure:
                        relevant = relevant.union(self.all_failed)
                    for m in job.follow.intersection(relevant):
                        dests.add(self.destinations[m])
                    if len(dests) > 1:
                        self.depending[msg_id] = job
                        self.fail_unreachable(msg_id)
                        return False
                if job.targets:
                    # check blacklist+targets for impossibility
                    job.targets.difference_update(job.blacklist)
                    if not job.targets or not job.targets.intersection(
                            self.targets):
                        self.depending[msg_id] = job
                        self.fail_unreachable(msg_id)
                        return False
                return False
        else:
            indices = None

        self.submit_task(job, indices)
        return True

    def save_unmet(self, job):
        """Save a message for later submission when its dependencies are met."""
        msg_id = job.msg_id
        self.depending[msg_id] = job
        # track the ids in follow or after, but not those already finished
        for dep_id in job.after.union(job.follow).difference(self.all_done):
            if dep_id not in self.graph:
                self.graph[dep_id] = set()
            self.graph[dep_id].add(msg_id)

    def submit_task(self, job, indices=None):
        """Submit a task to any of a subset of our targets."""
        if indices:
            loads = [self.loads[i] for i in indices]
        else:
            loads = self.loads
        idx = self.scheme(loads)
        if indices:
            idx = indices[idx]
        target = self.targets[idx]
        # print (target, map(str, msg[:3]))
        # send job to the engine
        self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
        self.engine_stream.send_multipart(job.raw_msg, copy=False)
        # update load
        self.add_job(idx)
        self.pending[target][job.msg_id] = job
        # notify Hub
        content = dict(msg_id=job.msg_id, engine_id=target.decode('ascii'))
        self.session.send(self.mon_stream,
                          'task_destination',
                          content=content,
                          ident=[b'tracktask', self.ident])

    #-----------------------------------------------------------------------
    # Result Handling
    #-----------------------------------------------------------------------

    @util.log_errors
    def dispatch_result(self, raw_msg):
        """dispatch method for result replies"""
        try:
            idents, msg = self.session.feed_identities(raw_msg, copy=False)
            msg = self.session.unserialize(msg, content=False, copy=False)
            engine = idents[0]
            try:
                idx = self.targets.index(engine)
            except ValueError:
                pass  # skip load-update for dead engines
            else:
                self.finish_job(idx)
        except Exception:
            self.log.error("task::Invaid result: %r", raw_msg, exc_info=True)
            return

        md = msg['metadata']
        parent = msg['parent_header']
        if md.get('dependencies_met', True):
            success = (md['status'] == 'ok')
            msg_id = parent['msg_id']
            retries = self.retries[msg_id]
            if not success and retries > 0:
                # failed
                self.retries[msg_id] = retries - 1
                self.handle_unmet_dependency(idents, parent)
            else:
                del self.retries[msg_id]
                # relay to client and update graph
                self.handle_result(idents, parent, raw_msg, success)
                # send to Hub monitor
                self.mon_stream.send_multipart([b'outtask'] + raw_msg,
                                               copy=False)
        else:
            self.handle_unmet_dependency(idents, parent)

    def handle_result(self, idents, parent, raw_msg, success=True):
        """handle a real task result, either success or failure"""
        # first, relay result to client
        engine = idents[0]
        client = idents[1]
        # swap_ids for ROUTER-ROUTER mirror
        raw_msg[:2] = [client, engine]
        # print (map(str, raw_msg[:4]))
        self.client_stream.send_multipart(raw_msg, copy=False)
        # now, update our data structures
        msg_id = parent['msg_id']
        self.pending[engine].pop(msg_id)
        if success:
            self.completed[engine].add(msg_id)
            self.all_completed.add(msg_id)
        else:
            self.failed[engine].add(msg_id)
            self.all_failed.add(msg_id)
        self.all_done.add(msg_id)
        self.destinations[msg_id] = engine

        self.update_graph(msg_id, success)

    def handle_unmet_dependency(self, idents, parent):
        """handle an unmet dependency"""
        engine = idents[0]
        msg_id = parent['msg_id']

        job = self.pending[engine].pop(msg_id)
        job.blacklist.add(engine)

        if job.blacklist == job.targets:
            self.depending[msg_id] = job
            self.fail_unreachable(msg_id)
        elif not self.maybe_run(job):
            # resubmit failed
            if msg_id not in self.all_failed:
                # put it back in our dependency tree
                self.save_unmet(job)

        if self.hwm:
            try:
                idx = self.targets.index(engine)
            except ValueError:
                pass  # skip load-update for dead engines
            else:
                if self.loads[idx] == self.hwm - 1:
                    self.update_graph(None)

    def update_graph(self, dep_id=None, success=True):
        """dep_id just finished. Update our dependency
        graph and submit any jobs that just became runable.

        Called with dep_id=None to update entire graph for hwm, but without finishing
        a task.
        """
        # print ("\n\n***********")
        # pprint (dep_id)
        # pprint (self.graph)
        # pprint (self.depending)
        # pprint (self.all_completed)
        # pprint (self.all_failed)
        # print ("\n\n***********\n\n")
        # update any jobs that depended on the dependency
        jobs = self.graph.pop(dep_id, [])

        # recheck *all* jobs if
        # a) we have HWM and an engine just become no longer full
        # or b) dep_id was given as None

        if dep_id is None or self.hwm and any(
            [load == self.hwm - 1 for load in self.loads]):
            jobs = self.depending.keys()

        for msg_id in sorted(
                jobs, key=lambda msg_id: self.depending[msg_id].timestamp):
            job = self.depending[msg_id]

            if job.after.unreachable(self.all_completed, self.all_failed)\
                    or job.follow.unreachable(self.all_completed, self.all_failed):
                self.fail_unreachable(msg_id)

            elif job.after.check(self.all_completed,
                                 self.all_failed):  # time deps met, maybe run
                if self.maybe_run(job):

                    self.depending.pop(msg_id)
                    for mid in job.dependents:
                        if mid in self.graph:
                            self.graph[mid].remove(msg_id)

    #----------------------------------------------------------------------
    # methods to be overridden by subclasses
    #----------------------------------------------------------------------

    def add_job(self, idx):
        """Called after self.targets[idx] just got the job with header.
        Override with subclasses.  The default ordering is simple LRU.
        The default loads are the number of outstanding jobs."""
        self.loads[idx] += 1
        for lis in (self.targets, self.loads):
            lis.append(lis.pop(idx))

    def finish_job(self, idx):
        """Called after self.targets[idx] just finished a job.
        Override with subclasses."""
        self.loads[idx] -= 1
Exemplo n.º 9
0
class EngineFactory(RegistrationFactory):
    """IPython engine"""

    # configurables:
    out_stream_factory=Type('IPython.kernel.zmq.iostream.OutStream', config=True,
        help="""The OutStream for handling stdout/err.
        Typically 'IPython.kernel.zmq.iostream.OutStream'""")
    display_hook_factory=Type('IPython.kernel.zmq.displayhook.ZMQDisplayHook', config=True,
        help="""The class for handling displayhook.
        Typically 'IPython.kernel.zmq.displayhook.ZMQDisplayHook'""")
    location=Unicode(config=True,
        help="""The location (an IP address) of the controller.  This is
        used for disambiguating URLs, to determine whether
        loopback should be used to connect or the public address.""")
    timeout=Float(5.0, config=True,
        help="""The time (in seconds) to wait for the Controller to respond
        to registration requests before giving up.""")
    max_heartbeat_misses=Integer(50, config=True,
        help="""The maximum number of times a check for the heartbeat ping of a 
        controller can be missed before shutting down the engine.
        
        If set to 0, the check is disabled.""")
    sshserver=Unicode(config=True,
        help="""The SSH server to use for tunneling connections to the Controller.""")
    sshkey=Unicode(config=True,
        help="""The SSH private key file to use when tunneling connections to the Controller.""")
    paramiko=Bool(sys.platform == 'win32', config=True,
        help="""Whether to use paramiko instead of openssh for tunnels.""")
    
    @property
    def tunnel_mod(self):
        from zmq.ssh import tunnel
        return tunnel


    # not configurable:
    connection_info = Dict()
    user_ns = Dict()
    id = Integer(allow_none=True)
    registrar = Instance('zmq.eventloop.zmqstream.ZMQStream')
    kernel = Instance(Kernel)
    hb_check_period=Integer()
    
    # States for the heartbeat monitoring
    # Initial values for monitored and pinged must satisfy "monitored > pinged == False" so that 
    # during the first check no "missed" ping is reported. Must be floats for Python 3 compatibility.
    _hb_last_pinged = 0.0
    _hb_last_monitored = 0.0
    _hb_missed_beats = 0
    # The zmq Stream which receives the pings from the Heart
    _hb_listener = None

    bident = CBytes()
    ident = Unicode()
    def _ident_changed(self, name, old, new):
        self.bident = cast_bytes(new)
    using_ssh=Bool(False)


    def __init__(self, **kwargs):
        super(EngineFactory, self).__init__(**kwargs)
        self.ident = self.session.session

    def init_connector(self):
        """construct connection function, which handles tunnels."""
        self.using_ssh = bool(self.sshkey or self.sshserver)

        if self.sshkey and not self.sshserver:
            # We are using ssh directly to the controller, tunneling localhost to localhost
            self.sshserver = self.url.split('://')[1].split(':')[0]

        if self.using_ssh:
            if self.tunnel_mod.try_passwordless_ssh(self.sshserver, self.sshkey, self.paramiko):
                password=False
            else:
                password = getpass("SSH Password for %s: "%self.sshserver)
        else:
            password = False

        def connect(s, url):
            url = disambiguate_url(url, self.location)
            if self.using_ssh:
                self.log.debug("Tunneling connection to %s via %s", url, self.sshserver)
                return self.tunnel_mod.tunnel_connection(s, url, self.sshserver,
                            keyfile=self.sshkey, paramiko=self.paramiko,
                            password=password,
                )
            else:
                return s.connect(url)

        def maybe_tunnel(url):
            """like connect, but don't complete the connection (for use by heartbeat)"""
            url = disambiguate_url(url, self.location)
            if self.using_ssh:
                self.log.debug("Tunneling connection to %s via %s", url, self.sshserver)
                url, tunnelobj = self.tunnel_mod.open_tunnel(url, self.sshserver,
                            keyfile=self.sshkey, paramiko=self.paramiko,
                            password=password,
                )
            return str(url)
        return connect, maybe_tunnel

    def register(self):
        """send the registration_request"""

        self.log.info("Registering with controller at %s"%self.url)
        ctx = self.context
        connect,maybe_tunnel = self.init_connector()
        reg = ctx.socket(zmq.DEALER)
        reg.setsockopt(zmq.IDENTITY, self.bident)
        connect(reg, self.url)
        self.registrar = zmqstream.ZMQStream(reg, self.loop)


        content = dict(uuid=self.ident)
        self.registrar.on_recv(lambda msg: self.complete_registration(msg, connect, maybe_tunnel))
        # print (self.session.key)
        self.session.send(self.registrar, "registration_request", content=content)

    def _report_ping(self, msg):
        """Callback for when the heartmonitor.Heart receives a ping"""
        #self.log.debug("Received a ping: %s", msg)
        self._hb_last_pinged = time.time()

    def complete_registration(self, msg, connect, maybe_tunnel):
        # print msg
        self.loop.remove_timeout(self._abort_timeout)
        ctx = self.context
        loop = self.loop
        identity = self.bident
        idents,msg = self.session.feed_identities(msg)
        msg = self.session.deserialize(msg)
        content = msg['content']
        info = self.connection_info
        
        def url(key):
            """get zmq url for given channel"""
            return str(info["interface"] + ":%i" % info[key])
        
        if content['status'] == 'ok':
            self.id = int(content['id'])

            # launch heartbeat
            # possibly forward hb ports with tunnels
            hb_ping = maybe_tunnel(url('hb_ping'))
            hb_pong = maybe_tunnel(url('hb_pong'))
            
            hb_monitor = None
            if self.max_heartbeat_misses > 0:
                # Add a monitor socket which will record the last time a ping was seen
                mon = self.context.socket(zmq.SUB)
                mport = mon.bind_to_random_port('tcp://%s' % localhost())
                mon.setsockopt(zmq.SUBSCRIBE, b"")
                self._hb_listener = zmqstream.ZMQStream(mon, self.loop)
                self._hb_listener.on_recv(self._report_ping)
            
            
                hb_monitor = "tcp://%s:%i" % (localhost(), mport)

            heart = Heart(hb_ping, hb_pong, hb_monitor , heart_id=identity)
            heart.start()

            # create Shell Connections (MUX, Task, etc.):
            shell_addrs = url('mux'), url('task')

            # Use only one shell stream for mux and tasks
            stream = zmqstream.ZMQStream(ctx.socket(zmq.ROUTER), loop)
            stream.setsockopt(zmq.IDENTITY, identity)
            shell_streams = [stream]
            for addr in shell_addrs:
                connect(stream, addr)

            # control stream:
            control_addr = url('control')
            control_stream = zmqstream.ZMQStream(ctx.socket(zmq.ROUTER), loop)
            control_stream.setsockopt(zmq.IDENTITY, identity)
            connect(control_stream, control_addr)

            # create iopub stream:
            iopub_addr = url('iopub')
            iopub_socket = ctx.socket(zmq.PUB)
            iopub_socket.setsockopt(zmq.IDENTITY, identity)
            connect(iopub_socket, iopub_addr)

            # disable history:
            self.config.HistoryManager.hist_file = ':memory:'
            
            # Redirect input streams and set a display hook.
            if self.out_stream_factory:
                sys.stdout = self.out_stream_factory(self.session, iopub_socket, u'stdout')
                sys.stdout.topic = cast_bytes('engine.%i.stdout' % self.id)
                sys.stderr = self.out_stream_factory(self.session, iopub_socket, u'stderr')
                sys.stderr.topic = cast_bytes('engine.%i.stderr' % self.id)
            if self.display_hook_factory:
                sys.displayhook = self.display_hook_factory(self.session, iopub_socket)
                sys.displayhook.topic = cast_bytes('engine.%i.execute_result' % self.id)

            self.kernel = Kernel(parent=self, int_id=self.id, ident=self.ident, session=self.session,
                    control_stream=control_stream, shell_streams=shell_streams, iopub_socket=iopub_socket,
                    loop=loop, user_ns=self.user_ns, log=self.log)
            
            self.kernel.shell.display_pub.topic = cast_bytes('engine.%i.displaypub' % self.id)
            
                
            # periodically check the heartbeat pings of the controller
            # Should be started here and not in "start()" so that the right period can be taken 
            # from the hubs HeartBeatMonitor.period
            if self.max_heartbeat_misses > 0:
                # Use a slightly bigger check period than the hub signal period to not warn unnecessary 
                self.hb_check_period = int(content['hb_period'])+10
                self.log.info("Starting to monitor the heartbeat signal from the hub every %i ms." , self.hb_check_period)
                self._hb_reporter = ioloop.PeriodicCallback(self._hb_monitor, self.hb_check_period, self.loop)
                self._hb_reporter.start()
            else:
                self.log.info("Monitoring of the heartbeat signal from the hub is not enabled.")

            
            # FIXME: This is a hack until IPKernelApp and IPEngineApp can be fully merged
            app = IPKernelApp(parent=self, shell=self.kernel.shell, kernel=self.kernel, log=self.log)
            app.init_profile_dir()
            app.init_code()
            
            self.kernel.start()
        else:
            self.log.fatal("Registration Failed: %s"%msg)
            raise Exception("Registration Failed: %s"%msg)

        self.log.info("Completed registration with id %i"%self.id)


    def abort(self):
        self.log.fatal("Registration timed out after %.1f seconds"%self.timeout)
        if self.url.startswith('127.'):
            self.log.fatal("""
            If the controller and engines are not on the same machine,
            you will have to instruct the controller to listen on an external IP (in ipcontroller_config.py):
                c.HubFactory.ip='*' # for all interfaces, internal and external
                c.HubFactory.ip='192.168.1.101' # or any interface that the engines can see
            or tunnel connections via ssh.
            """)
        self.session.send(self.registrar, "unregistration_request", content=dict(id=self.id))
        time.sleep(1)
        sys.exit(255)

    def _hb_monitor(self):
        """Callback to monitor the heartbeat from the controller"""
        self._hb_listener.flush()
        if self._hb_last_monitored > self._hb_last_pinged:
            self._hb_missed_beats += 1
            self.log.warn("No heartbeat in the last %s ms (%s time(s) in a row).", self.hb_check_period, self._hb_missed_beats)
        else:
            #self.log.debug("Heartbeat received (after missing %s beats).", self._hb_missed_beats)
            self._hb_missed_beats = 0

        if self._hb_missed_beats >= self.max_heartbeat_misses:
            self.log.fatal("Maximum number of heartbeats misses reached (%s times %s ms), shutting down.",
                           self.max_heartbeat_misses, self.hb_check_period)
            self.session.send(self.registrar, "unregistration_request", content=dict(id=self.id))
            self.loop.stop()

        self._hb_last_monitored = time.time()
            
        
    def start(self):
        loop = self.loop
        def _start():
            self.register()
            self._abort_timeout = loop.add_timeout(loop.time() + self.timeout, self.abort)
        self.loop.add_callback(_start)
Exemplo n.º 10
0
class Kernel(SessionFactory):

    #---------------------------------------------------------------------------
    # Kernel interface
    #---------------------------------------------------------------------------

    # kwargs:
    exec_lines = List(Unicode, config=True, help="List of lines to execute")

    # identities:
    int_id = Int(-1)
    bident = CBytes()
    ident = Unicode()

    def _ident_changed(self, name, old, new):
        self.bident = asbytes(new)

    user_ns = Dict(config=True,
                   help="""Set the user's namespace of the Kernel""")

    control_stream = Instance(zmqstream.ZMQStream)
    task_stream = Instance(zmqstream.ZMQStream)
    iopub_stream = Instance(zmqstream.ZMQStream)
    client = Instance('IPython.parallel.Client')

    # internals
    shell_streams = List()
    compiler = Instance(CommandCompiler, (), {})
    completer = Instance(KernelCompleter)

    aborted = Set()
    shell_handlers = Dict()
    control_handlers = Dict()

    def _set_prefix(self):
        self.prefix = "engine.%s" % self.int_id

    def _connect_completer(self):
        self.completer = KernelCompleter(self.user_ns)

    def __init__(self, **kwargs):
        super(Kernel, self).__init__(**kwargs)
        self._set_prefix()
        self._connect_completer()

        self.on_trait_change(self._set_prefix, 'id')
        self.on_trait_change(self._connect_completer, 'user_ns')

        # Build dict of handlers for message types
        for msg_type in [
                'execute_request', 'complete_request', 'apply_request',
                'clear_request'
        ]:
            self.shell_handlers[msg_type] = getattr(self, msg_type)

        for msg_type in ['shutdown_request', 'abort_request'
                         ] + self.shell_handlers.keys():
            self.control_handlers[msg_type] = getattr(self, msg_type)

        self._initial_exec_lines()

    def _wrap_exception(self, method=None):
        e_info = dict(engine_uuid=self.ident,
                      engine_id=self.int_id,
                      method=method)
        content = wrap_exception(e_info)
        return content

    def _initial_exec_lines(self):
        s = _Passer()
        content = dict(silent=True, user_variable=[], user_expressions=[])
        for line in self.exec_lines:
            self.log.debug("executing initialization: %s" % line)
            content.update({'code': line})
            msg = self.session.msg('execute_request', content)
            self.execute_request(s, [], msg)

    #-------------------- control handlers -----------------------------
    def abort_queues(self):
        for stream in self.shell_streams:
            if stream:
                self.abort_queue(stream)

    def abort_queue(self, stream):
        while True:
            idents, msg = self.session.recv(stream, zmq.NOBLOCK, content=True)
            if msg is None:
                return

            self.log.info("Aborting:")
            self.log.info(str(msg))
            msg_type = msg['header']['msg_type']
            reply_type = msg_type.split('_')[0] + '_reply'
            # reply_msg = self.session.msg(reply_type, {'status' : 'aborted'}, msg)
            # self.reply_socket.send(ident,zmq.SNDMORE)
            # self.reply_socket.send_json(reply_msg)
            reply_msg = self.session.send(stream,
                                          reply_type,
                                          content={'status': 'aborted'},
                                          parent=msg,
                                          ident=idents)
            self.log.debug(str(reply_msg))
            # We need to wait a bit for requests to come in. This can probably
            # be set shorter for true asynchronous clients.
            time.sleep(0.05)

    def abort_request(self, stream, ident, parent):
        """abort a specifig msg by id"""
        msg_ids = parent['content'].get('msg_ids', None)
        if isinstance(msg_ids, basestring):
            msg_ids = [msg_ids]
        if not msg_ids:
            self.abort_queues()
        for mid in msg_ids:
            self.aborted.add(str(mid))

        content = dict(status='ok')
        reply_msg = self.session.send(stream,
                                      'abort_reply',
                                      content=content,
                                      parent=parent,
                                      ident=ident)
        self.log.debug(str(reply_msg))

    def shutdown_request(self, stream, ident, parent):
        """kill ourself.  This should really be handled in an external process"""
        try:
            self.abort_queues()
        except:
            content = self._wrap_exception('shutdown')
        else:
            content = dict(parent['content'])
            content['status'] = 'ok'
        msg = self.session.send(stream,
                                'shutdown_reply',
                                content=content,
                                parent=parent,
                                ident=ident)
        self.log.debug(str(msg))
        dc = ioloop.DelayedCallback(lambda: sys.exit(0), 1000, self.loop)
        dc.start()

    def dispatch_control(self, msg):
        idents, msg = self.session.feed_identities(msg, copy=False)
        try:
            msg = self.session.unserialize(msg, content=True, copy=False)
        except:
            self.log.error("Invalid Message", exc_info=True)
            return
        else:
            self.log.debug("Control received, %s", msg)

        header = msg['header']
        msg_id = header['msg_id']
        msg_type = header['msg_type']

        handler = self.control_handlers.get(msg_type, None)
        if handler is None:
            self.log.error("UNKNOWN CONTROL MESSAGE TYPE: %r" % msg_type)
        else:
            handler(self.control_stream, idents, msg)

    #-------------------- queue helpers ------------------------------

    def check_dependencies(self, dependencies):
        if not dependencies:
            return True
        if len(dependencies) == 2 and dependencies[0] in 'any all'.split():
            anyorall = dependencies[0]
            dependencies = dependencies[1]
        else:
            anyorall = 'all'
        results = self.client.get_results(dependencies, status_only=True)
        if results['status'] != 'ok':
            return False

        if anyorall == 'any':
            if not results['completed']:
                return False
        else:
            if results['pending']:
                return False

        return True

    def check_aborted(self, msg_id):
        return msg_id in self.aborted

    #-------------------- queue handlers -----------------------------

    def clear_request(self, stream, idents, parent):
        """Clear our namespace."""
        self.user_ns = {}
        msg = self.session.send(stream,
                                'clear_reply',
                                ident=idents,
                                parent=parent,
                                content=dict(status='ok'))
        self._initial_exec_lines()

    def execute_request(self, stream, ident, parent):
        self.log.debug('execute request %s' % parent)
        try:
            code = parent[u'content'][u'code']
        except:
            self.log.error("Got bad msg: %s" % parent, exc_info=True)
            return
        self.session.send(self.iopub_stream,
                          u'pyin', {u'code': code},
                          parent=parent,
                          ident=asbytes('%s.pyin' % self.prefix))
        started = datetime.now()
        try:
            comp_code = self.compiler(code, '<zmq-kernel>')
            # allow for not overriding displayhook
            if hasattr(sys.displayhook, 'set_parent'):
                sys.displayhook.set_parent(parent)
                sys.stdout.set_parent(parent)
                sys.stderr.set_parent(parent)
            exec comp_code in self.user_ns, self.user_ns
        except:
            exc_content = self._wrap_exception('execute')
            # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
            self.session.send(self.iopub_stream,
                              u'pyerr',
                              exc_content,
                              parent=parent,
                              ident=asbytes('%s.pyerr' % self.prefix))
            reply_content = exc_content
        else:
            reply_content = {'status': 'ok'}

        reply_msg = self.session.send(stream,
                                      u'execute_reply',
                                      reply_content,
                                      parent=parent,
                                      ident=ident,
                                      subheader=dict(started=started))
        self.log.debug(str(reply_msg))
        if reply_msg['content']['status'] == u'error':
            self.abort_queues()

    def complete_request(self, stream, ident, parent):
        matches = {'matches': self.complete(parent), 'status': 'ok'}
        completion_msg = self.session.send(stream, 'complete_reply', matches,
                                           parent, ident)
        # print >> sys.__stdout__, completion_msg

    def complete(self, msg):
        return self.completer.complete(msg.content.line, msg.content.text)

    def apply_request(self, stream, ident, parent):
        # flush previous reply, so this request won't block it
        stream.flush(zmq.POLLOUT)
        try:
            content = parent[u'content']
            bufs = parent[u'buffers']
            msg_id = parent['header']['msg_id']
            # bound = parent['header'].get('bound', False)
        except:
            self.log.error("Got bad msg: %s" % parent, exc_info=True)
            return
        # pyin_msg = self.session.msg(u'pyin',{u'code':code}, parent=parent)
        # self.iopub_stream.send(pyin_msg)
        # self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent)
        sub = {
            'dependencies_met': True,
            'engine': self.ident,
            'started': datetime.now()
        }
        try:
            # allow for not overriding displayhook
            if hasattr(sys.displayhook, 'set_parent'):
                sys.displayhook.set_parent(parent)
                sys.stdout.set_parent(parent)
                sys.stderr.set_parent(parent)
            # exec "f(*args,**kwargs)" in self.user_ns, self.user_ns
            working = self.user_ns
            # suffix =
            prefix = "_" + str(msg_id).replace("-", "") + "_"

            f, args, kwargs = unpack_apply_message(bufs, working, copy=False)
            # if bound:
            #     bound_ns = Namespace(working)
            #     args = [bound_ns]+list(args)

            fname = getattr(f, '__name__', 'f')

            fname = prefix + "f"
            argname = prefix + "args"
            kwargname = prefix + "kwargs"
            resultname = prefix + "result"

            ns = {fname: f, argname: args, kwargname: kwargs, resultname: None}
            # print ns
            working.update(ns)
            code = "%s=%s(*%s,**%s)" % (resultname, fname, argname, kwargname)
            try:
                exec code in working, working
                result = working.get(resultname)
            finally:
                for key in ns.iterkeys():
                    working.pop(key)
            # if bound:
            #     working.update(bound_ns)

            packed_result, buf = serialize_object(result)
            result_buf = [packed_result] + buf
        except:
            exc_content = self._wrap_exception('apply')
            # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
            self.session.send(self.iopub_stream,
                              u'pyerr',
                              exc_content,
                              parent=parent,
                              ident=asbytes('%s.pyerr' % self.prefix))
            reply_content = exc_content
            result_buf = []

            if exc_content['ename'] == 'UnmetDependency':
                sub['dependencies_met'] = False
        else:
            reply_content = {'status': 'ok'}

        # put 'ok'/'error' status in header, for scheduler introspection:
        sub['status'] = reply_content['status']

        reply_msg = self.session.send(stream,
                                      u'apply_reply',
                                      reply_content,
                                      parent=parent,
                                      ident=ident,
                                      buffers=result_buf,
                                      subheader=sub)

        # flush i/o
        # should this be before reply_msg is sent, like in the single-kernel code,
        # or should nothing get in the way of real results?
        sys.stdout.flush()
        sys.stderr.flush()

    def dispatch_queue(self, stream, msg):
        self.control_stream.flush()
        idents, msg = self.session.feed_identities(msg, copy=False)
        try:
            msg = self.session.unserialize(msg, content=True, copy=False)
        except:
            self.log.error("Invalid Message", exc_info=True)
            return
        else:
            self.log.debug("Message received, %s", msg)

        header = msg['header']
        msg_id = header['msg_id']
        msg_type = msg['header']['msg_type']
        if self.check_aborted(msg_id):
            self.aborted.remove(msg_id)
            # is it safe to assume a msg_id will not be resubmitted?
            reply_type = msg_type.split('_')[0] + '_reply'
            status = {'status': 'aborted'}
            reply_msg = self.session.send(stream,
                                          reply_type,
                                          subheader=status,
                                          content=status,
                                          parent=msg,
                                          ident=idents)
            return
        handler = self.shell_handlers.get(msg_type, None)
        if handler is None:
            self.log.error("UNKNOWN MESSAGE TYPE: %r" % msg_type)
        else:
            handler(stream, idents, msg)

    def start(self):
        #### stream mode:
        if self.control_stream:
            self.control_stream.on_recv(self.dispatch_control, copy=False)
            self.control_stream.on_err(printer)

        def make_dispatcher(stream):
            def dispatcher(msg):
                return self.dispatch_queue(stream, msg)

            return dispatcher

        for s in self.shell_streams:
            s.on_recv(make_dispatcher(s), copy=False)
            s.on_err(printer)

        if self.iopub_stream:
            self.iopub_stream.on_err(printer)
Exemplo n.º 11
0
class AmberSimulator(Device):
    name = 'AMBER'
    path = 'msmaccelerator.simulate.amber_simulation.AmberSimulator'
    short_description = 'Run a single round of dynamics with AMBER'
    long_description = '''This device will connect to the msmaccelerator server,
        request the initial conditions with which to start a simulation, and
        propagate dynamics'''
    
    mdin = FilePath(config=True, exists=True, isfile=True,
        help="""AMBER .in file controlling the production run. If no production is
        desired, do not set this parameter.""")
    workdir = CBytes(config=True, help="""Directory to work in. If not set,
        we'll requirest a temporary directory from the OS and clean it up
        when we're finished. This option is useful for debugging.""")
    executable = Enum(['pmemd', 'pmemd.cuda', 'pmemd.cuda.MPI'], config=True,
        default_value='pmemd', help="Which AMBER executable to use?")
    precommand = Unicode(u'', config=True, help="Something to run before the command, like mpirun")
    prmtop = FilePath(config=True, exists=True, isfile=True,
                      help="""Parameter/topology file for the system""")

    amber_home = FilePath(exists=True, isdir=True, help='Home directory for AMBER installation')
    def _amber_home_default(self):
        if 'AMBERHOME' not in os.environ:
            raise KeyError("You need to set the AMBERHOME environment variable")
        return os.environ['AMBERHOME']

    aliases = dict(mdin='AmberSimulator.mdin',
                   precommand='AmberSimulator.precommand',
                   prmtop='AmberSimulator.prmtop',
                   zmq_port='Device.zmq_port',
                   zmq_url='Device.zmq_url',
                   executable='AmberSimulator.executable')

    def start(self):
        super(AmberSimulator, self).start()


    def error(self, msg):
        self.log.error(msg)
        self.exit(1)
        
    def on_startup_message(self, msg):
        """This method is called when the device receives its startup message
        from the server.
        """

        assert msg.header.msg_type in ['simulate']  # only allowed RPC
        return getattr(self, msg.header.msg_type)(msg.header, msg.content)

    def simulate(self, header, content):
        """Run the simulation in subprocesses to invoke the AMBER binaries"""

        if content.starting_state.protocol == 'localfs':
            if not content.starting_state.path.endswith('.inpcrd'):
                raise ValueError('starting state must have inpcrd extension. '
                                 'did you start server in amber mode? '
                                 'starting_state.path=%s' % content.starting_state.path)
        else:
            raise NotImplementedError('Only localfs transport is currently '
                                      'supported.')

        template = '{precommand} {binary} -O -i {mdin} -o {mdout} -p {prmtop} -c {inpcrd} -r {restart} -x {traj}'

        # RUNNING PRODUCTION
        with cd_context('amber_workdir', logger=self.log):
            base = splitext(basename(self.mdin))[0]
            binary = join(self.amber_home, 'bin', self.executable)
            mdout = base + '.out'
            restart = base + '.restart'
            traj = base + '.nc'
            cmd = template.format(binary=binary, mdin=relpath(self.mdin),
                                  mdout=mdout, prmtop=relpath(self.prmtop),
                                  inpcrd=relpath(content.starting_state.path),
                                  restart=restart, precommand=self.precommand,
                                  traj=relpath(content.output.path)).split()
            self.log.info('Executing Command: %s' % cmd)
            subprocess.check_output(cmd)
        

        self.send_recv(msg_type='simulation_done', content={
            'status': 'success',
            'output': {
                'protocol': 'localfs',
                'path': content.output.path
            }
        })
Exemplo n.º 12
0
class EngineFactory(RegistrationFactory):
    """IPython engine"""

    # configurables:
    out_stream_factory = Type('IPython.zmq.iostream.OutStream',
                              config=True,
                              help="""The OutStream for handling stdout/err.
        Typically 'IPython.zmq.iostream.OutStream'""")
    display_hook_factory = Type('IPython.zmq.displayhook.ZMQDisplayHook',
                                config=True,
                                help="""The class for handling displayhook.
        Typically 'IPython.zmq.displayhook.ZMQDisplayHook'""")
    location = Unicode(
        config=True,
        help="""The location (an IP address) of the controller.  This is
        used for disambiguating URLs, to determine whether
        loopback should be used to connect or the public address.""")
    timeout = CFloat(
        2,
        config=True,
        help="""The time (in seconds) to wait for the Controller to respond
        to registration requests before giving up.""")

    # not configurable:
    user_ns = Dict()
    id = Int(allow_none=True)
    registrar = Instance('zmq.eventloop.zmqstream.ZMQStream')
    kernel = Instance(Kernel)

    bident = CBytes()
    ident = Unicode()

    def _ident_changed(self, name, old, new):
        self.bident = asbytes(new)

    def __init__(self, **kwargs):
        super(EngineFactory, self).__init__(**kwargs)
        self.ident = self.session.session
        ctx = self.context

        reg = ctx.socket(zmq.XREQ)
        reg.setsockopt(zmq.IDENTITY, self.bident)
        reg.connect(self.url)
        self.registrar = zmqstream.ZMQStream(reg, self.loop)

    def register(self):
        """send the registration_request"""

        self.log.info("Registering with controller at %s" % self.url)
        content = dict(queue=self.ident,
                       heartbeat=self.ident,
                       control=self.ident)
        self.registrar.on_recv(self.complete_registration)
        # print (self.session.key)
        self.session.send(self.registrar,
                          "registration_request",
                          content=content)

    def complete_registration(self, msg):
        # print msg
        self._abort_dc.stop()
        ctx = self.context
        loop = self.loop
        identity = self.bident
        idents, msg = self.session.feed_identities(msg)
        msg = Message(self.session.unpack_message(msg))

        if msg.content.status == 'ok':
            self.id = int(msg.content.id)

            # create Shell Streams (MUX, Task, etc.):
            queue_addr = msg.content.mux
            shell_addrs = [str(queue_addr)]
            task_addr = msg.content.task
            if task_addr:
                shell_addrs.append(str(task_addr))

            # Uncomment this to go back to two-socket model
            # shell_streams = []
            # for addr in shell_addrs:
            #     stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop)
            #     stream.setsockopt(zmq.IDENTITY, identity)
            #     stream.connect(disambiguate_url(addr, self.location))
            #     shell_streams.append(stream)

            # Now use only one shell stream for mux and tasks
            stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop)
            stream.setsockopt(zmq.IDENTITY, identity)
            shell_streams = [stream]
            for addr in shell_addrs:
                stream.connect(disambiguate_url(addr, self.location))
            # end single stream-socket

            # control stream:
            control_addr = str(msg.content.control)
            control_stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop)
            control_stream.setsockopt(zmq.IDENTITY, identity)
            control_stream.connect(
                disambiguate_url(control_addr, self.location))

            # create iopub stream:
            iopub_addr = msg.content.iopub
            iopub_stream = zmqstream.ZMQStream(ctx.socket(zmq.PUB), loop)
            iopub_stream.setsockopt(zmq.IDENTITY, identity)
            iopub_stream.connect(disambiguate_url(iopub_addr, self.location))

            # launch heartbeat
            hb_addrs = msg.content.heartbeat
            # print (hb_addrs)

            # # Redirect input streams and set a display hook.
            if self.out_stream_factory:
                sys.stdout = self.out_stream_factory(self.session,
                                                     iopub_stream, u'stdout')
                sys.stdout.topic = 'engine.%i.stdout' % self.id
                sys.stderr = self.out_stream_factory(self.session,
                                                     iopub_stream, u'stderr')
                sys.stderr.topic = 'engine.%i.stderr' % self.id
            if self.display_hook_factory:
                sys.displayhook = self.display_hook_factory(
                    self.session, iopub_stream)
                sys.displayhook.topic = 'engine.%i.pyout' % self.id

            self.kernel = Kernel(config=self.config,
                                 int_id=self.id,
                                 ident=self.ident,
                                 session=self.session,
                                 control_stream=control_stream,
                                 shell_streams=shell_streams,
                                 iopub_stream=iopub_stream,
                                 loop=loop,
                                 user_ns=self.user_ns,
                                 log=self.log)
            self.kernel.start()
            hb_addrs = [
                disambiguate_url(addr, self.location) for addr in hb_addrs
            ]
            heart = Heart(*map(str, hb_addrs), heart_id=identity)
            heart.start()

        else:
            self.log.fatal("Registration Failed: %s" % msg)
            raise Exception("Registration Failed: %s" % msg)

        self.log.info("Completed registration with id %i" % self.id)

    def abort(self):
        self.log.fatal("Registration timed out after %.1f seconds" %
                       self.timeout)
        self.session.send(self.registrar,
                          "unregistration_request",
                          content=dict(id=self.id))
        time.sleep(1)
        sys.exit(255)

    def start(self):
        dc = ioloop.DelayedCallback(self.register, 0, self.loop)
        dc.start()
        self._abort_dc = ioloop.DelayedCallback(self.abort,
                                                self.timeout * 1000, self.loop)
        self._abort_dc.start()