def __assert_method_preserve(self, method, *args, **kwargs): # this method assert if method does not change the float tensor parameters # of the object (i.e. it preserves the state of the object) all_params0, names0 = _get_tensors(self) all_params0 = [p.clone() for p in all_params0] method(*args, **kwargs) all_params1, names1 = _get_tensors(self) # now assert if all_params0 == all_params1 clsname = method.__self__.__class__.__name__ methodname = method.__name__ msg = "The method %s.%s does not preserve the object's float tensors: \n" % ( clsname, methodname) if len(all_params0) != len(all_params1): msg += "The number of parameters changed:\n" msg += "* number of object's parameters before: %d\n" % len( all_params0) msg += "* number of object's parameters after : %d\n" % len( all_params1) raise GetSetParamsError(msg) for pname, p0, p1 in zip(names0, all_params0, all_params1): if p0.shape != p1.shape: msg += "The shape of %s changed\n" % pname msg += "* (before) %s.shape: %s\n" % (pname, p0.shape) msg += "* (after ) %s.shape: %s\n" % (pname, p1.shape) raise GetSetParamsError(msg) if not torch.allclose(p0, p1): msg += "The value of %s changed\n" % pname msg += "* (before) %s: %s\n" % (pname, p0) msg += "* (after ) %s: %s\n" % (pname, p1) raise GetSetParamsError(msg)
def __assert_get_correct_params(self, method, *args, **kwargs): # this function perform checks if the getparams on the method returns # the correct tensors methodname = method.__name__ clsname = method.__self__.__class__.__name__ # get all tensor parameters in the object all_params, all_names = _get_tensors(self) def _get_tensor_name(param): for i in range(len(all_params)): if id(all_params[i]) == id(param): return all_names[i] return None # get the parameter tensors used in the operation and the tensors specified by the developer oper_names, oper_params = self.__list_operating_params( method, *args, **kwargs) user_names = self.getparamnames(method.__name__) user_params = [get_attr(self, name) for name in user_names] user_params_id = [id(p) for p in user_params] oper_params_id = [id(p) for p in oper_params] user_params_id_set = set(user_params_id) oper_params_id_set = set(oper_params_id) # check if the userparams contains non-tensor for i in range(len(user_params)): param = user_params[i] if (not isinstance(param, torch.Tensor)) or \ (isinstance(param, torch.Tensor) and param.dtype not in torch_float_type): msg = "Parameter %s is a non-floating point tensor" % user_names[ i] raise GetSetParamsError(msg) # check if there are missing parameters (present in operating params, but not in the user params) missing_names = [] for i in range(len(oper_names)): if oper_params_id[i] not in user_params_id_set: # if oper_names[i] not in user_names: missing_names.append(oper_names[i]) # if there are missing parameters, give a warning (because the program # can still run correctly, e.g. missing parameters are parameters that # are never set to require grad) if len(missing_names) > 0: msg = "getparams for %s.%s does not include: %s" % ( clsname, methodname, ", ".join(missing_names)) warnings.warn(msg, stacklevel=3) # check if there are excessive parameters (present in the user params, but not in the operating params) excess_names = [] for i in range(len(user_names)): if user_params_id[i] not in oper_params_id_set: # if user_names[i] not in oper_names: excess_names.append(user_names[i]) # if there are excess parameters, give warnings if len(excess_names) > 0: msg = "getparams for %s.%s has excess parameters: %s" % \ (clsname, methodname, ", ".join(excess_names)) warnings.warn(msg, stacklevel=3)
def __assert_match_getsetparams(self, methodname): # this function assert if get & set params functions correspond to the # same parameters in the same order # count the number of parameters in getparams and setparams params0 = self.getparams(methodname) len_setparams0 = self.setparams(methodname, *params0) if len_setparams0 != len(params0): raise GetSetParamsError("The number of parameters returned by getparams and set by setparams do not match \n"\ "(getparams: %d, setparams: %d)" % (len(params0), len_setparams0)) # check if the params are assigned correctly in the correct order params1 = self.getparams(methodname) names1 = self.getparamnames(methodname) for i, p0, p1 in zip(range(len(params0)), params0, params1): if id(p0) != id(p1): msg = "The parameter %s in getparams and setparams does not match" % names1[ i] raise GetSetParamsError(msg)
def __assert_method_preserve(self, method, *args, **kwargs): # this method assert if method does not change the float tensor parameters # of the object (i.e. it preserves the state of the object) all_params0, names0 = _get_tensors(self) all_params0 = [p.clone() for p in all_params0] method(*args, **kwargs) all_params1, names1 = _get_tensors(self) # now assert if all_params0 == all_params1 clsname = method.__self__.__class__.__name__ methodname = method.__name__ msg = "The method %s.%s does not preserve the object's float tensors" % ( clsname, methodname) if len(all_params0) != len(all_params1): raise GetSetParamsError(msg) for p0, p1 in zip(all_params0, all_params1): if p0.shape != p1.shape: raise GetSetParamsError(msg) if not torch.allclose(p0, p1): raise GetSetParamsError(msg)