def test_is_picklable(tmpdir): # See the full list of picklable types at # https://docs.python.org/3/library/pickle.html#pickle-picklable class UnpicklableClass: # Only classes defined at the top level of a module are picklable. pass true_cases = [None, True, 123, "str", (123, "str"), max] false_cases = [unpicklable_function, UnpicklableClass, ScriptModule()] for case in true_cases: assert is_picklable(case) is True for case in false_cases: assert is_picklable(case) is False
def replicate(network, devices, no_gradient=False): def clear_gradient(para): para.grad = None return para num_replicas = len(devices) params = [clear_gradient(para) for para in network.parameters() ] if no_gradient else list(network.parameters()) param_indices = {param: idx for idx, param in enumerate(params)} param_copies = comm.broadcast_coalesced(params, devices) buffers = list(network.buffers()) buffer_indices = {buf: idx for idx, buf in enumerate(buffers)} buffer_copies = comm.broadcast_coalesced(buffers, devices) modules = list(network.modules()) module_copies = [[] for device in devices] module_indices = {} scriptmodule_skip_attr = { "_parameters", "_buffers", "_modules", "forward", "_c" } for i, module in enumerate(modules): module_indices[module] = i for j in range(num_replicas): if isinstance(module, ScriptModule): # we have to initialize ScriptModule properly so that # it works with pybind11 replica = ScriptModule() attribute_names = set(entry[0] for entry in module._c._get_attributes()) keys = set(module.__dict__.keys() ) - scriptmodule_skip_attr - attribute_names for key in keys: if not isinstance(module.__dict__[key], ScriptMethod): replica.__dict__[key] = module.__dict__[key] for name, the_type, value in module._c._get_attributes(): if not name in module._buffers.keys(): replica._c._register_attribute(name, the_type, value) else: replica = module.__new__(type(module)) replica.__dict__ = module.__dict__.copy() replica._parameters = replica._parameters.copy() replica._buffers = replica._buffers.copy() replica._modules = replica._modules.copy() module_copies[j].append(replica) for i, module in enumerate(modules): for key, child in module._modules.items(): if child is None: for j in range(num_replicas): replica = module_copies[j][i] replica._modules[key] = None else: module_idx = module_indices[child] for j in range(num_replicas): replica = module_copies[j][i] replica._modules[key] = module_copies[j][module_idx] for key, param in module._parameters.items(): if param is None: for j in range(num_replicas): replica = module_copies[j][i] replica._parameters[key] = None else: param_idx = param_indices[param] for j in range(num_replicas): replica = module_copies[j][i] replica._parameters[key] = param_copies[j][ param_idx].requires_grad_(param.requires_grad) for key, buf in module._buffers.items(): if buf is None: for j in range(num_replicas): replica = module_copies[j][i] replica._buffers[key] = None else: buffer_idx = buffer_indices[buf] for j in range(num_replicas): replica = module_copies[j][i] replica._buffers[key] = buffer_copies[j][buffer_idx] for j in range(num_replicas): for i, module in enumerate(modules): if isinstance(module, ScriptModule): replica = module_copies[j][i] for method_name in module._c._method_names(): replica._c.clone_method(module._c, method_name) return [module_copies[j][0] for j in range(num_replicas)]
def __init__(self, optimize=True, **kwargs): ScriptModule.__init__(self, optimize=optimize, **kwargs) ModuleWithInit.__init__(self)