def check_tensors_dtypes_same_and_valid(args, valid_dtypes, prim_name): """Checks whether the element types of input tensors are the same and valid.""" valid_dtypes = valid_dtypes if isinstance( valid_dtypes, Iterable) else [valid_dtypes] tensor_types = [mstype.tensor_type(t) for t in valid_dtypes] Validator.check_types_same_and_valid(args, tensor_types, prim_name)
def check_tensor_dtype_valid(arg_name, arg_type, valid_dtypes, prim_name): """Checks whether the element types of input tensors are valid.""" valid_dtypes = valid_dtypes if isinstance( valid_dtypes, Iterable) else [valid_dtypes] tensor_types = [mstype.tensor_type(t) for t in valid_dtypes] Validator.check_subclass(arg_name, arg_type, tensor_types, prim_name)
def check_tensor_type_same(args, valid_values, prim_name): """Checks whether the element types of input tensors are the same.""" tensor_types = [mstype.tensor_type(t) for t in valid_values] Validator.check_type_same(args, tensor_types, prim_name)