Beispiel #1
0
    def __assert_get_correct_params(self, method, *args, **kwargs):
        # this function perform checks if the getparams on the method returns
        # the correct tensors

        methodname = method.__name__
        clsname = method.__self__.__class__.__name__

        # get all tensor parameters in the object
        all_params, all_names = _get_tensors(self)

        def _get_tensor_name(param):
            for i in range(len(all_params)):
                if id(all_params[i]) == id(param):
                    return all_names[i]
            return None

        # get the parameter tensors used in the operation and the tensors specified by the developer
        oper_names, oper_params = self.__list_operating_params(
            method, *args, **kwargs)
        user_names = self.getparamnames(method.__name__)
        user_params = [get_attr(self, name) for name in user_names]
        user_params_id = [id(p) for p in user_params]
        oper_params_id = [id(p) for p in oper_params]
        user_params_id_set = set(user_params_id)
        oper_params_id_set = set(oper_params_id)

        # check if the userparams contains non-tensor
        for i in range(len(user_params)):
            param = user_params[i]
            if (not isinstance(param, torch.Tensor)) or \
               (isinstance(param, torch.Tensor) and param.dtype not in torch_float_type):
                msg = "Parameter %s is a non-floating point tensor" % user_names[
                    i]
                raise GetSetParamsError(msg)

        # check if there are missing parameters (present in operating params, but not in the user params)
        missing_names = []
        for i in range(len(oper_names)):
            if oper_params_id[i] not in user_params_id_set:
                # if oper_names[i] not in user_names:
                missing_names.append(oper_names[i])
        # if there are missing parameters, give a warning (because the program
        # can still run correctly, e.g. missing parameters are parameters that
        # are never set to require grad)
        if len(missing_names) > 0:
            msg = "getparams for %s.%s does not include: %s" % (
                clsname, methodname, ", ".join(missing_names))
            warnings.warn(msg, stacklevel=3)

        # check if there are excessive parameters (present in the user params, but not in the operating params)
        excess_names = []
        for i in range(len(user_names)):
            if user_params_id[i] not in oper_params_id_set:
                # if user_names[i] not in oper_names:
                excess_names.append(user_names[i])
        # if there are excess parameters, give warnings
        if len(excess_names) > 0:
            msg = "getparams for %s.%s has excess parameters: %s" % \
                (clsname, methodname, ", ".join(excess_names))
            warnings.warn(msg, stacklevel=3)
Beispiel #2
0
    def _useobjparams(self, objparams):
        nnmodule = self.obj
        names = self.names
        try:
            # substitute the state dictionary of the module with the new tensor

            # save the current state
            state_tensors = [get_attr(nnmodule, name) for name in names]

            # substitute the state with the given tensor
            for (name, param) in zip(names, objparams):
                del_attr(nnmodule, name) # delete require in case the param is not a torch.nn.Parameter
                set_attr(nnmodule, name, param)

            yield nnmodule

        except Exception as exc:
            raise exc
        finally:
            # restore back the saved tensors
            for (name, param) in zip(names, state_tensors):
                set_attr(nnmodule, name, param)
Beispiel #3
0
    def getparams(self, methodname: str) -> Sequence[torch.Tensor]:
        # Returns a list of tensor parameters used in the object's operations

        paramnames = self.getparamnames(methodname)
        return [get_attr(self, name) for name in paramnames]