def __assert_get_correct_params(self, method, *args, **kwargs): # this function perform checks if the getparams on the method returns # the correct tensors methodname = method.__name__ clsname = method.__self__.__class__.__name__ # get all tensor parameters in the object all_params, all_names = _get_tensors(self) def _get_tensor_name(param): for i in range(len(all_params)): if id(all_params[i]) == id(param): return all_names[i] return None # get the parameter tensors used in the operation and the tensors specified by the developer oper_names, oper_params = self.__list_operating_params( method, *args, **kwargs) user_names = self.getparamnames(method.__name__) user_params = [get_attr(self, name) for name in user_names] user_params_id = [id(p) for p in user_params] oper_params_id = [id(p) for p in oper_params] user_params_id_set = set(user_params_id) oper_params_id_set = set(oper_params_id) # check if the userparams contains non-tensor for i in range(len(user_params)): param = user_params[i] if (not isinstance(param, torch.Tensor)) or \ (isinstance(param, torch.Tensor) and param.dtype not in torch_float_type): msg = "Parameter %s is a non-floating point tensor" % user_names[ i] raise GetSetParamsError(msg) # check if there are missing parameters (present in operating params, but not in the user params) missing_names = [] for i in range(len(oper_names)): if oper_params_id[i] not in user_params_id_set: # if oper_names[i] not in user_names: missing_names.append(oper_names[i]) # if there are missing parameters, give a warning (because the program # can still run correctly, e.g. missing parameters are parameters that # are never set to require grad) if len(missing_names) > 0: msg = "getparams for %s.%s does not include: %s" % ( clsname, methodname, ", ".join(missing_names)) warnings.warn(msg, stacklevel=3) # check if there are excessive parameters (present in the user params, but not in the operating params) excess_names = [] for i in range(len(user_names)): if user_params_id[i] not in oper_params_id_set: # if user_names[i] not in oper_names: excess_names.append(user_names[i]) # if there are excess parameters, give warnings if len(excess_names) > 0: msg = "getparams for %s.%s has excess parameters: %s" % \ (clsname, methodname, ", ".join(excess_names)) warnings.warn(msg, stacklevel=3)
def _useobjparams(self, objparams): nnmodule = self.obj names = self.names try: # substitute the state dictionary of the module with the new tensor # save the current state state_tensors = [get_attr(nnmodule, name) for name in names] # substitute the state with the given tensor for (name, param) in zip(names, objparams): del_attr(nnmodule, name) # delete require in case the param is not a torch.nn.Parameter set_attr(nnmodule, name, param) yield nnmodule except Exception as exc: raise exc finally: # restore back the saved tensors for (name, param) in zip(names, state_tensors): set_attr(nnmodule, name, param)
def getparams(self, methodname: str) -> Sequence[torch.Tensor]: # Returns a list of tensor parameters used in the object's operations paramnames = self.getparamnames(methodname) return [get_attr(self, name) for name in paramnames]