def _remote_module_reducer(remote_module): """ Serializes a RemoteModule. """ pickled_attrs = {} for k, v in remote_module.__dict__.items(): # Pickling the attribute `module_rref` must invoke RRef's `_serialize()` method. if k == "module_rref": pickled_attrs[k] = v._serialize() elif k in _REMOTE_MODULE_PICKLED_ATTRIBUTES: # type: ignore[attr-defined] pickled_attrs[k] = v # Check if unpickled attributes are all in _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING. elif k not in _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING: # type: ignore[attr-defined] print( "The new attribute ``{}`` of RemoteModule is ignored during RPC pickling. " "To pickle this attribute, please add it to ``_REMOTE_MODULE_PICKLED_ATTRIBUTES``. " "Otherwise, please explicitly add it to ``_REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING``." .format(k), file=sys.stderr, ) return ( _remote_module_receiver, tuple(pickled_attrs.values()), ) _internal_rpc_pickler._register_reducer(RemoteModule, _remote_module_reducer)
def _recursive_script_module_receiver(recursive_script_module_serialized, ): """ Deserializes a RecursiveScirptModule that does not contain a script RemoteModule. """ f = io.BytesIO(recursive_script_module_serialized) m = torch.jit.load(f) return m def _recursive_script_module_reducer(recursive_script_module): """ Serializes a RecursiveScirptModule that does not contain a script RemoteModule, and raises an error otherwise. """ if hasattr(recursive_script_module._c, "module_rref"): raise RuntimeError( "Passing a script RemoteModule over RPC is not supported. Please create a RemoteModule in the sender, " "send the `module_rref` to the receiver, and create a new instance on the receiver end by passing this `module_rref`." ) f = io.BytesIO() torch.jit.save(recursive_script_module, f) return (_recursive_script_module_receiver, (f.getvalue(), )) _internal_rpc_pickler._register_reducer(RemoteModule, _remote_module_reducer) _internal_rpc_pickler._register_reducer(torch.jit.RecursiveScriptModule, _recursive_script_module_reducer)