Ejemplo n.º 1
0
    def __init__(self, service, channel, config = {}, _lazy = False):
        self._closed = True
        self._config = DEFAULT_CONFIG.copy()
        self._config.update(config)
        if self._config["connid"] is None:
            self._config["connid"] = "conn%d" % (_connection_id_generator.next(),)

        self._channel = channel
        self._seqcounter = itertools.count()
        self._recvlock = Lock()
        self._sendlock = Lock()
        self._sync_replies = {}
        self._async_callbacks = {}
        self._local_objects = RefCountingColl()
        self._last_traceback = None
        self._proxy_cache = WeakValueDict()
        self._netref_classes_cache = {}
        self._remote_root = None
        self._local_root = service(weakref.proxy(self))
        if not _lazy:
            self._init_service()
        self._closed = False
Ejemplo n.º 2
0
class Connection(object):
    """The RPyC connection (also know as the RPyC protocol).
    * service: the service to expose
    * channel: the channcel over which messages are passed
    * config: this connection's config dict (overriding parameters from the
      default config dict)
    * _lazy: whether or not to initialize the service with the creation of the
      connection. default is True. if set to False, you will need to call
      _init_service manually later
    """
    def __init__(self, service, channel, config = {}, _lazy = False):
        self._closed = True
        self._config = DEFAULT_CONFIG.copy()
        self._config.update(config)
        if self._config["connid"] is None:
            self._config["connid"] = "conn%d" % (_connection_id_generator.next(),)

        self._channel = channel
        self._seqcounter = itertools.count()
        self._recvlock = Lock()
        self._sendlock = Lock()
        self._sync_replies = {}
        self._async_callbacks = {}
        self._local_objects = RefCountingColl()
        self._last_traceback = None
        self._proxy_cache = WeakValueDict()
        self._netref_classes_cache = {}
        self._remote_root = None
        self._local_root = service(weakref.proxy(self))
        if not _lazy:
            self._init_service()
        self._closed = False
    def _init_service(self):
        self._local_root.on_connect()

    def __del__(self):
        self.close()
    def __enter__(self):
        return self
    def __exit__(self, t, v, tb):
        self.close()
    def __repr__(self):
        a, b = object.__repr__(self).split(" object ")
        return "%s %r object %s" % (a, self._config["connid"], b)

    #
    # IO
    #
    def _cleanup(self, _anyway = True):
        if self._closed and not _anyway:
            return
        self._closed = True
        self._channel.close()
        self._local_root.on_disconnect()
        self._sync_replies.clear()
        self._async_callbacks.clear()
        self._local_objects.clear()
        self._proxy_cache.clear()
        self._netref_classes_cache.clear()
        self._last_traceback = None
        self._last_traceback = None
        self._remote_root = None
        self._local_root = None
        #self._seqcounter = None
        #self._config.clear()
    def close(self, _catchall = True):
        if self._closed:
            return
        self._closed = True
        try:
            try:
                self._async_request(consts.HANDLE_CLOSE)
            except EOFError:
                pass
            except Exception:
                if not _catchall:
                    raise
        finally:
            self._cleanup(_anyway = True)

    @property
    def closed(self):
        return self._closed
    def fileno(self):
        return self._channel.fileno()

    def ping(self, data = "the world is a vampire!" * 20, timeout = 3):
        """assert that the other party is functioning properly"""
        res = self.async_request(consts.HANDLE_PING, data, timeout = timeout)
        if res.value != data:
            raise PingError("echo mismatches sent data")

    def _send(self, msg, seq, args):
        data = brine.dump((msg, seq, args))
        self._sendlock.acquire()
        try:
            self._channel.send(data)
        finally:
            self._sendlock.release()
    def _send_request(self, handler, args):
        seq = self._seqcounter.next()
        self._send(consts.MSG_REQUEST, seq, (handler, self._box(args)))
        return seq
    def _send_reply(self, seq, obj):
        self._send(consts.MSG_REPLY, seq, self._box(obj))
    def _send_exception(self, seq, exctype, excval, exctb):
        exc = vinegar.dump(exctype, excval, exctb,
            include_local_traceback = self._config["include_local_traceback"])
        self._send(consts.MSG_EXCEPTION, seq, exc)

    #
    # boxing
    #
    def _box(self, obj):
        """store a local object in such a way that it could be recreated on
        the remote party either by-value or by-reference"""
        if brine.dumpable(obj):
            return consts.LABEL_VALUE, obj
        if type(obj) is tuple:
            return consts.LABEL_TUPLE, tuple(self._box(item) for item in obj)
        elif isinstance(obj, netref.BaseNetref) and obj.____conn__() is self:
            return consts.LABEL_LOCAL_REF, obj.____oid__
        else:
            self._local_objects.add(obj)
            ## cls = getattr(obj, "__class__", type(obj))
            try:
                cls = obj.__class__
            except:
                cls = type(obj)
            return consts.LABEL_REMOTE_REF, (id(obj), cls.__name__, cls.__module__)

    def _unbox(self, package):
        """recreate a local object representation of the remote object: if the
        object is passed by value, just return it; if the object is passed by
        reference, create a netref to it"""
        label, value = package
        if label == consts.LABEL_VALUE:
            return value
        if label == consts.LABEL_TUPLE:
            return tuple(self._unbox(item) for item in value)
        if label == consts.LABEL_LOCAL_REF:
            return self._local_objects[value]
        if label == consts.LABEL_REMOTE_REF:
            oid, clsname, modname = value
            if oid in self._proxy_cache:
                return self._proxy_cache[oid]
            proxy = self._netref_factory(oid, clsname, modname)
            self._proxy_cache[oid] = proxy
            return proxy
        raise ValueError("invalid label %r" % (label,))

    def _netref_factory(self, oid, clsname, modname):
        typeinfo = (clsname, modname)
        if typeinfo in self._netref_classes_cache:
            cls = self._netref_classes_cache[typeinfo]
        elif typeinfo in netref.builtin_classes_cache:
            cls = netref.builtin_classes_cache[typeinfo]
        else:
            info = self.sync_request(consts.HANDLE_INSPECT, oid)
            cls = netref.class_factory(clsname, modname, info)
            self._netref_classes_cache[typeinfo] = cls
        return cls(weakref.ref(self), oid)

    #
    # dispatching
    #
    def _dispatch_request(self, seq, raw_args):
        try:
            handler, args = raw_args
            args = self._unbox(args)
            res = self._HANDLERS[handler](self, *args)
        except KeyboardInterrupt:
            raise
        except:
            t, v, tb = sys.exc_info()
            self._last_traceback = tb
            if t is SystemExit and self._config["propagate_SystemExit_locally"]:
                raise
            self._send_exception(seq, t, v, tb)
        else:
            self._send_reply(seq, res)

    def _dispatch_reply(self, seq, raw):
        obj = self._unbox(raw)
        if seq in self._async_callbacks:
            self._async_callbacks.pop(seq)(False, obj)
        else:
            self._sync_replies[seq] = (False, obj)

    def _dispatch_exception(self, seq, raw):
        obj = vinegar.load(raw,
            import_custom_exceptions = self._config["import_custom_exceptions"],
            instantiate_custom_exceptions = self._config["instantiate_custom_exceptions"],
            instantiate_oldstyle_exceptions = self._config["instantiate_oldstyle_exceptions"])
        if seq in self._async_callbacks:
            self._async_callbacks.pop(seq)(True, obj)
        else:
            self._sync_replies[seq] = (True, obj)

    #
    # serving
    #
    def _recv(self, timeout, wait_for_lock):
        if not self._recvlock.acquire(wait_for_lock):
            return None
        try:
            try:
                if self._channel.poll(timeout):
                    data = self._channel.recv()
                else:
                    data = None
            except EOFError:
                self.close()
                raise
        finally:
            self._recvlock.release()
        return data

    def _dispatch(self, data):
        msg, seq, args = brine.load(data)
        if msg == consts.MSG_REQUEST:
            self._dispatch_request(seq, args)
        elif msg == consts.MSG_REPLY:
            self._dispatch_reply(seq, args)
        elif msg == consts.MSG_EXCEPTION:
            self._dispatch_exception(seq, args)
        else:
            raise ValueError("invalid message type: %r" % (msg,))

    def poll(self, timeout = 0):
        """serve a single transaction, should one arrives in the given
        interval. note that handling a request/reply may trigger nested
        requests, which are all part of the transaction.

        returns True if one was served, False otherwise"""
        data = self._recv(timeout, wait_for_lock = False)
        if not data:
            return False
        self._dispatch(data)
        return True

    def serve(self, timeout = 1):
        """serve a single request or reply that arrives within the given
        time frame (default is 1 sec). note that the dispatching of a request
        might trigger multiple (nested) requests, thus this function may be
        reentrant. returns True if a request or reply were received, False
        otherwise."""

        data = self._recv(timeout, wait_for_lock = True)
        if not data:
            return False
        self._dispatch(data)
        return True

    def serve_all(self):
        """serve all requests and replies while the connection is alive"""
        try:
            try:
                while True:
                    self.serve(0.1)
            except select.error:
                if not self.closed:
                    raise e
            except EOFError:
                pass
        finally:
            self.close()

    def poll_all(self, timeout = 0):
        """serve all requests and replies that arrive within the given interval.
        returns True if at least one was served, False otherwise"""
        at_least_once = False
        try:
            while self.poll(timeout):
                at_least_once = True
        except EOFError:
            pass
        return at_least_once

    #
    # requests
    #
    def sync_request(self, handler, *args):
        """send a request and wait for the reply to arrive"""
        seq = self._send_request(handler, args)
        while seq not in self._sync_replies:
            self.serve(0.1)
        isexc, obj = self._sync_replies.pop(seq)
        if isexc:
            raise obj
        else:
            return obj

    def _async_request(self, handler, args = (), callback = (lambda a, b: None)):
        seq = self._send_request(handler, args)
        self._async_callbacks[seq] = callback
    def async_request(self, handler, *args, **kwargs):
        """send a request and return an AsyncResult object, which will
        eventually hold the reply"""
        timeout = kwargs.pop("timeout", None)
        if kwargs:
            raise TypeError("got unexpected keyword argument %r" % (kwargs.keys()[0],))
        res = AsyncResult(weakref.proxy(self))
        self._async_request(handler, args, res)
        if timeout is not None:
            res.set_expiry(timeout)
        return res

    @property
    def root(self):
        """fetch the root object of the other party"""
        if self._remote_root is None:
            self._remote_root = self.sync_request(consts.HANDLE_GETROOT)
        return self._remote_root

    #
    # attribute access
    #
    def _check_attr(self, obj, name):
        if self._config["allow_exposed_attrs"]:
            if name.startswith(self._config["exposed_prefix"]):
                name2 = name
            else:
                name2 = self._config["exposed_prefix"] + name
            if hasattr(obj, name2):
                return name2
        if self._config["allow_all_attrs"]:
            return name
        if self._config["allow_safe_attrs"] and name in self._config["safe_attrs"]:
            return name
        if self._config["allow_public_attrs"] and not name.startswith("_"):
            return name
        return False

    def _access_attr(self, oid, name, args, overrider, param, default):
        if type(name) is not str:
            raise TypeError("attr name must be a string")
        obj = self._local_objects[oid]
        accessor = getattr(type(obj), overrider, None)
        if accessor is None:
            name2 = self._check_attr(obj, name)
            if not self._config[param] or not name2:
                raise AttributeError("cannot access %r" % (name,))
            accessor = default
            name = name2
        return accessor(obj, name, *args)

    #
    # handlers
    #
    def _handle_ping(self, data):
        return data
    def _handle_close(self):
        self._cleanup()
    def _handle_getroot(self):
        return self._local_root
    def _handle_del(self, oid):
        self._local_objects.decref(oid)
    def _handle_repr(self, oid):
        return repr(self._local_objects[oid])
    def _handle_str(self, oid):
        return str(self._local_objects[oid])
    def _handle_cmp(self, oid, other):
        # cmp() might enter recursive resonance... yet another workaround
        #return cmp(self._local_objects[oid], other)
        obj = self._local_objects[oid]
        try:
            return type(obj).__cmp__(obj, other)
        except TypeError:
            return NotImplemented
    def _handle_hash(self, oid):
        return hash(self._local_objects[oid])
    def _handle_call(self, oid, args, kwargs):
        return self._local_objects[oid](*args, **dict(kwargs))
    def _handle_dir(self, oid):
        return tuple(dir(self._local_objects[oid]))
    def _handle_inspect(self, oid):
        return tuple(netref.inspect_methods(self._local_objects[oid]))
    def _handle_getattr(self, oid, name):
        return self._access_attr(oid, name, (), "_rpyc_getattr", "allow_getattr", getattr)
    def _handle_delattr(self, oid, name):
        return self._access_attr(oid, name, (), "_rpyc_delattr", "allow_delattr", delattr)
    def _handle_setattr(self, oid, name, value):
        return self._access_attr(oid, name, (value,), "_rpyc_setattr", "allow_setattr", setattr)
    def _handle_callattr(self, oid, name, args, kwargs):
        return self._handle_getattr(oid, name)(*args, **dict(kwargs))
    def _handle_pickle(self, oid, proto):
        if not self._config["allow_pickle"]:
            raise ValueError("pickling is disabled")
        return pickle.dumps(self._local_objects[oid], proto)
    def _handle_buffiter(self, oid, count):
        items = []
        obj = self._local_objects[oid]
        for i in xrange(count):
            try:
                items.append(obj.next())
            except StopIteration:
                break
        return tuple(items)

    # collect handlers
    _HANDLERS = {}
    for name, obj in locals().items():
        if name.startswith("_handle_"):
            name2 = "HANDLE_" + name[8:].upper()
            if hasattr(consts, name2):
                _HANDLERS[getattr(consts, name2)] = obj
            else:
                raise NameError("no constant defined for %r", name)
    del name, name2, obj
Ejemplo n.º 3
0
            yield elem

class _Async(object):
    """creates an async proxy wrapper over an existing proxy. async proxies 
    are cached. invoking an async proxy will return an AsyncResult instead of
    blocking"""
    
    __slots__ = ("proxy", "__weakref__")
    def __init__(self, proxy):
        self.proxy = proxy
    def __call__(self, *args, **kwargs):
        return asyncreq(self.proxy, HANDLE_CALL, args, tuple(kwargs.items()))
    def __repr__(self):
        return "async(%r)" % (self.proxy,)

_async_proxies_cache = WeakValueDict()
def async(proxy):
    pid = id(proxy)
    if pid in _async_proxies_cache:
        return _async_proxies_cache[pid]
    if not hasattr(proxy, "____conn__") or not hasattr(proxy, "____oid__"):
        raise TypeError("'proxy' must be a Netref: %r", (proxy,))
    if not callable(proxy):
        raise TypeError("'proxy' must be callable: %r" % (proxy,))
    caller = _Async(proxy)
    _async_proxies_cache[id(caller)] = _async_proxies_cache[pid] = caller
    return caller

async.__doc__ = _Async.__doc__

class timed(object):