Exemplo n.º 1
0
 def load(module: nn.Module, *, prefix: str):
     """ Load weights recursively. Heavily inspired by torch.nn.Module.load_state_dict """
     module._load_from_state_dict(
         state_dict, prefix, dict(), True, missing_keys, unexpected_keys, error_msgs)
     for name, child in module._modules.items():
         if child is not None:
             load(child, prefix=prefix + name + ".")
 def load(module: nn.Module, prefix=""):
     local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
     module._load_from_state_dict(
         state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs,
     )
     for name, child in module._modules.items():
         if child is not None:
             load(child, prefix + name + ".")
 def load(module: nn.Module, prefix=""):
     local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
     if isinstance(ignore_weights, list) and not any(weight in prefix for weight in ignore_weights):
         module._load_from_state_dict(
             state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs, )
     for name, child in module._modules.items():
         if child is not None:
             load(child, prefix + name + ".")
Exemplo n.º 4
0
def _load_from_nested_dict(module: nn.Module, nested_dict: OrderedDict,
                           prefix: str, local_metadata: Dict[str, int],
                           missing_keys: List[str], unexpected_keys: List[str],
                           error_msgs: List[str]):
    """
    A warpped function of `torch.nn.Module._load_state_dict()`, to load
    from nested dict(`OrderedDict` that contains one state_dict for each loadable 
    module).
    
    Copies parameters and buffers from `nested_dict` into only `module` (if
    necessary), but not its descendents. This is called on every submodule in
    `load_nested_dict()` function. Metadata will be used if there is any inside
    `state_dict_block`, otherwise `local_metadata` will be used. 

    Args:
        module (torch.nn.Module): module to be loaded
        nested_dict (OrderedDict): an `OrderedDict` containing several dicts,
            each containing parameters and persistent buffers for one loadable
            module.
        prefix (str): the prefix for parameters and buffers used in this module
        local_metadata (dict): metadata of the module(if there is metadata
            inside the state_dict_block, this will not be used)
        missing_keys (list of str): (used in `torch.nn.Module._load_state_dict()`)
            if `strict=True`, add missing keys to this list
        unexpected_keys (list of str): (used in `torch.nn.Module._load_state_dict()`)
            if `strict=True`, add unexpected keys to this list
        error_msgs (list of str):(used in `torch.nn.Module._load_state_dict()`)
            error messages should be added to this list, and will be reported 
            together in `load_nested_dict()` function
    """
    # get state_dict_block if necessary and possible
    if len(module._parameters) == 0 and len(module._buffers) == 0:
        state_dict_block = {}
    else:
        try:
            _, state_dict_block = nested_dict.popitem(last=False)
        except KeyError:
            state_dict_block = {}
    metadata_from_block = getattr(state_dict_block, '_metadata', {})
    state_dict_block = {
        prefix + key: val
        for key, val in state_dict_block.items()
    }
    # get metadata if possible(in state_dict_block or nested_dict._metadata)
    local_metadata.update(metadata_from_block)
    module._load_from_state_dict(state_dict_block, prefix, local_metadata,
                                 True, missing_keys, unexpected_keys,
                                 error_msgs)
Exemplo n.º 5
0
 def _disable_state_dict(module: nn.Module):
     for name, child in module._modules.items():
         if child is not None:
             _disable_state_dict(child)
     module.state_dict = types.MethodType(state_dict, self)
     module._load_from_state_dict = types.MethodType(
         _load_from_state_dict, self)