def test_is_picklable(tmpdir):
    # See the full list of picklable types at
    # https://docs.python.org/3/library/pickle.html#pickle-picklable
    class UnpicklableClass:
        # Only classes defined at the top level of a module are picklable.
        pass

    true_cases = [None, True, 123, "str", (123, "str"), max]
    false_cases = [unpicklable_function, UnpicklableClass, ScriptModule()]

    for case in true_cases:
        assert is_picklable(case) is True

    for case in false_cases:
        assert is_picklable(case) is False
Ejemplo n.º 2
0
def replicate(network, devices, no_gradient=False):
    def clear_gradient(para):
        para.grad = None
        return para

    num_replicas = len(devices)

    params = [clear_gradient(para) for para in network.parameters()
              ] if no_gradient else list(network.parameters())
    param_indices = {param: idx for idx, param in enumerate(params)}
    param_copies = comm.broadcast_coalesced(params, devices)

    buffers = list(network.buffers())
    buffer_indices = {buf: idx for idx, buf in enumerate(buffers)}
    buffer_copies = comm.broadcast_coalesced(buffers, devices)

    modules = list(network.modules())
    module_copies = [[] for device in devices]
    module_indices = {}
    scriptmodule_skip_attr = {
        "_parameters", "_buffers", "_modules", "forward", "_c"
    }

    for i, module in enumerate(modules):
        module_indices[module] = i
        for j in range(num_replicas):
            if isinstance(module, ScriptModule):
                # we have to initialize ScriptModule properly so that
                # it works with pybind11
                replica = ScriptModule()

                attribute_names = set(entry[0]
                                      for entry in module._c._get_attributes())

                keys = set(module.__dict__.keys()
                           ) - scriptmodule_skip_attr - attribute_names
                for key in keys:
                    if not isinstance(module.__dict__[key], ScriptMethod):
                        replica.__dict__[key] = module.__dict__[key]
                for name, the_type, value in module._c._get_attributes():
                    if not name in module._buffers.keys():
                        replica._c._register_attribute(name, the_type, value)
            else:
                replica = module.__new__(type(module))
                replica.__dict__ = module.__dict__.copy()
                replica._parameters = replica._parameters.copy()
                replica._buffers = replica._buffers.copy()
                replica._modules = replica._modules.copy()

            module_copies[j].append(replica)

    for i, module in enumerate(modules):
        for key, child in module._modules.items():
            if child is None:
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._modules[key] = None
            else:
                module_idx = module_indices[child]
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._modules[key] = module_copies[j][module_idx]
        for key, param in module._parameters.items():
            if param is None:
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._parameters[key] = None
            else:
                param_idx = param_indices[param]
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._parameters[key] = param_copies[j][
                        param_idx].requires_grad_(param.requires_grad)
        for key, buf in module._buffers.items():
            if buf is None:
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._buffers[key] = None
            else:
                buffer_idx = buffer_indices[buf]
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._buffers[key] = buffer_copies[j][buffer_idx]

    for j in range(num_replicas):
        for i, module in enumerate(modules):
            if isinstance(module, ScriptModule):
                replica = module_copies[j][i]
                for method_name in module._c._method_names():
                    replica._c.clone_method(module._c, method_name)

    return [module_copies[j][0] for j in range(num_replicas)]
Ejemplo n.º 3
0
 def __init__(self, optimize=True, **kwargs):
     ScriptModule.__init__(self, optimize=optimize, **kwargs)
     ModuleWithInit.__init__(self)