Exemple #1
0

def _remote_module_reducer(remote_module):
    """
    Serializes a RemoteModule.
    """
    pickled_attrs = {}
    for k, v in remote_module.__dict__.items():
        # Pickling the attribute `module_rref` must invoke RRef's `_serialize()` method.
        if k == "module_rref":
            pickled_attrs[k] = v._serialize()
        elif k in _REMOTE_MODULE_PICKLED_ATTRIBUTES:  # type: ignore[attr-defined]
            pickled_attrs[k] = v
        # Check if unpickled attributes are all in _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING.
        elif k not in _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING:  # type: ignore[attr-defined]
            print(
                "The new attribute ``{}`` of RemoteModule is ignored during RPC pickling. "
                "To pickle this attribute, please add it to ``_REMOTE_MODULE_PICKLED_ATTRIBUTES``. "
                "Otherwise, please explicitly add it to ``_REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING``."
                .format(k),
                file=sys.stderr,
            )

    return (
        _remote_module_receiver,
        tuple(pickled_attrs.values()),
    )


_internal_rpc_pickler._register_reducer(RemoteModule, _remote_module_reducer)
Exemple #2
0

def _recursive_script_module_receiver(recursive_script_module_serialized, ):
    """
    Deserializes a RecursiveScirptModule that does not contain a script RemoteModule.
    """
    f = io.BytesIO(recursive_script_module_serialized)
    m = torch.jit.load(f)
    return m


def _recursive_script_module_reducer(recursive_script_module):
    """
    Serializes a RecursiveScirptModule that does not contain a script RemoteModule,
    and raises an error otherwise.
    """
    if hasattr(recursive_script_module._c, "module_rref"):
        raise RuntimeError(
            "Passing a script RemoteModule over RPC is not supported. Please create a RemoteModule in the sender, "
            "send the `module_rref` to the receiver, and create a new instance on the receiver end by passing this `module_rref`."
        )

    f = io.BytesIO()
    torch.jit.save(recursive_script_module, f)
    return (_recursive_script_module_receiver, (f.getvalue(), ))


_internal_rpc_pickler._register_reducer(RemoteModule, _remote_module_reducer)
_internal_rpc_pickler._register_reducer(torch.jit.RecursiveScriptModule,
                                        _recursive_script_module_reducer)