Beispiel #1
0
    def __assert_method_preserve(self, method, *args, **kwargs):
        # this method assert if method does not change the float tensor parameters
        # of the object (i.e. it preserves the state of the object)

        all_params0, names0 = _get_tensors(self)
        all_params0 = [p.clone() for p in all_params0]
        method(*args, **kwargs)
        all_params1, names1 = _get_tensors(self)

        # now assert if all_params0 == all_params1
        clsname = method.__self__.__class__.__name__
        methodname = method.__name__
        msg = "The method %s.%s does not preserve the object's float tensors: \n" % (
            clsname, methodname)
        if len(all_params0) != len(all_params1):
            msg += "The number of parameters changed:\n"
            msg += "* number of object's parameters before: %d\n" % len(
                all_params0)
            msg += "* number of object's parameters after : %d\n" % len(
                all_params1)
            raise GetSetParamsError(msg)

        for pname, p0, p1 in zip(names0, all_params0, all_params1):
            if p0.shape != p1.shape:
                msg += "The shape of %s changed\n" % pname
                msg += "* (before) %s.shape: %s\n" % (pname, p0.shape)
                msg += "* (after ) %s.shape: %s\n" % (pname, p1.shape)
                raise GetSetParamsError(msg)
            if not torch.allclose(p0, p1):
                msg += "The value of %s changed\n" % pname
                msg += "* (before) %s: %s\n" % (pname, p0)
                msg += "* (after ) %s: %s\n" % (pname, p1)
                raise GetSetParamsError(msg)
Beispiel #2
0
    def __assert_get_correct_params(self, method, *args, **kwargs):
        # this function perform checks if the getparams on the method returns
        # the correct tensors

        methodname = method.__name__
        clsname = method.__self__.__class__.__name__

        # get all tensor parameters in the object
        all_params, all_names = _get_tensors(self)

        def _get_tensor_name(param):
            for i in range(len(all_params)):
                if id(all_params[i]) == id(param):
                    return all_names[i]
            return None

        # get the parameter tensors used in the operation and the tensors specified by the developer
        oper_names, oper_params = self.__list_operating_params(
            method, *args, **kwargs)
        user_names = self.getparamnames(method.__name__)
        user_params = [get_attr(self, name) for name in user_names]
        user_params_id = [id(p) for p in user_params]
        oper_params_id = [id(p) for p in oper_params]
        user_params_id_set = set(user_params_id)
        oper_params_id_set = set(oper_params_id)

        # check if the userparams contains non-tensor
        for i in range(len(user_params)):
            param = user_params[i]
            if (not isinstance(param, torch.Tensor)) or \
               (isinstance(param, torch.Tensor) and param.dtype not in torch_float_type):
                msg = "Parameter %s is a non-floating point tensor" % user_names[
                    i]
                raise GetSetParamsError(msg)

        # check if there are missing parameters (present in operating params, but not in the user params)
        missing_names = []
        for i in range(len(oper_names)):
            if oper_params_id[i] not in user_params_id_set:
                # if oper_names[i] not in user_names:
                missing_names.append(oper_names[i])
        # if there are missing parameters, give a warning (because the program
        # can still run correctly, e.g. missing parameters are parameters that
        # are never set to require grad)
        if len(missing_names) > 0:
            msg = "getparams for %s.%s does not include: %s" % (
                clsname, methodname, ", ".join(missing_names))
            warnings.warn(msg, stacklevel=3)

        # check if there are excessive parameters (present in the user params, but not in the operating params)
        excess_names = []
        for i in range(len(user_names)):
            if user_params_id[i] not in oper_params_id_set:
                # if user_names[i] not in oper_names:
                excess_names.append(user_names[i])
        # if there are excess parameters, give warnings
        if len(excess_names) > 0:
            msg = "getparams for %s.%s has excess parameters: %s" % \
                (clsname, methodname, ", ".join(excess_names))
            warnings.warn(msg, stacklevel=3)
Beispiel #3
0
    def __assert_match_getsetparams(self, methodname):
        # this function assert if get & set params functions correspond to the
        # same parameters in the same order

        # count the number of parameters in getparams and setparams
        params0 = self.getparams(methodname)
        len_setparams0 = self.setparams(methodname, *params0)
        if len_setparams0 != len(params0):
            raise GetSetParamsError("The number of parameters returned by getparams and set by setparams do not match \n"\
                "(getparams: %d, setparams: %d)" % (len(params0), len_setparams0))

        # check if the params are assigned correctly in the correct order
        params1 = self.getparams(methodname)
        names1 = self.getparamnames(methodname)
        for i, p0, p1 in zip(range(len(params0)), params0, params1):
            if id(p0) != id(p1):
                msg = "The parameter %s in getparams and setparams does not match" % names1[
                    i]
                raise GetSetParamsError(msg)
Beispiel #4
0
    def __assert_method_preserve(self, method, *args, **kwargs):
        # this method assert if method does not change the float tensor parameters
        # of the object (i.e. it preserves the state of the object)

        all_params0, names0 = _get_tensors(self)
        all_params0 = [p.clone() for p in all_params0]
        method(*args, **kwargs)
        all_params1, names1 = _get_tensors(self)

        # now assert if all_params0 == all_params1
        clsname = method.__self__.__class__.__name__
        methodname = method.__name__
        msg = "The method %s.%s does not preserve the object's float tensors" % (
            clsname, methodname)
        if len(all_params0) != len(all_params1):
            raise GetSetParamsError(msg)

        for p0, p1 in zip(all_params0, all_params1):
            if p0.shape != p1.shape:
                raise GetSetParamsError(msg)
            if not torch.allclose(p0, p1):
                raise GetSetParamsError(msg)