def __getattr__(self, attr): if self._has_method(attr): if attr in self.__class__._original_methods: original_method = self.__class__._original_methods[attr] script_method = self._get_method(attr) return functools.wraps(original_method)(script_method) else: return self._get_method(attr) return Module.__getattr__(self, attr)
def __getattr__(self, attr): if self._has_method(attr): if attr in self.__class__._original_methods: original_method = self.__class__._original_methods[attr] script_method = self._get_method(attr) return functools.wraps(original_method)(script_method) else: return self._get_method(attr) return Module.__getattr__(self, attr)
def load_network(self, load_path:str, network:nn.Module, strict:bool=True, submodule:str=None, model_type:str=None, param_key:str=None): """ Load pretrained model into instantiated network. :param load_path: The path of model to be loaded into the network. :param network: the network. :param strict: Whether if the model will be strictly loaded. :param submodule: Specify a submodule of the network to load the model into. :param model_type: To do additional validations if needed (either 'G' or 'D'). :param param_key: The parameter key of loaded model. If set to None, will use the root 'path'. """ # Get bare model, especially under wrapping with DistributedDataParallel or DataParallel. if isinstance(network, (nn.DataParallel, nn.parallel.DistributedDataParallel)): network = network.module # network.load_state_dict(torch.load(load_path), strict=strict) # load into a specific submodule of the network if not (submodule is None or submodule.lower() == 'none'.lower()): network = network.__getattr__(submodule) # load_net = torch.load(load_path) load_net = torch.load( load_path, map_location=lambda storage, loc: storage) # to allow loading state_dicts if 'state_dict' in load_net: load_net = load_net['state_dict'] # load specific keys of the model if param_key is not None: load_net = load_net[param_key] # remove unnecessary 'module.' if needed # for k, v in deepcopy(load_net).items(): # if k.startswith('module.'): # load_net[k[7:]] = v # load_net.pop(k) # validate model type to be loaded in the network can do # any additional conversion or modification steps here # (requires 'model_type', either 'G' or 'D') if model_type: load_net = model_val( opt_net=self.opt, state_dict=load_net, model_type=model_type ) # to remove running_mean and running_var from models using # InstanceNorm2d trained with PyTorch before 0.4.0: # for k in list(load_net.keys()): # if (k.find('running_mean') > 0) or (k.find('running_var') > 0): # del load_net[k] network.load_state_dict(load_net, strict=strict)
def count_parameters(model: nn.Module, keys: Optional[Sequence[str]] = None) -> Dict[str, int]: """ Count number of total and trainable parameters of a model :param model: A model :param keys: Optional list of top-level blocks :return: Tuple (total, trainable) """ if keys is None: keys = ["encoder", "decoder", "logits", "head", "final"] total = int(sum(p.numel() for p in model.parameters())) trainable = int(sum(p.numel() for p in model.parameters() if p.requires_grad)) parameters = {"total": total, "trainable": trainable} for key in keys: if hasattr(model, key) and model.__getattr__(key) is not None: parameters[key] = int(sum(p.numel() for p in model.__getattr__(key).parameters())) return parameters
def count_parameters(model: nn.Module) -> dict: """ Count number of total and trainable parameters of a model :param model: A model :return: Tuple (total, trainable) """ total = sum(p.numel() for p in model.parameters()) trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) parameters = {"total": total, "trainable": trainable} for key in ["encoder", "decoder"]: if hasattr(model, key): parameters[key] = sum(p.numel() for p in model.__getattr__(key).parameters()) return parameters
def __getattr__(self, attr): if self._has_method(attr): return self._get_method(attr) return Module.__getattr__(self, attr)
def __getattr__(self, attr): if self._has_method(attr): return self._get_method(attr) return Module.__getattr__(self, attr)