def setparams(self, methodname: str, *params) -> int: # Set the input parameters to the object's parameters to make a copy of # the operations. # *params is an excessive list of the parameters to be set and the # method will return the number of parameters it sets. paramnames = self.getparamnames(methodname) for name, val in zip(paramnames, params): set_attr(self, name, val) return len(params)
def setparams(self, methodname: str, *params) -> int: # Set the input parameters to the object's parameters to make a copy of # the operations. # *params is an excessive list of the parameters to be set and the # method will return the number of parameters it sets. paramnames = self.cached_getparamnames(methodname) for name, val in zip(paramnames, params): try: set_attr(self, name, val) except TypeError as e: # failed because val should be param del_attr(self, name) set_attr(self, name, val) return len(params)
def _useobjparams(self, objparams): nnmodule = self.obj names = self.names try: # substitute the state dictionary of the module with the new tensor # save the current state state_tensors = [get_attr(nnmodule, name) for name in names] # substitute the state with the given tensor for (name, param) in zip(names, objparams): del_attr(nnmodule, name) # delete require in case the param is not a torch.nn.Parameter set_attr(nnmodule, name, param) yield nnmodule except Exception as exc: raise exc finally: # restore back the saved tensors for (name, param) in zip(names, state_tensors): set_attr(nnmodule, name, param)
def _set_all_obj_params(self, objparams: List): for (name, param) in zip(self.names, objparams): del_attr( self.obj, name ) # delete required in case the param is not a torch.nn.Parameter set_attr(self.obj, name, param)